Repository: tky823/ssspy Branch: main Commit: 38b9389e8b19 Files: 270 Total size: 1.3 MB Directory structure: gitextract_c37qrstt/ ├── .github/ │ ├── PULL_REQUEST_TEMPLATE.md │ ├── release.yaml │ └── workflows/ │ ├── lint.yaml │ ├── test_docs.yaml │ ├── test_package_macos-13.yaml │ ├── test_package_macos-13_python-3.10.yaml │ ├── test_package_macos-13_python-3.11.yaml │ ├── test_package_macos-13_python-3.12.yaml │ ├── test_package_macos-13_python-3.8.yaml │ ├── test_package_macos-13_python-3.9.yaml │ ├── test_package_macos-latest.yaml │ ├── test_package_macos-latest_python-3.10.yaml │ ├── test_package_macos-latest_python-3.11.yaml │ ├── test_package_macos-latest_python-3.12.yaml │ ├── test_package_macos-latest_python-3.8.yaml │ ├── test_package_macos-latest_python-3.9.yaml │ ├── test_package_main.yaml │ ├── test_package_ubuntu-latest.yaml │ ├── test_package_ubuntu-latest_python-3.10.yaml │ ├── test_package_ubuntu-latest_python-3.11.yaml │ ├── test_package_ubuntu-latest_python-3.12.yaml │ ├── test_package_ubuntu-latest_python-3.8.yaml │ ├── test_package_ubuntu-latest_python-3.9.yaml │ ├── test_package_windows-latest.yaml │ ├── test_package_windows-latest_python-3.10.yaml │ ├── test_package_windows-latest_python-3.11.yaml │ ├── test_package_windows-latest_python-3.12.yaml │ ├── test_package_windows-latest_python-3.8.yaml │ ├── test_package_windows-latest_python-3.9.yaml │ └── upload_package.yaml ├── .gitignore ├── .readthedocs.yaml ├── CHANGELOG.rst ├── LICENSE ├── MANIFEST.in ├── README.md ├── codecov.yaml ├── docs/ │ ├── Makefile │ ├── api.rst │ ├── changelog.rst │ ├── conf.py │ ├── index.rst │ ├── make.bat │ ├── pre_build.sh │ ├── ssspy.algorithm.rst │ ├── ssspy.bss.admmbss.rst │ ├── ssspy.bss.base.rst │ ├── ssspy.bss.cacgmm.rst │ ├── ssspy.bss.fdica.rst │ ├── ssspy.bss.hva.rst │ ├── ssspy.bss.ica.rst │ ├── ssspy.bss.ilrma.rst │ ├── ssspy.bss.iva.rst │ ├── ssspy.bss.mnmf.rst │ ├── ssspy.bss.pdsbss.rst │ ├── ssspy.bss.proxbss.rst │ ├── ssspy.bss.rst │ ├── ssspy.linalg.rst │ ├── ssspy.special.rst │ └── ssspy.transform.rst ├── notebooks/ │ ├── BSS/ │ │ ├── ADMMBSS/ │ │ │ ├── ADMMBSS.ipynb │ │ │ └── ADMMBSS_multi-penalty.ipynb │ │ ├── CACGMM/ │ │ │ └── CACGMM.ipynb │ │ ├── FDICA/ │ │ │ ├── AuxFDICA-IP1.ipynb │ │ │ ├── AuxFDICA-IP2.ipynb │ │ │ ├── AuxLaplaceFDICA-IP1.ipynb │ │ │ ├── AuxLaplaceFDICA-IP2.ipynb │ │ │ ├── GradFDICA.ipynb │ │ │ ├── GradLaplaceFDICA.ipynb │ │ │ ├── NaturalGradFDICA.ipynb │ │ │ └── NaturalGradLaplaceFDICA.ipynb │ │ ├── HVA/ │ │ │ ├── ADMM-HVA.ipynb │ │ │ └── HVA.ipynb │ │ ├── ICA/ │ │ │ ├── FastICA.ipynb │ │ │ ├── GradICA.ipynb │ │ │ └── NaturalGradICA.ipynb │ │ ├── ILRMA/ │ │ │ ├── GGDILRMA-IP1-MM.ipynb │ │ │ ├── GGDILRMA-IP2-MM.ipynb │ │ │ ├── GGDILRMA-ISS1-MM.ipynb │ │ │ ├── GGDILRMA-ISS2-MM.ipynb │ │ │ ├── GaussILRMA-IP1-ME.ipynb │ │ │ ├── GaussILRMA-IP1-MM.ipynb │ │ │ ├── GaussILRMA-IP2-ME.ipynb │ │ │ ├── GaussILRMA-IP2-MM.ipynb │ │ │ ├── GaussILRMA-IPA-ME.ipynb │ │ │ ├── GaussILRMA-IPA-MM.ipynb │ │ │ ├── GaussILRMA-ISS1-ME.ipynb │ │ │ ├── GaussILRMA-ISS1-MM.ipynb │ │ │ ├── GaussILRMA-ISS2-ME.ipynb │ │ │ ├── GaussILRMA-ISS2-MM.ipynb │ │ │ ├── TILRMA-IP1-ME.ipynb │ │ │ ├── TILRMA-IP1-MM.ipynb │ │ │ ├── TILRMA-IP2-ME.ipynb │ │ │ ├── TILRMA-IP2-MM.ipynb │ │ │ ├── TILRMA-ISS1-ME.ipynb │ │ │ ├── TILRMA-ISS1-MM.ipynb │ │ │ ├── TILRMA-ISS2-ME.ipynb │ │ │ └── TILRMA-ISS2-MM.ipynb │ │ ├── IPSDTA/ │ │ │ ├── GaussIPSDTA-VCD.ipynb │ │ │ └── TIPSDTA-VCD.ipynb │ │ ├── IVA/ │ │ │ ├── AuxGaussIVA-IP1.ipynb │ │ │ ├── AuxGaussIVA-IP2.ipynb │ │ │ ├── AuxGaussIVA-IPA.ipynb │ │ │ ├── AuxGaussIVA-ISS1.ipynb │ │ │ ├── AuxGaussIVA-ISS2.ipynb │ │ │ ├── AuxIVA-IP1.ipynb │ │ │ ├── AuxIVA-IP2.ipynb │ │ │ ├── AuxIVA-IPA.ipynb │ │ │ ├── AuxIVA-ISS1.ipynb │ │ │ ├── AuxIVA-ISS2.ipynb │ │ │ ├── AuxLaplaceIVA-IP1.ipynb │ │ │ ├── AuxLaplaceIVA-IP2.ipynb │ │ │ ├── AuxLaplaceIVA-IPA.ipynb │ │ │ ├── AuxLaplaceIVA-ISS1.ipynb │ │ │ ├── AuxLaplaceIVA-ISS2.ipynb │ │ │ ├── FastIVA.ipynb │ │ │ ├── FasterIVA.ipynb │ │ │ ├── GradGaussIVA.ipynb │ │ │ ├── GradIVA.ipynb │ │ │ ├── GradLaplaceIVA.ipynb │ │ │ ├── NaturalGradGaussIVA.ipynb │ │ │ ├── NaturalGradIVA.ipynb │ │ │ └── NaturalGradLaplaceIVA.ipynb │ │ ├── MNMF/ │ │ │ ├── FastGaussMNMF-IP1.ipynb │ │ │ ├── FastGaussMNMF-IP2.ipynb │ │ │ └── GaussMNMF.ipynb │ │ └── PDSBSS/ │ │ ├── PDSBSS.ipynb │ │ ├── PDSBSS_masking.ipynb │ │ └── PDSBSS_multi-penalty.ipynb │ └── Examples/ │ └── Getting-Started.ipynb ├── pyproject.toml ├── ssspy/ │ ├── __init__.py │ ├── algorithm/ │ │ ├── __init__.py │ │ ├── minimal_distortion_principle.py │ │ ├── permutation_alignment.py │ │ └── projection_back.py │ ├── bss/ │ │ ├── __init__.py │ │ ├── _flooring.py │ │ ├── _psd.py │ │ ├── _select_pair.py │ │ ├── _solve_permutation.py │ │ ├── _update_spatial_model.py │ │ ├── admmbss.py │ │ ├── base.py │ │ ├── cacgmm.py │ │ ├── fdica.py │ │ ├── hva.py │ │ ├── ica.py │ │ ├── ilrma.py │ │ ├── ipsdta.py │ │ ├── iva.py │ │ ├── mnmf.py │ │ ├── pdsbss.py │ │ └── proxbss.py │ ├── io/ │ │ └── __init__.py │ ├── linalg/ │ │ ├── __init__.py │ │ ├── _solve.py │ │ ├── cubic.py │ │ ├── eigh.py │ │ ├── inv.py │ │ ├── lqpqm.py │ │ ├── mean.py │ │ ├── polynomial.py │ │ ├── prox.py │ │ ├── quadratic.py │ │ └── sqrtm.py │ ├── special/ │ │ ├── __init__.py │ │ ├── flooring.py │ │ ├── logsumexp.py │ │ ├── psd.py │ │ └── softmax.py │ ├── transform/ │ │ ├── __init__.py │ │ ├── pca.py │ │ └── whiten.py │ └── utils/ │ ├── __init__.py │ ├── dataset/ │ │ ├── __init__.py │ │ ├── mird.py │ │ └── sisec2010.py │ ├── flooring.py │ └── select_pair.py └── tests/ ├── conftest.py ├── dummy/ │ ├── callback.py │ ├── io.py │ └── utils/ │ └── dataset/ │ └── __init__.py ├── mock/ │ └── regression/ │ └── bss/ │ ├── cacgmm/ │ │ └── url.json │ ├── fdica/ │ │ ├── aux_laplace_fdica/ │ │ │ ├── IP1/ │ │ │ │ └── url.json │ │ │ └── IP2/ │ │ │ └── url.json │ │ ├── grad_laplace_fdica/ │ │ │ ├── holonomic/ │ │ │ │ └── url.json │ │ │ └── nonholonomic/ │ │ │ └── url.json │ │ └── natural_grad_laplace_fdica/ │ │ ├── holonomic/ │ │ │ └── url.json │ │ └── nonholonomic/ │ │ └── url.json │ ├── ilrma/ │ │ ├── gauss_ilrma/ │ │ │ ├── IP1/ │ │ │ │ ├── ME/ │ │ │ │ │ └── url.json │ │ │ │ └── MM/ │ │ │ │ └── url.json │ │ │ ├── IP2/ │ │ │ │ ├── ME/ │ │ │ │ │ └── url.json │ │ │ │ └── MM/ │ │ │ │ └── url.json │ │ │ ├── IPA/ │ │ │ │ ├── ME/ │ │ │ │ │ └── url.json │ │ │ │ └── MM/ │ │ │ │ └── url.json │ │ │ ├── ISS1/ │ │ │ │ ├── ME/ │ │ │ │ │ └── url.json │ │ │ │ └── MM/ │ │ │ │ └── url.json │ │ │ └── ISS2/ │ │ │ ├── ME/ │ │ │ │ └── url.json │ │ │ └── MM/ │ │ │ └── url.json │ │ ├── ggd_ilrma/ │ │ │ ├── IP1/ │ │ │ │ └── MM/ │ │ │ │ └── url.json │ │ │ ├── IP2/ │ │ │ │ └── MM/ │ │ │ │ └── url.json │ │ │ ├── ISS1/ │ │ │ │ └── MM/ │ │ │ │ └── url.json │ │ │ └── ISS2/ │ │ │ └── MM/ │ │ │ └── url.json │ │ └── t_ilrma/ │ │ ├── IP1/ │ │ │ ├── ME/ │ │ │ │ └── url.json │ │ │ └── MM/ │ │ │ └── url.json │ │ ├── IP2/ │ │ │ ├── ME/ │ │ │ │ └── url.json │ │ │ └── MM/ │ │ │ └── url.json │ │ ├── ISS1/ │ │ │ ├── ME/ │ │ │ │ └── url.json │ │ │ └── MM/ │ │ │ └── url.json │ │ └── ISS2/ │ │ ├── ME/ │ │ │ └── url.json │ │ └── MM/ │ │ └── url.json │ ├── ipsdta/ │ │ ├── gauss_ipsdta/ │ │ │ └── VCD/ │ │ │ └── MM/ │ │ │ └── url.json │ │ └── t_ipsdta/ │ │ └── VCD/ │ │ └── MM/ │ │ └── url.json │ ├── iva/ │ │ ├── aux_iva/ │ │ │ ├── IP1/ │ │ │ │ └── url.json │ │ │ ├── IP2/ │ │ │ │ └── url.json │ │ │ ├── IPA/ │ │ │ │ └── url.json │ │ │ ├── ISS1/ │ │ │ │ └── url.json │ │ │ └── ISS2/ │ │ │ └── url.json │ │ ├── fast_iva/ │ │ │ └── url.json │ │ ├── grad_iva/ │ │ │ ├── holonomic/ │ │ │ │ └── url.json │ │ │ └── nonholonomic/ │ │ │ └── url.json │ │ └── natural_grad_iva/ │ │ ├── holonomic/ │ │ │ └── url.json │ │ └── nonholonomic/ │ │ └── url.json │ └── mnmf/ │ ├── fast_gauss_mnmf/ │ │ ├── IP1/ │ │ │ └── url.json │ │ └── IP2/ │ │ └── url.json │ └── gauss_mnmf/ │ └── url.json ├── package/ │ ├── algorithm/ │ │ ├── test_minimal_distortion_principle.py │ │ ├── test_permutation_alignment.py │ │ └── test_projection_back.py │ ├── bss/ │ │ ├── test_admmbss.py │ │ ├── test_base.py │ │ ├── test_cacgmm.py │ │ ├── test_fdica.py │ │ ├── test_hva.py │ │ ├── test_ica.py │ │ ├── test_ilrma.py │ │ ├── test_ipsdta.py │ │ ├── test_iterative_methods.py │ │ ├── test_iva.py │ │ ├── test_mnmf.py │ │ ├── test_pair_selector.py │ │ ├── test_pdsbss.py │ │ ├── test_proxbss.py │ │ ├── test_psd_legacy.py │ │ ├── test_solve_permutation.py │ │ └── test_update_spatial_model.py │ ├── io/ │ │ └── test_wavread.py │ ├── linalg/ │ │ ├── test_cubic.py │ │ ├── test_eigh.py │ │ ├── test_gmean.py │ │ ├── test_inv.py │ │ ├── test_lqpqm.py │ │ ├── test_polynomial.py │ │ └── test_sqrtm.py │ ├── special/ │ │ ├── test_logsumexp.py │ │ ├── test_psd.py │ │ └── test_softmax.py │ ├── transform/ │ │ ├── test_pca.py │ │ └── test_whiten.py │ └── utils/ │ ├── test_dataset.py │ └── test_select_pair.py ├── regression/ │ └── bss/ │ ├── test_cacgmm.py │ ├── test_fdica.py │ ├── test_ilrma.py │ ├── test_ipsdta.py │ ├── test_iva.py │ └── test_mnmf.py └── scripts/ └── download_all.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/PULL_REQUEST_TEMPLATE.md ================================================ ## Summary ================================================ FILE: .github/release.yaml ================================================ changelog: categories: - title: Breaking Changes 🛠 labels: - breaking changes - title: New Features 🎉 labels: - new feature - title: Bug Fixes 🐛 labels: - bug - bug fix - title: Notebooks - notebooks - title: Other Changes labels: - "*" ================================================ FILE: .github/workflows/lint.yaml ================================================ name: lint on: push: branches: - main pull_request: branches: - main jobs: lint: name: Run linters runs-on: ubuntu-latest steps: - name: Checkout uses: actions/checkout@v4 - name: Set up Python 3.12 uses: actions/setup-python@v4 with: python-version: "3.12" - name: Install dependencies run: | python -m pip install --upgrade pip pip install ".[dev]" - name: Run linters run: | # See pyproject.toml isort --line-length 100 ssspy tests flake8 --max-line-length=100 --ignore=E203,W503,W504 --exclude ssspy/_version.py ssspy tests - name: Run formatters run: | python -m black --config pyproject.toml --check ssspy tests ================================================ FILE: .github/workflows/test_docs.yaml ================================================ name: tests for docs on: push: branches: - main pull_request: branches: - main jobs: build: name: Build docs runs-on: ubuntu-latest steps: - name: Checkout uses: actions/checkout@v4 - name: Set up Python 3.12 uses: actions/setup-python@v4 with: python-version: "3.12" - name: Install dependencies run: | sudo apt-get update sudo apt-get install pandoc python -m pip install --upgrade pip pip install ".[docs,notebooks]" - name: Build docs run: | . ./docs/pre_build.sh cd docs/ sphinx-build -W ./ ./_build/html/ ================================================ FILE: .github/workflows/test_package_macos-13.yaml ================================================ name: macos-13 on: workflow_call: inputs: python-version: required: true type: string secrets: CODECOV_TOKEN: required: true TEST_PYPI_API_TOKEN: required: true jobs: package: uses: ./.github/workflows/test_package_main.yaml with: # macos-13: x86_64, macos-latest: arm # See https://docs.github.com/en/actions/using-github-hosted-runners/about-github-hosted-runners/about-github-hosted-runners os: macos-13 python-version: ${{ inputs.python-version }} secrets: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} permissions: id-token: write ================================================ FILE: .github/workflows/test_package_macos-13_python-3.10.yaml ================================================ name: macos-13/3.10 on: push: branches: - main pull_request: branches: - main jobs: package: uses: ./.github/workflows/test_package_macos-13.yaml with: python-version: "3.10" secrets: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} permissions: id-token: write ================================================ FILE: .github/workflows/test_package_macos-13_python-3.11.yaml ================================================ name: macos-13/3.11 on: push: branches: - main pull_request: branches: - main jobs: package: uses: ./.github/workflows/test_package_macos-13.yaml with: python-version: "3.11" secrets: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} permissions: id-token: write ================================================ FILE: .github/workflows/test_package_macos-13_python-3.12.yaml ================================================ name: macos-13/3.12 on: push: branches: - main pull_request: branches: - main jobs: package: uses: ./.github/workflows/test_package_macos-13.yaml with: python-version: "3.12" secrets: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} permissions: id-token: write ================================================ FILE: .github/workflows/test_package_macos-13_python-3.8.yaml ================================================ name: macos-13/3.8 on: push: branches: - main pull_request: branches: - main jobs: package: uses: ./.github/workflows/test_package_macos-13.yaml with: python-version: "3.8" secrets: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} permissions: id-token: write ================================================ FILE: .github/workflows/test_package_macos-13_python-3.9.yaml ================================================ name: macos-13/3.9 on: push: branches: - main pull_request: branches: - main jobs: package: uses: ./.github/workflows/test_package_macos-13.yaml with: python-version: "3.9" secrets: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} permissions: id-token: write ================================================ FILE: .github/workflows/test_package_macos-latest.yaml ================================================ name: macos-latest on: workflow_call: inputs: python-version: required: true type: string secrets: CODECOV_TOKEN: required: true TEST_PYPI_API_TOKEN: required: true jobs: package: uses: ./.github/workflows/test_package_main.yaml with: # macos-13: x86_64, macos-latest: arm # See https://docs.github.com/en/actions/using-github-hosted-runners/about-github-hosted-runners/about-github-hosted-runners os: macos-latest python-version: ${{ inputs.python-version }} secrets: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} permissions: id-token: write ================================================ FILE: .github/workflows/test_package_macos-latest_python-3.10.yaml ================================================ name: macos-latest/3.10 on: push: branches: - main pull_request: branches: - main jobs: package: uses: ./.github/workflows/test_package_macos-latest.yaml with: python-version: "3.10" secrets: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} permissions: id-token: write ================================================ FILE: .github/workflows/test_package_macos-latest_python-3.11.yaml ================================================ name: macos-latest/3.11 on: push: branches: - main pull_request: branches: - main jobs: package: uses: ./.github/workflows/test_package_macos-latest.yaml with: python-version: "3.11" secrets: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} permissions: id-token: write ================================================ FILE: .github/workflows/test_package_macos-latest_python-3.12.yaml ================================================ name: macos-latest/3.12 on: push: branches: - main pull_request: branches: - main jobs: package: uses: ./.github/workflows/test_package_macos-latest.yaml with: python-version: "3.12" secrets: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} permissions: id-token: write ================================================ FILE: .github/workflows/test_package_macos-latest_python-3.8.yaml ================================================ name: macos-latest/3.8 on: push: branches: - main pull_request: branches: - main jobs: package: uses: ./.github/workflows/test_package_macos-latest.yaml with: python-version: "3.8" secrets: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} permissions: id-token: write ================================================ FILE: .github/workflows/test_package_macos-latest_python-3.9.yaml ================================================ name: macos-latest/3.9 on: push: branches: - main pull_request: branches: - main jobs: package: uses: ./.github/workflows/test_package_macos-latest.yaml with: python-version: "3.9" secrets: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} permissions: id-token: write ================================================ FILE: .github/workflows/test_package_main.yaml ================================================ name: test package on: workflow_call: inputs: os: required: true type: string python-version: required: true type: string secrets: CODECOV_TOKEN: required: true TEST_PYPI_API_TOKEN: required: true jobs: build: name: Run tests with pytest runs-on: ${{ inputs.os }} steps: - name: Checkout uses: actions/checkout@v4 - name: Set up Python ${{ inputs.python-version }} uses: actions/setup-python@v4 with: python-version: ${{ inputs.python-version }} - name: Install dependencies run: | python -m pip install --upgrade pip pip install -e ".[dev,tests]" - name: Preparation of pytest run: | python tests/scripts/download_all.py - name: Pytest (run all tests including redundant ones) id: run_redundant if: startsWith(github.head_ref, 'release/') run: | pytest --run-redundant -vvv -n 16 --cov=./ssspy --cov-report=xml tests/package/ - name: Pytest (skip redundant tests) if: steps.run_redundant.conclusion == 'skipped' run: | pytest -vvv -n 16 --cov=./ssspy --cov-report=xml tests/package/ - name: Pytest (regression tests) run: | pytest -vvv -n 16 tests/regression/ - name: Upload coverage reports to Codecov if: inputs.python-version == '3.12' && inputs.os == 'ubuntu-latest' uses: codecov/codecov-action@v3 env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} with: fail_ci_if_error: true upload_package: needs: - build permissions: id-token: write secrets: TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} if: github.event_name == 'pull_request' && inputs.python-version == '3.12' && inputs.os == 'ubuntu-latest' uses: ./.github/workflows/upload_package.yaml ================================================ FILE: .github/workflows/test_package_ubuntu-latest.yaml ================================================ name: ubuntu-latest on: workflow_call: inputs: python-version: required: true type: string secrets: CODECOV_TOKEN: required: true TEST_PYPI_API_TOKEN: required: true jobs: package: uses: ./.github/workflows/test_package_main.yaml with: os: ubuntu-latest python-version: ${{ inputs.python-version }} secrets: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} permissions: id-token: write ================================================ FILE: .github/workflows/test_package_ubuntu-latest_python-3.10.yaml ================================================ name: ubuntu-latest/3.10 on: push: branches: - main pull_request: branches: - main jobs: package: uses: ./.github/workflows/test_package_ubuntu-latest.yaml with: python-version: "3.10" secrets: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} permissions: id-token: write ================================================ FILE: .github/workflows/test_package_ubuntu-latest_python-3.11.yaml ================================================ name: ubuntu-latest/3.11 on: push: branches: - main pull_request: branches: - main jobs: package: uses: ./.github/workflows/test_package_ubuntu-latest.yaml with: python-version: "3.11" secrets: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} permissions: id-token: write ================================================ FILE: .github/workflows/test_package_ubuntu-latest_python-3.12.yaml ================================================ name: ubuntu-latest/3.12 on: push: branches: - main pull_request: branches: - main jobs: package: uses: ./.github/workflows/test_package_ubuntu-latest.yaml with: python-version: "3.12" secrets: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} permissions: id-token: write ================================================ FILE: .github/workflows/test_package_ubuntu-latest_python-3.8.yaml ================================================ name: ubuntu-latest/3.8 on: push: branches: - main pull_request: branches: - main jobs: package: uses: ./.github/workflows/test_package_ubuntu-latest.yaml with: python-version: "3.8" secrets: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} permissions: id-token: write ================================================ FILE: .github/workflows/test_package_ubuntu-latest_python-3.9.yaml ================================================ name: ubuntu-latest/3.9 on: push: branches: - main pull_request: branches: - main jobs: package: uses: ./.github/workflows/test_package_ubuntu-latest.yaml with: python-version: "3.9" secrets: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} permissions: id-token: write ================================================ FILE: .github/workflows/test_package_windows-latest.yaml ================================================ name: windows-latest on: workflow_call: inputs: python-version: required: true type: string secrets: CODECOV_TOKEN: required: true TEST_PYPI_API_TOKEN: required: true jobs: package: uses: ./.github/workflows/test_package_main.yaml with: os: windows-latest python-version: ${{ inputs.python-version }} secrets: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} permissions: id-token: write ================================================ FILE: .github/workflows/test_package_windows-latest_python-3.10.yaml ================================================ name: windows-latest/3.10 on: push: branches: - main pull_request: branches: - main jobs: package: uses: ./.github/workflows/test_package_windows-latest.yaml with: python-version: "3.10" secrets: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} permissions: id-token: write ================================================ FILE: .github/workflows/test_package_windows-latest_python-3.11.yaml ================================================ name: windows-latest/3.11 on: push: branches: - main pull_request: branches: - main jobs: package: uses: ./.github/workflows/test_package_windows-latest.yaml with: python-version: "3.11" secrets: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} permissions: id-token: write ================================================ FILE: .github/workflows/test_package_windows-latest_python-3.12.yaml ================================================ name: windows-latest/3.12 on: push: branches: - main pull_request: branches: - main jobs: package: uses: ./.github/workflows/test_package_windows-latest.yaml with: python-version: "3.12" secrets: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} permissions: id-token: write ================================================ FILE: .github/workflows/test_package_windows-latest_python-3.8.yaml ================================================ name: windows-latest/3.8 on: push: branches: - main pull_request: branches: - main jobs: package: uses: ./.github/workflows/test_package_windows-latest.yaml with: python-version: "3.8" secrets: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} permissions: id-token: write ================================================ FILE: .github/workflows/test_package_windows-latest_python-3.9.yaml ================================================ name: windows-latest/3.9 on: push: branches: - main pull_request: branches: - main jobs: package: uses: ./.github/workflows/test_package_windows-latest.yaml with: python-version: "3.9" secrets: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} permissions: id-token: write ================================================ FILE: .github/workflows/upload_package.yaml ================================================ # based on # https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python # https://github.com/pypa/gh-action-pypi-publish # TODO: update this config for practical use name: Upload package to PyPI on: workflow_call: secrets: TEST_PYPI_API_TOKEN: required: true jobs: build: name: Build and upload package runs-on: ubuntu-latest permissions: id-token: write steps: - name: Checkout uses: actions/checkout@v4 with: # to retrive tags fetch-depth: 0 - name: Set up Python 3.x uses: actions/setup-python@v4 with: python-version: '3.x' - name: Show git tags run: | git tag - name: Install dependencies run: | python -m pip install --upgrade pip pip install build wheel twine - name: Build run: | python -m build - name: Publish distribution to TestPyPI env: TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} run: | python -m twine upload --repository testpypi --username __token__ --password ${TEST_PYPI_API_TOKEN} dist/* ================================================ FILE: .gitignore ================================================ # For building docs docs/_notebooks/ # For local .data/ _version.py # For Mac .DS_Store # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ ================================================ FILE: .readthedocs.yaml ================================================ # .readthedocs.yaml # Read the Docs configuration file # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details # Required version: 2 # Set the version of Python and other tools you might need build: os: ubuntu-20.04 tools: python: "3.11" # You can also specify other tool versions: # nodejs: "16" # rust: "1.55" # golang: "1.17" jobs: pre_build: - . ./docs/pre_build.sh # Build documentation in the docs/ directory with Sphinx sphinx: configuration: docs/conf.py # If using Sphinx, optionally build your docs in additional formats such as PDF # formats: # - pdf # Optionally declare the Python requirements required to build your docs # python: # install: # - method: pip # path: . # extra_requirements: # - docs # - notebooks ================================================ FILE: CHANGELOG.rst ================================================ Changelog ######### v0.2.0 ****** What's Changed ============== Breaking Changes 🛠 ------------------- * Rename `aux` to `auxiliary` by @tky823 in https://github.com/tky823/ssspy/pull/268 * Detailed build status by @tky823 in https://github.com/tky823/ssspy/pull/288 New Features 🎉 --------------- * Implementation of harmonic vector analysis by @tky823 in https://github.com/tky823/ssspy/pull/271 * Implementation of ADMM-HVA by @tky823 in https://github.com/tky823/ssspy/pull/281 Bug Fixes 🐛 ------------ * Fix test coverage by @tky823 in https://github.com/tky823/ssspy/pull/269 * Fix timing of uploading package by @tky823 in https://github.com/tky823/ssspy/pull/273 * Remove status badge of lint by @tky823 in https://github.com/tky823/ssspy/pull/274 Other Changes ------------- * Upload package to TestPyPI by @tky823 in https://github.com/tky823/ssspy/pull/267 * Remove duplicate uploads to TestPyPI by @tky823 in https://github.com/tky823/ssspy/pull/270 * Use flooring function to compute norm. by @tky823 in https://github.com/tky823/ssspy/pull/276 * Regression tests by @tky823 in https://github.com/tky823/ssspy/pull/238 * Add `needs` to upload_package job in GHA. by @tky823 in https://github.com/tky823/ssspy/pull/277 * Update actions/checkout in GitHub actions by @tky823 in https://github.com/tky823/ssspy/pull/279 * Hugging Face demo by @tky823 in https://github.com/tky823/ssspy/pull/282 * Set permissions in workflows by @tky823 in https://github.com/tky823/ssspy/pull/289 * Bump up version to 0.2.0 by @tky823 in https://github.com/tky823/ssspy/pull/290 **Full Changelog**: https://github.com/tky823/ssspy/compare/v0.1.7...v0.2.0 v0.1.7 ****** Summary ======= In this version, we improve the management of the package. As a new BSS method, ADMM-BSS is newly added. What's Changed ============== Breaking Changes 🛠 ------------------- * Include ssspy only as package by @tky823 in https://github.com/tky823/ssspy/pull/253 * Add ``MANIFEST.in`` by @tky823 in https://github.com/tky823/ssspy/pull/257 New Features 🎉 --------------- * Implementation of ADMM-IVA by @tky823 in https://github.com/tky823/ssspy/pull/263 * Support ADMM-BSS_multi-penalty by @tky823 in https://github.com/tky823/ssspy/pull/265 Bug Fixes 🐛 ------------ * Fix document deployment by @tky823 in https://github.com/tky823/ssspy/pull/255 * Update some variables depending on ``demix_filter`` instead of ``self.algorithm``. by @tky823 in https://github.com/tky823/ssspy/pull/260 Other Changes ------------- * Release notes by @tky823 in https://github.com/tky823/ssspy/pull/246 * Add label for breaking changes by @tky823 in https://github.com/tky823/ssspy/pull/247 * Notebooks/getting started by @tky823 in https://github.com/tky823/ssspy/pull/248 * Update docs and notebooks to install ``ssspy`` from pypi by @tky823 in https://github.com/tky823/ssspy/pull/251 * Detect reformatting by @tky823 in https://github.com/tky823/ssspy/pull/258 * Make PDSBSSBase inherit IterativeMethodBase by @tky823 in https://github.com/tky823/ssspy/pull/262 **Full Changelog**: `v0.1.6...v0.1.7 `_ v0.1.6 ****** Summary ======= In this version, the following BSS methods are newly added 🚀 - Fast MNMF - IVA-IPA - ILRMA-IPA What's Changed ============== * Bump up version to v0.1.5 by @tky823 in https://github.com/tky823/ssspy/pull/222 * Rename "XXXbase" to "XXXBase" by @tky823 in https://github.com/tky823/ssspy/pull/224 * Move default pair_selector by @tky823 in https://github.com/tky823/ssspy/pull/225 * Implement Fast MNMF by @tky823 in https://github.com/tky823/ssspy/pull/226 * Score-based permutation solver by @tky823 in https://github.com/tky823/ssspy/pull/221 * Specify flooring function in each method by @tky823 in https://github.com/tky823/ssspy/pull/228 * Solver for cubic equations. by @tky823 in https://github.com/tky823/ssspy/pull/230 * Consider corner case of cubic polynomial by @tky823 in https://github.com/tky823/ssspy/pull/233 * Use pytest-xdist by @tky823 in https://github.com/tky823/ssspy/pull/235 * Implement IVA-IPA by @tky823 in https://github.com/tky823/ssspy/pull/234 * Update links to reference by @tky823 in https://github.com/tky823/ssspy/pull/237 * Fix shape of varphi in tests of IVA by @tky823 in https://github.com/tky823/ssspy/pull/240 * End support of python=3.7 by @tky823 in https://github.com/tky823/ssspy/pull/243 * Stabilize IVA-IPA related algorithms by @tky823 in https://github.com/tky823/ssspy/pull/241 * Implementation of ILRMA-IPA by @tky823 in https://github.com/tky823/ssspy/pull/244 **Full Changelog**: `v0.1.5...v0.1.6 `_ ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright 2022 Takuya Hasumi Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: MANIFEST.in ================================================ exclude .gitignore exclude *.yaml recursive-include ssspy *.py prune .github prune tests prune docs prune notebooks ================================================ FILE: README.md ================================================ # ssspy [![Documentation Status](https://readthedocs.org/projects/sound-source-separation-python/badge/?version=latest)](https://sound-source-separation-python.readthedocs.io/en/latest/?badge=latest) [![codecov](https://codecov.io/gh/tky823/ssspy/branch/main/graph/badge.svg)](https://codecov.io/gh/tky823/ssspy) [![Open in Spaces](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm.svg)](https://tky823-ssspy-demo.hf.space/) A Python toolkit for sound source separation. ## Build Status | Python | Ubuntu | MacOS (x86_64) | MacOS (arm64) | Windows | |:-:|:-:|:-:|:-:|:-:| | 3.9 | [![ubuntu-latest/3.9](https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.9.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.9.yaml) | [![macos-13/3.9](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.9.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.9.yaml) | [![macos-latest/3.9](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.9.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.9.yaml) | [![windows-latest/3.9](https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.9.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.9.yaml) | | 3.10 | [![ubuntu-latest/3.10](https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.10.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.10.yaml) | [![macos-13/3.10](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.10.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.10.yaml) | [![macos-latest/3.10](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.10.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.10.yaml) | [![windows-latest/3.10](https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.10.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.10.yaml) | | 3.11 | [![ubuntu-latest/3.11](https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.11.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.11.yaml) | [![macos-13/3.11](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.11.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.11.yaml) | [![macos-latest/3.11](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.11.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.11.yaml) | [![windows-latest/3.11](https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.11.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.11.yaml) | | 3.12 | [![ubuntu-latest/3.12](https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.12.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.12.yaml) | [![macos-13/3.12](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.12.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.12.yaml) | [![macos-latest/3.12](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.12.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.12.yaml) | [![windows-latest/3.12](https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.12.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.12.yaml) | ## Installation You can install by pip. ```shell pip install ssspy ``` To install latest version, ```shell pip install git+https://github.com/tky823/ssspy.git ``` Instead, you can build package from source. ```shell git clone https://github.com/tky823/ssspy.git cd ssspy pip install . ``` If you cannot install `ssspy` due to failure in building wheel for numpy, please install numpy in advance. ## Build Documentation Locally (optional) To build the documentation locally, you have to include `docs` and `notebooks` when installing `ssspy`. ```shell pip install -e ".[docs,notebooks]" ``` You need to convert some notebooks by the following command: ```shell # in ssspy/ . ./docs/pre_build.sh ``` When you build the documentation, run the following command. ```shell cd docs/ make html ``` Or, you can build the documentation automatically using `sphinx-autobuild`. ```shell # in ssspy/ sphinx-autobuild docs docs/_build/html ``` ## Blind Source Separation Methods | Method | Notebooks | |:-:|:-:| | Independent Component Analysis (ICA) [1-3] | Gradient-descent-based ICA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ICA/GradICA.ipynb)
Natural-gradient-descent-based ICA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ICA/NaturalGradICA.ipynb)
Fast ICA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ICA/FastICA.ipynb) | | Frequency-Domain Independent Component Analysis (FDICA) [4-6] | Gradient-descent-based FDICA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/FDICA/GradFDICA.ipynb)
Natural-gradient-descent-based FDICA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/FDICA/NaturalGradFDICA.ipynb)
Auxiliary-function-based FDICA (IP1): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/FDICA/AuxFDICA-IP1.ipynb)
Auxiliary-function-based FDICA (IP2): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/FDICA/AuxFDICA-IP2.ipynb)
Gradient-descent-based Laplace-FDICA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/FDICA/GradLaplaceFDICA.ipynb)
Natural-gradient-descent-based Laplace-FDICA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/FDICA/NaturalGradLaplaceFDICA.ipynb)
Auxiliary-function-based Laplace-FDICA (IP1): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/FDICA/AuxLaplaceFDICA-IP1.ipynb)
Auxiliary-function-based Laplace-FDICA (IP2): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/FDICA/AuxLaplaceFDICA-IP2.ipynb) | | Independent Vector Analysis (IVA) [7-14] | Gradient-descent-based IVA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/GradIVA.ipynb)
Natural-gradient-descent-based IVA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/NaturalGradIVA.ipynb)
Fast IVA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/FastIVA.ipynb)
Faster IVA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/FasterIVA.ipynb)
Auxiliary-function-based IVA (IP1): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/AuxIVA-IP1.ipynb)
Auxiliary-function-based IVA (IP2): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/AuxIVA-IP2.ipynb)
Auxiliary-function-based IVA (ISS1): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/AuxIVA-ISS1.ipynb)
Auxiliary-function-based IVA (ISS2): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/AuxIVA-ISS2.ipynb)
Auxiliary-function-based IVA (IPA): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/AuxIVA-IPA.ipynb)
Gradient-descent-based Laplace-IVA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/GradLaplaceIVA.ipynb)
Natural-gradient-descent-based Laplace-IVA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/NaturalGradLaplaceIVA.ipynb)
Auxiliary-function-based Laplace-IVA (IP1): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/AuxLaplaceIVA-IP1.ipynb)
Auxiliary-function-based Laplace-IVA (IP2): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/AuxLaplaceIVA-IP2.ipynb)
Auxiliary-function-based Laplace-IVA (ISS1): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/AuxLaplaceIVA-ISS1.ipynb)
Auxiliary-function-based Laplace-IVA (ISS2): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/AuxLaplaceIVA-ISS2.ipynb)
Auxiliary-function-based Laplace-IVA (IPA): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/AuxLaplaceIVA-IPA.ipynb)
Gradient-descent-based Gauss-IVA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/GradGaussIVA.ipynb)
Natural-gradient-descent-based Gauss-IVA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/NaturalGradGaussIVA.ipynb)
Auxiliary-function-based Gauss-IVA (IP1): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/AuxGaussIVA-IP1.ipynb)
Auxiliary-function-based Gauss-IVA (IP2): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/AuxGaussIVA-IP2.ipynb)
Auxiliary-function-based Gauss-IVA (ISS1): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/AuxGaussIVA-ISS1.ipynb)
Auxiliary-function-based Gauss-IVA (ISS2): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/AuxGaussIVA-ISS2.ipynb)
Auxiliary-function-based Gauss-IVA (IPA): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/AuxGaussIVA-IPA.ipynb) | | Independent Low-Rank Matrix Analysis (ILRMA) [15-18] | Gauss-ILRMA (IP1/MM): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/GaussILRMA-IP1-MM.ipynb)
Gauss-ILRMA (IP1/ME): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/GaussILRMA-IP1-ME.ipynb)
Gauss-ILRMA (IP2/MM): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/GaussILRMA-IP2-MM.ipynb)
Gauss-ILRMA (IP2/ME): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/GaussILRMA-IP2-ME.ipynb)
Gauss-ILRMA (ISS1/MM): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/GaussILRMA-ISS1-MM.ipynb)
Gauss-ILRMA (ISS1/ME): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/GaussILRMA-ISS1-ME.ipynb)
Gauss-ILRMA (ISS2/MM): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/GaussILRMA-ISS2-MM.ipynb)
Gauss-ILRMA (ISS2/ME): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/GaussILRMA-ISS2-ME.ipynb)
Gauss-ILRMA (IPA/MM): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/GaussILRMA-IPA-MM.ipynb)
Gauss-ILRMA (IPA/ME): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/GaussILRMA-IPA-ME.ipynb)
*t*-ILRMA (IP1/MM): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/TILRMA-IP1-MM.ipynb)
*t*-ILRMA (IP1/ME): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/TILRMA-IP1-ME.ipynb)
*t*-ILRMA (IP2/MM): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/TILRMA-IP2-MM.ipynb)
*t*-ILRMA (IP2/ME): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/TILRMA-IP2-ME.ipynb)
*t*-ILRMA (ISS1/MM): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/TILRMA-ISS1-MM.ipynb)
*t*-ILRMA (ISS1/ME): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/TILRMA-ISS1-ME.ipynb)
*t*-ILRMA (ISS2/MM): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/TILRMA-ISS2-MM.ipynb)
*t*-ILRMA (ISS2/ME): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/TILRMA-ISS2-ME.ipynb)
GGD-ILRMA (IP1/MM): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/GGDILRMA-IP1-MM.ipynb)
GGD-ILRMA (IP2/MM): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/GGDILRMA-IP2-MM.ipynb)
GGD-ILRMA (ISS1/MM): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/GGDILRMA-ISS1-MM.ipynb)
GGD-ILRMA (ISS2/MM): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/GGDILRMA-ISS2-MM.ipynb) | | Independent Positive Semidefinite Tensor Analysis (IPSDTA) [19, 20] | Gauss-IPSDTA (VCD): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IPSDTA/GaussIPSDTA-VCD.ipynb)
*t*-IPSDTA (VCD): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IPSDTA/TIPSDTA-VCD.ipynb) | | Multichannel Nonnegative Matrix Factorization (MNMF) [21-24] | Gauss-MNMF: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/MNMF/GaussMNMF.ipynb)
*t*-MNMF: soon
Fast Gauss-MNMF (IP1): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/MNMF/FastGaussMNMF-IP1.ipynb)
Fast Gauss-MNMF (IP2): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/MNMF/FastGaussMNMF-IP2.ipynb) | | Blind Source Separation via Primal-Dual Splitting Algorithm (PDS-BSS) [25,26] | PDS-BSS: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/PDSBSS/PDSBSS.ipynb)
PDS-BSS-multiPenalty: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/PDSBSS/PDSBSS_multi-penalty.ipynb)
PDS-BSS-masking: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/PDSBSS/PDSBSS_masking.ipynb) | | Blind Source Separation via Alternating Direction Method of Multipliers (ADMM-BSS) | ADMM-BSS: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ADMMBSS/ADMMBSS.ipynb) | | Harmonic Vector Analysis (HVA) [27] | HVA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/HVA/HVA.ipynb) | | Complex Angular Central Gaussian Mixture Model (cACGMM) [28] | cACGMM: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/CACGMM/CACGMM.ipynb) | - [1] [P. Comon, "Independent component analysis, a new concept?" Signal Processing, vol. 36, no. 3, pp. 287-314, 1994.](https://www.sciencedirect.com/science/article/pii/0165168494900299) - [2] [S. Amari, A. Cichocki, and H. H. Yang, "A new learning algorithm forblind signal separation," in *Proc. NIPS*, 1996, pp. 757-763.](https://proceedings.neurips.cc/paper/1995/hash/e19347e1c3ca0c0b97de5fb3b690855a-Abstract.html) - [3] [A. Hyvärinen, "Fast and robust fixed-point algorithms for independent component analysis," *IEEE Trans. on Neural Netw.*, vol. 10, no. 3, pp. 626-634, 1999.](https://www.cs.helsinki.fi/u/ahyvarin/papers/TNN99new.pdf) - [4] [N. Murata, S. Ikeda, and A. Ziehe, "An approach to blind source separation based on temporal structure of speech signals," in *Neurocomputing*, vol. 41, no. 1, pp. 1-24, 2001](https://www.sciencedirect.com/science/article/pii/S0925231200003453) - [5] [H. Sawada, S. Araki, and S. Makino, "Underdetermined convolutive blind source separation via frequency bin-wise clustering and permutation alignment," 2011.](https://ieeexplore.ieee.org/document/5473129) - [6] [N. Ono and S. Miyabe, "Auxiliary-function-based independent componentanalysis for super-Gaussian sources," in *Proc. LVA/ICA*, 2010, pp. 165-172.](https://link.springer.com/chapter/10.1007/978-3-642-15995-4_21) - [7] [T. Kim, T. Attias, S.-Y. Lee, and T.-W. Lee, "Blind source separation exploiting higher-order frequency dependencies," *IEEE trans. ASLP*, vol. 15, no.1, pp. 70-79, 2006.](https://link.springer.com/chapter/10.1007/11679363_21) - [8] [I. Lee, T. Kim, and T.-W. Lee, "Fast fixed-point independent vector analysis algorithms for convolutive blind source separation," *Signal Processing*, vol. 87, no. 8, pp. 1859-1871, 2007.]() - [9] [N. Ono, "Stable and fast update rules for independent vector analysis based on auxiliary function technique," in *Proc. WASPAA*, 2011, p.189-192.](https://ieeexplore.ieee.org/document/6082320) - [10] [N. Ono, "Auxiliary-function-based independent vector analysis with power of vector-norm type weighting functions," in *Proc. APSIPA ASC*, 2012, pp. 1-4.](https://ieeexplore.ieee.org/document/6411886) - [11] [R. Scheibler and N. Ono, "Fast and stable blind source separation with rank-1 updates," in *Proc. ICASSP*, 2020, pp. 236-240.](https://ieeexplore.ieee.org/document/9053556) - [12] [A. Brendel and W. Kellermann, "Faster IVA: Update rules for independent vector analysis based on negentropy and the majorize-minimize principle," in *Proc. WASPAA*, 2021, pp. 131–135.](https://arxiv.org/abs/2003.09531) - [13] [R. Scheibler, "Independent vector analysis via log-quadratically penalized quadratic minimization," *IEEE Trans. Signal Processing*, vol. 69, pp. 2509-2524, 2021.](https://ieeexplore.ieee.org/document/9399809) - [14] [R. Ikeshita and T. Nakatani, "ISS2: An extension of iterative source steering algorithm for majorization-minimization-based independent vector analysis," in *Proc. EUSIPCO*, 2022, pp. 65-69.](https://arxiv.org/abs/2202.00875) - [15] [D. Kitamura, N. Ono, H. Sawada, H. Kameoka, and H. Saruwatari, "Determined blind source separation unifying independent vector analysis and nonnegative matrix factorization," *IEEE/ACM Trans. ASLP*, vol. 24, no. 9, pp. 1626-1641, 2016.](https://ieeexplore.ieee.org/document/7486081) - [16] [D. Kitamura, S. Mogami, Y. Mitsui, N. Takamune, H. Saruwatari, N. Ono, Y. Takahashi, and K. Kondo, "Generalized independent low-rank matrix analysis using heavy-tailed distributions for blind source separation," *EURASIP J. Adv. in Signal Processing*, vol. 2018, no. 28, 25 pages, 2018.](https://link.springer.com/article/10.1186/s13634-018-0549-5) - [17] [T. Nakashima, R. Scheibler, Y. Wakabayashi, and N. Ono, "Faster independent low-rank matrix analysis with pairwise updates of demixing vectors," in *Proc. EUSIPCO*, 2021, pp. 301-305.](https://ieeexplore.ieee.org/document/9287508) - [18] [Y. Mitsui, D. Kitamura, N. Takamune, H. Saruwatari, Y. Takahashi, K. Kondo, "Independent low-rank matrix analysis based on parametric majorization-equalization algorithm", in *Proc. CAMSAP*, 2017, pp. 98-102.](https://arxiv.org/abs/1710.01589) - [19] [R. Ikeshita, "Independent positive semidefinite tensor analysis in blind source separation," in *Proc. EUSIPCO*, 2018, pp. 1652-1656.](https://ieeexplore.ieee.org/document/8553546) - [20] [T. Kondo, K. Fukushige, N. Takamune, D. Kitamura, H. Saruwatari, R. Ikeshita, and T. Nakatani, "Convergence-guaranteed independent positive semidefinite tensor analysis based on Student's t distribution," in *Proc ICASSP*, 2020, pp. 681-685.](https://ieeexplore.ieee.org/document/9054150) - [21] [A. Ozerov and C. Fevotte, "Multichannel nonnegative matrix factorization in convolutive mixtures for audio source separation," *IEEE Trans. ASLP*, vol. 18, no. 3, pp. 550-563, 2010.](https://ieeexplore.ieee.org/document/5229304) - [22] [H. Sawada, H. Kameoka, S. Araki, and N. Ueda, "Multichannel extensions of non-negative matrix factorization with complex-valued data," *IEEE Trans. ASLP*, vol. 21, no. 5, pp. 971-982, 2013.](https://ieeexplore.ieee.org/document/6410389) - [23] [K. Yoshii, K. Itoyama, and M. Goto, "Student's T nonnegative matrix factorization and positive semidefinite tensor factorization for single-channel audio source separation," in *Proc. ICASSP*, 2016, pp. 51-55.](https://ieeexplore.ieee.org/document/7471635) - [24] [K. Sekiguchi, A. A. Nugraha, Y. Bando, and K. Yoshii, "Fast multichannel source separation based on jointly diagonalizable spatial covariance matrices," in *Proc. EUSIPCO*, 2019, pp. 1-5.](https://arxiv.org/abs/1903.03237) - [25] [K. Yatabe and D. Kitamura, "Determined blind source separation via proximal splitting algorithm," in *Proc. ICASSP*, 2018, pp. 776-780.](https://ieeexplore.ieee.org/document/8462338) - [26] [K. Yatabe and D. Kitamura, "Time-frequency-masking-based determined BSS with application to sparse IVA," in *Proc. ICASSP*, 2019, pp. 715-719.](https://ieeexplore.ieee.org/document/8682217) - [27] [K. Yatabe and D. Kitamura, "Determined BSS based on time-frequency masking and its application to harmonic vector analysis," *IEEE/ACM Trans. ASLP*, vol. 29, pp. 1609–1625, 2021.](https://dl.acm.org/doi/abs/10.1109/TASLP.2021.3073863) - [28] [N. Ito, S. Araki, and T. Nakatani, "Complex angular central Gaussian mixture model for directional statistics in mask-based microphone array signal processing," in *Proc. EUSIPCO*, 2016, pp. 1153-1157.](https://ieeexplore.ieee.org/document/7760429) ## LICENSE Apache License 2.0 ================================================ FILE: codecov.yaml ================================================ coverage: status: project: default: target: auto threshold: 1% patch: default: target: auto threshold: 5% ================================================ FILE: docs/Makefile ================================================ # Minimal makefile for Sphinx documentation # # You can set these variables from the command line, and also # from the environment for the first two. SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build SOURCEDIR = . BUILDDIR = _build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) ================================================ FILE: docs/api.rst ================================================ APIs ==== Introduction ------------ .. code-block:: python import numpy as np import scipy.signal as ss import IPython.display as ipd import matplotlib.pyplot as plt from ssspy.utils.dataset import download_sample_speech_data from ssspy.transform import whiten from ssspy.algorithm import projection_back from ssspy.bss.fdica import AuxFDICA n_fft, hop_length = 4096, 2048 window = "hann" waveform_src_img = download_sample_speech_data(n_sources=3) waveform_mix = np.sum(waveform_src_img, axis=1) _, _, spectrogram_mix = ss.stft( waveform_mix, window=window, nperseg=n_fft, noverlap=n_fft-hop_length ) _, _, spectrogram_mix = ss.stft( waveform_mix, window=window, nperseg=n_fft, noverlap=n_fft-hop_length ) def contrast_fn(y): return 2 * np.abs(y) def d_contrast_fn(y): return 2 * np.ones_like(y) fdica = AuxFDICA( contrast_fn=contrast_fn, d_contrast_fn=d_contrast_fn, ) spectrogram_mix_whitened = whiten(spectrogram_mix) spectrogram_est = fdica(spectrogram_mix_whitened) spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix) _, waveform_est = ss.istft( spectrogram_est, window=window, nperseg=n_fft, noverlap=n_fft-hop_length ) for idx, waveform in enumerate(waveform_est): print("Estimated source: {}".format(idx + 1)) ipd.display(ipd.Audio(waveform, rate=16000)) print() plt.figure() plt.plot(fdica.loss) plt.show() Submodules ---------- .. toctree:: :maxdepth: 1 ssspy.bss ssspy.algorithm ssspy.transform ssspy.linalg ssspy.special ================================================ FILE: docs/changelog.rst ================================================ .. include:: ../CHANGELOG.rst ================================================ FILE: docs/conf.py ================================================ # Configuration file for the Sphinx documentation builder. # # This file only contains a selection of the most common options. For a full # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # # import os # import sys # sys.path.insert(0, os.path.abspath('.')) # -- Project information ----------------------------------------------------- project = "ssspy" copyright = "2022, Takuya Hasumi" author = "Takuya Hasumi" # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ "sphinx.ext.autodoc", "sphinx.ext.napoleon", "sphinx_autodoc_typehints", "nbsphinx", ] # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # html_theme = "furo" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". # html_static_path = ["_static"] ================================================ FILE: docs/index.rst ================================================ .. ssspy documentation master file, created by sphinx-quickstart on Fri Apr 29 20:59:12 2022. You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. Welcome to ssspy's documentation! ================================= .. image:: https://readthedocs.org/projects/sound-source-separation-python/badge/?version=latest :target: https://sound-source-separation-python.readthedocs.io/en/latest/?badge=latest .. image:: https://github.com/tky823/ssspy/actions/workflows/lint.yaml/badge.svg :target: https://github.com/tky823/ssspy/actions/workflows/lint.yaml .. image:: https://codecov.io/gh/tky823/ssspy/branch/main/graph/badge.svg?token=IZ89MTV64G :target: https://codecov.io/gh/tky823/ssspy ``ssspy`` is a Python toolkit for sound source separation. Build status ------------ +--------+--------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------+ | Python | Ubuntu | MacOS (x86_64) | MacOS (arm64) | Windows | +--------+--------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------+ | 3.9 | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.9.yaml/badge.svg?branch=main | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.9.yaml/badge.svg?branch=main | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.9.yaml/badge.svg?branch=main | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.9.yaml/badge.svg?branch=main | | | :target: https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.9.yaml | :target: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.9.yaml | :target: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.9.yaml | :target: https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.9.yaml | +--------+--------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+---------------------------------+ | 3.10 | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.10.yaml/badge.svg?branch=main | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.10.yaml/badge.svg?branch=main | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.10.yaml/badge.svg?branch=main | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.10.yaml/badge.svg?branch=main | | | :target: https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.10.yaml | :target: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.10.yaml | :target: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.10.yaml | :target: https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.10.yaml | +--------+--------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+---------------------------------+ | 3.11 | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.11.yaml/badge.svg?branch=main | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.11.yaml/badge.svg?branch=main | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.11.yaml/badge.svg?branch=main | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.11.yaml/badge.svg?branch=main | | | :target: https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.11.yaml | :target: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.11.yaml | :target: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.11.yaml | :target: https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.11.yaml | +--------+--------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+---------------------------------+ | 3.12 | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.12.yaml/badge.svg?branch=main | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.12.yaml/badge.svg?branch=main | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.12.yaml/badge.svg?branch=main | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.12.yaml/badge.svg?branch=main | | | :target: https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.12.yaml | :target: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.12.yaml | :target: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.12.yaml | :target: https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.12.yaml | +--------+--------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+---------------------------------+ Installation ------------ You can install ``ssspy`` by pip. .. code-block:: shell pip install ssspy To install latest version, .. code-block:: shell pip install git+https://github.com/tky823/ssspy.git Instead, you can build package from source. .. code-block:: shell git clone https://github.com/tky823/ssspy.git cd ssspy pip install -e . .. note:: If you fail to install ``ssspy``, please update ``setuptools`` by .. code-block:: shell python -m pip install --upgrade setuptools .. note:: If you cannot install `ssspy` due to failure in building wheel for numpy, please install numpy in advance. Build Documentation Locally (optional) -------------------------------------- To build the documentation locally, you have to include ``docs`` and ``notebooks`` when installing ``ssspy``. .. code-block:: shell pip install -e ".[docs,notebooks]" You need to convert some notebooks by the following command: .. code-block:: shell . ./docs/pre_build.sh When you build the documentation, run the following command. .. code-block:: shell cd docs/ make html Or, you can build the documentation automatically using ``sphinx-autobuild``. .. code-block:: shell # in ssspy/ sphinx-autobuild docs docs/_build/html .. toctree:: :maxdepth: 1 :caption: Contents: _notebooks/Getting-Started.rst changelog api Indices and tables ================== * :ref:`genindex` * :ref:`modindex` * :ref:`search` ================================================ FILE: docs/make.bat ================================================ @ECHO OFF pushd %~dp0 REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build ) set SOURCEDIR=. set BUILDDIR=_build if "%1" == "" goto help %SPHINXBUILD% >NUL 2>NUL if errorlevel 9009 ( echo. echo.The 'sphinx-build' command was not found. Make sure you have Sphinx echo.installed, then set the SPHINXBUILD environment variable to point echo.to the full path of the 'sphinx-build' executable. Alternatively you echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from echo.https://www.sphinx-doc.org/ exit /b 1 ) %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% goto end :help %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% :end popd ================================================ FILE: docs/pre_build.sh ================================================ #!/bin/bash # TODO: unify .readthedocs.yaml pip install -e ".[docs,notebooks]" # convert .ipynb to .rst format. jupyter nbconvert --execute notebooks/Examples/Getting-Started.ipynb --to notebook --output-dir docs/_notebooks/ ================================================ FILE: docs/ssspy.algorithm.rst ================================================ ssspy.algorithm =============== ``ssspy.algorithm`` provides algorithms related to source separation. Algorithms ~~~~~~~~~~ .. autofunction:: ssspy.algorithm.projection_back ================================================ FILE: docs/ssspy.bss.admmbss.rst ================================================ ssspy.bss.admmbss ================= Algorithms ~~~~~~~~~~ .. autoclass:: ssspy.bss.admmbss.ADMMBSSBase .. autoclass:: ssspy.bss.admmbss.ADMMBSS :special-members: __call__ :members: update_once ================================================ FILE: docs/ssspy.bss.base.rst ================================================ ssspy.bss.base ============== In this module, we provide base class of blind source separation methods. Algorithms ~~~~~~~~~~ .. autoclass:: ssspy.bss.base.IterativeMethodBase :special-members: __call__ :members: :undoc-members: ================================================ FILE: docs/ssspy.bss.cacgmm.rst ================================================ ssspy.bss.cacgmm ================ Algorithms ~~~~~~~~~~ .. autoclass:: ssspy.bss.cacgmm.CACGMM :special-members: __call__ :members: separate, normalize_covariance, update_once, update_posterior, update_parameters, compute_loss, solve_permutation ================================================ FILE: docs/ssspy.bss.fdica.rst ================================================ ssspy.bss.fdica =============== In this module, we separate multichannel signals using frequency-domain independent component analysis (FDICA). We denote the number of sources and microphones as :math:`N` and :math:`M`, respectively. We also denote short-time Fourier transforms of source, observed, and separated signals as :math:`\boldsymbol{s}_{ij}`, :math:`\boldsymbol{x}_{ij}`, and :math:`\boldsymbol{y}_{ij}`, respectively. .. math:: \boldsymbol{s}_{ij} &= (s_{ij1},\ldots,s_{ijn},\ldots,s_{ijN})^{\mathsf{T}}\in\mathbb{C}^{N}, \\ \boldsymbol{x}_{ij} &= (x_{ij1},\ldots,x_{ijm},\ldots,x_{ijM})^{\mathsf{T}}\in\mathbb{C}^{M}, \\ \boldsymbol{y}_{ij} &= (y_{ij1},\ldots,y_{ijn},\ldots,y_{ijN})^{\mathsf{T}}\in\mathbb{C}^{N}, where :math:`i=1,\ldots,I` and :math:`j=1,\ldots,J` are indices of frequency bins and time frames, respectively. When a mixing system is time-invariant, :math:`\boldsymbol{x}_{ij}` is represented as follows: .. math:: \boldsymbol{x}_{ij} = \boldsymbol{A}_{i}\boldsymbol{s}_{ij}, where :math:`\boldsymbol{A}_{i}=(\boldsymbol{a}_{i1},\ldots,\boldsymbol{a}_{in},\ldots,\boldsymbol{a}_{iN})\in\mathbb{C}^{M\times N}` is a mixing matrix. If :math:`M=N` and :math:`\boldsymbol{A}_{i}` is non-singular, a demixing system is represented as .. math:: \boldsymbol{y}_{ij} = \boldsymbol{W}_{i}\boldsymbol{x}_{ij}, where :math:`\boldsymbol{W}_{i}=(\boldsymbol{w}_{i1},\ldots,\boldsymbol{w}_{in},\ldots,\boldsymbol{w}_{iN})^{\mathsf{H}}\in\mathbb{C}^{N\times M}` is a demixing matrix. The negative log-likelihood of observed signals (divided by :math:`J`) is computed as follows: .. math:: \mathcal{L} &= -\frac{1}{J}\log p(\mathcal{X}) \\ &= -\frac{1}{J}\left(\log p(\mathcal{Y}) \ + \sum_{i}\log|\det\boldsymbol{W}_{i}|^{2J} \right) \\ &= -\frac{1}{J}\sum_{i,j,n}\log p(y_{ijn}) - 2\sum_{i}\log|\det\boldsymbol{W}_{i}| \\ &= \sum_{i}\mathcal{L}^{[i]}, \\ \mathcal{L}^{[i]} \ &= \frac{1}{J}\sum_{j,n}G(y_{ijn}) - 2\log|\det\boldsymbol{W}_{i}|, \\ G(y_{ijn}) &= -\log p(y_{ijn}), where :math:`G(y_{ijn})` is a contrast function. The derivative of :math:`G(y_{ijn})` is called a score function. .. math:: \phi(y_{ijn}) = \frac{\partial G(y_{ijn})}{\partial y_{ijn}^{*}}. Algorithms ~~~~~~~~~~ .. autoclass:: ssspy.bss.fdica.FDICABase :special-members: __call__ :members: separate, compute_loss, compute_logdet, restore_scale, apply_projection_back, solve_permutation .. autoclass:: ssspy.bss.fdica.GradFDICABase :special-members: __call__ .. autoclass:: ssspy.bss.fdica.GradFDICA :members: update_once .. autoclass:: ssspy.bss.fdica.NaturalGradFDICA :members: update_once .. autoclass:: ssspy.bss.fdica.AuxFDICA :special-members: __call__ :members: update_once, update_once_ip1, update_once_ip2 .. autoclass:: ssspy.bss.fdica.GradLaplaceFDICA .. autoclass:: ssspy.bss.fdica.NaturalGradLaplaceFDICA .. autoclass:: ssspy.bss.fdica.AuxLaplaceFDICA ================================================ FILE: docs/ssspy.bss.hva.rst ================================================ ssspy.bss.hva ============= Algorithms ~~~~~~~~~~ .. autoclass:: ssspy.bss.hva.MaskingPDSHVA .. autoclass:: ssspy.bss.hva.MaskingADMMHVA .. autoclass:: ssspy.bss.hva.HVA ================================================ FILE: docs/ssspy.bss.ica.rst ================================================ ssspy.bss.ica ============= In this module, we separate time-domain multichannel signals using independent component analysis (ICA) [#comon1994independent]_. We denote the number of sources and microphones as :math:`N` and :math:`M`, respectively. We also denote source, observed, and separated signals (in time-domain) as :math:`\boldsymbol{s}_{t}`, :math:`\boldsymbol{x}_{t}`, and :math:`\boldsymbol{y}_{t}`, respectively. .. math:: \boldsymbol{s}_{t} &= (s_{t1},\ldots,s_{tn},\ldots,s_{tN})^{\mathsf{T}}\in\mathbb{R}^{N}, \\ \boldsymbol{x}_{t} &= (x_{t1},\ldots,x_{tm},\ldots,x_{tM})^{\mathsf{T}}\in\mathbb{R}^{M}, \\ \boldsymbol{y}_{t} &= (y_{t1},\ldots,y_{tn},\ldots,y_{tN})^{\mathsf{T}}\in\mathbb{R}^{N}, where :math:`t=1,\ldots,T` is an index of time samples. When a mixing system is time-invariant, :math:`\boldsymbol{x}_{t}` is represented as follows: .. math:: \boldsymbol{x}_{t} = \boldsymbol{A}\boldsymbol{s}_{t}, where :math:`\boldsymbol{A}=(\boldsymbol{a}_{1},\ldots,\boldsymbol{a}_{n},\ldots,\boldsymbol{a}_{N})\in\mathbb{R}^{M\times N}` is a mixing matrix. If :math:`M=N` and :math:`\boldsymbol{A}` is non-singular, a demixing system is represented as .. math:: \boldsymbol{y}_{t} = \boldsymbol{W}\boldsymbol{x}_{t}, where :math:`\boldsymbol{W}=(\boldsymbol{w}_{1},\ldots,\boldsymbol{w}_{n},\ldots,\boldsymbol{w}_{N})^{\mathsf{T}}\in\mathbb{R}^{N\times M}` is a demixing matrix. The negative log-likelihood of observed signals (divided by :math:`T`) is computed as follows: .. math:: \mathcal{L} &= -\frac{1}{T}\log p(\mathcal{X}) \\ &= -\frac{1}{T}\left(\log p(\mathcal{Y}) \ + \log|\det\boldsymbol{W}|^{T} \right) \\ &= -\frac{1}{T}\sum_{t,n}\log p(y_{tn}) - \log|\det\boldsymbol{W}| \\ &= \frac{1}{T}\sum_{t,n}G(y_{tn}) - \log|\det\boldsymbol{W}|, \\ G(y_{tn}) &= -\log p(y_{tn}), where :math:`G(y_{tn})` is a contrast function. The derivative of :math:`G(y_{tn})` is called a score function. .. math:: \phi(y_{tn}) = \frac{\partial G(y_{tn})}{\partial y_{ijn}}. .. [#comon1994independent] P. Comon, "Independent component analysis, a new concept?" *Signal Processing*, vol. 36, no. 3, pp. 287-314, 1994. Algorithms ~~~~~~~~~~ .. autoclass:: ssspy.bss.ica.GradICABase :special-members: __call__ :members: separate, compute_loss, compute_logdet .. autoclass:: ssspy.bss.ica.FastICABase :special-members: __call__ :members: separate, compute_loss .. autoclass:: ssspy.bss.ica.GradICA :members: update_once .. autoclass:: ssspy.bss.ica.NaturalGradICA :members: update_once .. autoclass:: ssspy.bss.ica.FastICA :members: update_once .. autoclass:: ssspy.bss.ica.GradLaplaceICA :members: update_once, compute_loss .. autoclass:: ssspy.bss.ica.NaturalGradLaplaceICA :members: update_once, compute_loss ================================================ FILE: docs/ssspy.bss.ilrma.rst ================================================ ssspy.bss.ilrma =============== In this module, we separate multichannel signals using independent low-rank matrix analysis (ILRMA). We denote the number of sources and microphones as :math:`N` and :math:`M`, respectively. We also denote short-time Fourier transforms of source, observed, and separated signals as :math:`\boldsymbol{s}_{ij}`, :math:`\boldsymbol{x}_{ij}`, and :math:`\boldsymbol{y}_{ij}`, respectively. .. math:: \boldsymbol{s}_{ij} &= (s_{ij1},\ldots,s_{ijn},\ldots,s_{ijN})^{\mathsf{T}}\in\mathbb{C}^{N}, \\ \boldsymbol{x}_{ij} &= (x_{ij1},\ldots,x_{ijm},\ldots,x_{ijM})^{\mathsf{T}}\in\mathbb{C}^{M}, \\ \boldsymbol{y}_{ij} &= (y_{ij1},\ldots,y_{ijn},\ldots,y_{ijN})^{\mathsf{T}}\in\mathbb{C}^{N}, where :math:`i=1,\ldots,I` and :math:`j=1,\ldots,J` are indices of frequency bins and time frames, respectively. When a mixing system is time-invariant, :math:`\boldsymbol{x}_{ij}` is represented as follows: .. math:: \boldsymbol{x}_{ij} = \boldsymbol{A}_{i}\boldsymbol{s}_{ij}, where :math:`\boldsymbol{A}_{i}=(\boldsymbol{a}_{i1},\ldots,\boldsymbol{a}_{in},\ldots,\boldsymbol{a}_{iN})\in\mathbb{C}^{M\times N}` is a mixing matrix. If :math:`M=N` and :math:`\boldsymbol{A}_{i}` is non-singular, a demixing system is represented as .. math:: \boldsymbol{y}_{ij} = \boldsymbol{W}_{i}\boldsymbol{x}_{ij}, where :math:`\boldsymbol{W}_{i}=(\boldsymbol{w}_{i1},\ldots,\boldsymbol{w}_{in},\ldots,\boldsymbol{w}_{iN})^{\mathsf{H}}\in\mathbb{C}^{N\times M}` is a demixing matrix. The negative log-likelihood of observed signals (divided by :math:`J`) is computed as follows: .. math:: \mathcal{L} &= -\frac{1}{J}\log p(\mathcal{X}) \\ &= -\frac{1}{J}\left(\log p(\mathcal{Y}) \ + \sum_{i}\log|\det\boldsymbol{W}_{i}|^{2J} \right) \\ &= -\frac{1}{J}\sum_{i,j,n}\log p(y_{ijn}) - 2\sum_{i}\log|\det\boldsymbol{W}_{i}|. Algorithms ~~~~~~~~~~ .. autoclass:: ssspy.bss.ilrma.ILRMABase :special-members: __call__ :members: _init_nmf, separate, reconstruct_nmf, update_once, normalize, normalize_by_power, normalize_by_projection_back, compute_loss, compute_logdet, restore_scale, apply_projection_back .. autoclass:: ssspy.bss.ilrma.GaussILRMA :special-members: __call__ :members: update_once, update_source_model, update_source_model_mm, update_source_model_me, update_latent_mm, update_basis_mm, update_activation_mm, update_latent_me, update_basis_me, update_activation_me, update_spatial_model, update_spatial_model_ip1, update_spatial_model_ip2, update_spatial_model_iss1, update_spatial_model_iss2, update_spatial_model_ipa, compute_loss, apply_projection_back .. autoclass:: ssspy.bss.ilrma.TILRMA :special-members: __call__ :members: update_once, update_source_model, update_source_model_mm, update_source_model_me, update_latent_mm, update_basis_mm, update_activation_mm, update_latent_me, update_basis_me, update_activation_me, update_spatial_model, update_spatial_model_ip1, update_spatial_model_ip2, update_spatial_model_iss1, update_spatial_model_iss2, compute_loss, apply_projection_back .. autoclass:: ssspy.bss.ilrma.GGDILRMA :special-members: __call__ :members: update_once, update_source_model, update_source_model_mm, update_latent_mm, update_basis_mm, update_activation_mm, update_spatial_model, update_spatial_model_ip1, update_spatial_model_ip2, update_spatial_model_iss1, update_spatial_model_iss2, compute_loss, apply_projection_back ================================================ FILE: docs/ssspy.bss.iva.rst ================================================ ssspy.bss.iva ============= In this module, we separate multichannel signals using independent vector analysis (IVA). We denote the number of sources and microphones as :math:`N` and :math:`M`, respectively. We also denote short-time Fourier transforms of source, observed, and separated signals as :math:`\boldsymbol{s}_{ij}`, :math:`\boldsymbol{x}_{ij}`, and :math:`\boldsymbol{y}_{ij}`, respectively. .. math:: \boldsymbol{s}_{ij} &= (s_{ij1},\ldots,s_{ijn},\ldots,s_{ijN})^{\mathsf{T}}\in\mathbb{C}^{N}, \\ \boldsymbol{x}_{ij} &= (x_{ij1},\ldots,x_{ijm},\ldots,x_{ijM})^{\mathsf{T}}\in\mathbb{C}^{M}, \\ \boldsymbol{y}_{ij} &= (y_{ij1},\ldots,y_{ijn},\ldots,y_{ijN})^{\mathsf{T}}\in\mathbb{C}^{N}, where :math:`i=1,\ldots,I` and :math:`j=1,\ldots,J` are indices of frequency bins and time frames, respectively. We also define the following vector: .. math:: \vec{\boldsymbol{y}}_{jn} = (y_{1jn},\ldots,y_{ijn},\ldots,y_{Ijn})^{\mathsf{T}}\in\mathbb{C}^{I}. When a mixing system is time-invariant, :math:`\boldsymbol{x}_{ij}` is represented as follows: .. math:: \boldsymbol{x}_{ij} = \boldsymbol{A}_{i}\boldsymbol{s}_{ij}, where :math:`\boldsymbol{A}_{i}=(\boldsymbol{a}_{i1},\ldots,\boldsymbol{a}_{in},\ldots,\boldsymbol{a}_{iN})\in\mathbb{C}^{M\times N}` is a mixing matrix. If :math:`M=N` and :math:`\boldsymbol{A}_{i}` is non-singular, a demixing system is represented as .. math:: \boldsymbol{y}_{ij} = \boldsymbol{W}_{i}\boldsymbol{x}_{ij}, where :math:`\boldsymbol{W}_{i}=(\boldsymbol{w}_{i1},\ldots,\boldsymbol{w}_{in},\ldots,\boldsymbol{w}_{iN})^{\mathsf{H}}\in\mathbb{C}^{N\times M}` is a demixing matrix. The negative log-likelihood of observed signals (divided by :math:`J`) is computed as follows: .. math:: \mathcal{L} &= -\frac{1}{J}\log p(\mathcal{X}) \\ &= -\frac{1}{J}\left(\log p(\mathcal{Y}) \ + \sum_{i}\log|\det\boldsymbol{W}_{i}|^{2J} \right) \\ &= -\frac{1}{J}\sum_{j,n}\log p(\vec{\boldsymbol{y}}_{jn}) - 2\sum_{i}\log|\det\boldsymbol{W}_{i}| \\ &= \frac{1}{J}\sum_{j,n}G(\vec{\boldsymbol{y}}_{jn}) - 2\sum_{i}\log|\det\boldsymbol{W}_{i}|, \\ G(\vec{\boldsymbol{y}}_{jn}) &= -\log p(\vec{\boldsymbol{y}}_{jn}), where :math:`G(\vec{\boldsymbol{y}}_{jn})` is a contrast function. The derivative of :math:`G(\vec{\boldsymbol{y}}_{jn})` is called a score function. .. math:: \phi_{i}(\vec{\boldsymbol{y}}_{jn}) = \frac{\partial G(\vec{\boldsymbol{y}}_{jn})}{\partial y_{ijn}^{*}}. Algorithms ~~~~~~~~~~ .. autoclass:: ssspy.bss.iva.IVABase :special-members: __call__ :members: separate, update_once, compute_loss, compute_logdet, restore_scale, apply_projection_back .. autoclass:: ssspy.bss.iva.GradIVABase .. autoclass:: ssspy.bss.iva.FastIVABase :members: separate, compute_loss, apply_projection_back .. autoclass:: ssspy.bss.iva.AuxIVABase :special-members: __call__ :members: separate, compute_loss, apply_projection_back .. autoclass:: ssspy.bss.iva.GradIVA :members: update_once .. autoclass:: ssspy.bss.iva.NaturalGradIVA :members: update_once .. autoclass:: ssspy.bss.iva.FastIVA :special-members: __call__ :members: update_once .. autoclass:: ssspy.bss.iva.FasterIVA :special-members: __call__ :members: update_once .. autoclass:: ssspy.bss.iva.AuxIVA :special-members: __call__ :members: update_once, update_once_ip1, update_once_ip2, update_once_iss1, update_once_iss2, update_once_ipa .. autoclass:: ssspy.bss.iva.GradLaplaceIVA :members: update_once, compute_loss .. autoclass:: ssspy.bss.iva.GradGaussIVA :members: update_once, update_source_model .. autoclass:: ssspy.bss.iva.NaturalGradLaplaceIVA :members: update_once, compute_loss .. autoclass:: ssspy.bss.iva.NaturalGradGaussIVA :members: update_once, compute_loss .. autoclass:: ssspy.bss.iva.AuxLaplaceIVA .. autoclass:: ssspy.bss.iva.AuxGaussIVA :members: update_once, update_source_model ================================================ FILE: docs/ssspy.bss.mnmf.rst ================================================ ssspy.bss.mnmf ============== Algorithms ~~~~~~~~~~ .. autoclass:: ssspy.bss.mnmf.FastMNMFBase :special-members: __call__ :members: normalize, normalize_by_power .. autoclass:: ssspy.bss.mnmf.FastGaussMNMF :special-members: __call__ :members: separate, compute_loss, compute_logdet, update_once, update_basis, update_activation, update_diagonalizer, update_spatial, update_diagonalizer_ip1, update_diagonalizer_ip2 ================================================ FILE: docs/ssspy.bss.pdsbss.rst ================================================ ssspy.bss.pdsbss ================ In this module, we separate multichannel signals using blind source separation via primal dual splitting algorithm. We denote the number of sources and microphones as :math:`N` and :math:`M`, respectively. We also denote short-time Fourier transforms of source, observed, and separated signals as :math:`\boldsymbol{s}_{ij}`, :math:`\boldsymbol{x}_{ij}`, and :math:`\boldsymbol{y}_{ij}`, respectively. .. math:: \boldsymbol{s}_{ij} &= (s_{ij1},\ldots,s_{ijn},\ldots,s_{ijN})^{\mathsf{T}}\in\mathbb{C}^{N}, \\ \boldsymbol{x}_{ij} &= (x_{ij1},\ldots,x_{ijm},\ldots,x_{ijM})^{\mathsf{T}}\in\mathbb{C}^{M}, \\ \boldsymbol{y}_{ij} &= (y_{ij1},\ldots,y_{ijn},\ldots,y_{ijN})^{\mathsf{T}}\in\mathbb{C}^{N}, where :math:`i=1,\ldots,I` and :math:`j=1,\ldots,J` are indices of frequency bins and time frames, respectively. When a mixing system is time-invariant, :math:`\boldsymbol{x}_{ij}` is represented as follows: .. math:: \boldsymbol{x}_{ij} = \boldsymbol{A}_{i}\boldsymbol{s}_{ij}, where :math:`\boldsymbol{A}_{i}=(\boldsymbol{a}_{i1},\ldots,\boldsymbol{a}_{in},\ldots,\boldsymbol{a}_{iN})\in\mathbb{C}^{M\times N}` is a mixing matrix. If :math:`M=N` and :math:`\boldsymbol{A}_{i}` is non-singular, a demixing system is represented as .. math:: \boldsymbol{y}_{ij} = \boldsymbol{W}_{i}\boldsymbol{x}_{ij}, where :math:`\boldsymbol{W}_{i}=(\boldsymbol{w}_{i1},\ldots,\boldsymbol{w}_{in},\ldots,\boldsymbol{w}_{iN})^{\mathsf{H}}\in\mathbb{C}^{N\times M}` is a demixing matrix. The negative log-likelihood of observed signals (divided by :math:`2J`) is computed as follows: .. math:: \mathcal{L} &= \mathcal{P}(\mathcal{V}(\mathcal{Y})) + \sum_{i}\mathcal{I}(\boldsymbol{W}_{i}), \\ \mathcal{V}(\mathcal{Y}) &:= (y_{111},\ldots,y_{11N},\ldots,y_{1JN},\ldots,y_{IJN})^{\mathsf{T}} \in\mathbb{C}^{IJN} \\ \mathcal{I}(\boldsymbol{W}_{i}) &= - \log|\det\boldsymbol{W}_{i}|, where :math:`\mathcal{P}` is a penalty funcion that is determined by the source model. Let us consider independent vector analysis. In this case, :math:`\mathcal{P}` can be written by .. math:: \mathcal{P}(\mathcal{V}(\mathcal{Y})) = C\sum_{j,n}\left( \sum_{i}\left|\boldsymbol{w}_{in}^{\mathsf{H}}\boldsymbol{x}_{ij}\right|^{2} \right)^{\frac{1}{2}}, where :math:`C` is a positive constant. To the above formulation, we can apply the primal-dual splitting algorithm. On the basis of this algorithm, the demixing filter is updated as follows: .. math:: \tilde{\boldsymbol{W}}_{i} &\leftarrow\mathrm{prox}_{\mu_{1}\mathcal{I}} \left[\boldsymbol{W}_{i} - \mu_{1}\mu_{2}\sum_{j}\boldsymbol{u}_{ij}\boldsymbol{x}_{ij}^{\mathsf{H}}\right] \\ \boldsymbol{z}_{ij} &\leftarrow\boldsymbol{u}_{ij} + \left(2 * \tilde{\boldsymbol{W}}_{i} - \boldsymbol{W}_{i}\right)\boldsymbol{x}_{ij} \\ \mathcal{V}(\tilde{\mathcal{U}}) &\leftarrow\mathcal{V}(\mathcal{Z}) - \mathrm{prox}_{\mathcal{P}/\mu_{2}}\left[\mathcal{V}(\mathcal{Z})\right] \\ \boldsymbol{u}_{ij} &\leftarrow\alpha\tilde{\boldsymbol{u}}_{ij} + (1 - \alpha)\boldsymbol{u}_{ij}, \\ \boldsymbol{W}_{i} &\leftarrow\alpha\tilde{\boldsymbol{W}}_{i} + (1 - \alpha)\boldsymbol{W}_{i}. :math:`\boldsymbol{u}_{ij}` is a dual variable, which should be initialized by a certain value. :math:`\mathrm{prox}_{g}` is a proximal operator defined as .. math:: \mathrm{prox}_{g}[\boldsymbol{z}] = \mathrm{argmin}_{\boldsymbol{y}} ~~g(\boldsymbol{y}) + \frac{1}{2}\|\boldsymbol{z} - \boldsymbol{y}\|_{2}^{2}. For :math:`\mathcal{I}`, we can obatain the following proximal operator: .. math:: \mathrm{prox}_{\mu\mathcal{I}}[\boldsymbol{W}_{i}] &= \boldsymbol{U}_{i}\tilde{\boldsymbol{\Sigma}}_{i}\boldsymbol{V}_{i}^{\mathsf{H}}, \\ \tilde{\boldsymbol{\Sigma}}_{i} &= \mathrm{diag}(\tilde{\sigma}_{i1},\ldots,\tilde{\sigma}_{iN}), \\ \tilde{\sigma}_{in} &= \frac{\sigma_{in} + \sqrt{\sigma_{in}^{2} + 4\mu}}{2}, where :math:`\boldsymbol{U}_{i}`, :math:`\boldsymbol{V}_{i}`, and :math:`\boldsymbol{\Sigma}_{i}=\mathrm{diag}(\sigma_{i1},\ldots,\sigma_{iN})` are singular value decomposition. .. math:: \boldsymbol{W}_{i} = \boldsymbol{U}_{i}\boldsymbol{\Sigma}_{i}\boldsymbol{V}_{i}^{\mathsf{H}}. When :math:`\mathcal{P}` is defined as .. math:: \mathcal{P}(\mathcal{V}(\mathcal{Y})) = C\sum_{j,n}\left( \sum_{i}\left|\boldsymbol{w}_{in}^{\mathsf{H}}\boldsymbol{x}_{ij}\right|^{2} \right)^{\frac{1}{2}}, the updates by the proximal operator can be written as .. math:: y_{ijn} \leftarrow\left(1 - \frac{\mu}{\sqrt{\sum_{i}|y_{ijn}|^{2}}}\right)_{+}y_{ijn}. Algorithms ~~~~~~~~~~ .. autoclass:: ssspy.bss.pdsbss.PDSBSSBase .. autoclass:: ssspy.bss.pdsbss.PDSBSS :special-members: __call__ :members: update_once ================================================ FILE: docs/ssspy.bss.proxbss.rst ================================================ ssspy.bss.proxbss ================= Algorithms ~~~~~~~~~~ .. autoclass:: ssspy.bss.proxbss.ProxBSSBase :special-members: __call__ :members: separate, compute_loss, compute_logdet, normalize_by_spectral_norm, restore_scale, apply_projection_back, apply_minimal_distortion_principle ================================================ FILE: docs/ssspy.bss.rst ================================================ ssspy.bss ========= ``ssspy.bss`` provides various blind source separation methods. Submodules ~~~~~~~~~~ .. toctree:: :maxdepth: 1 ssspy.bss.base ssspy.bss.ica ssspy.bss.fdica ssspy.bss.iva ssspy.bss.ilrma ssspy.bss.mnmf ssspy.bss.proxbss ssspy.bss.pdsbss ssspy.bss.admmbss ssspy.bss.hva ssspy.bss.cacgmm ================================================ FILE: docs/ssspy.linalg.rst ================================================ ssspy.linalg ============ ``ssspy.linalg`` is linear algebra module related to source separation. Algorithms ~~~~~~~~~~ .. autofunction:: ssspy.linalg.inv2 .. autofunction:: ssspy.linalg.eigh .. autofunction:: ssspy.linalg.eigh2 .. autofunction:: ssspy.linalg.gmeanmh .. autofunction:: ssspy.linalg.lqpqm2 ================================================ FILE: docs/ssspy.special.rst ================================================ ssspy.special ============= ``ssspy.special`` is a module related to special function. Algorithms ~~~~~~~~~~ .. autofunction:: ssspy.special.logsumexp .. autofunction:: ssspy.special.softmax ================================================ FILE: docs/ssspy.transform.rst ================================================ ssspy.transform =============== ``ssspy.transform`` provides transforms related to source separation. Algorithms ~~~~~~~~~~ .. autofunction:: ssspy.transform.pca .. autofunction:: ssspy.transform.whiten ================================================ FILE: notebooks/BSS/ADMMBSS/ADMMBSS.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyPQC07HvYHC5TQxnCCmzCwE"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"IuOxOjnRWK-4"},"outputs":[],"source":["!pip install git+https://github.com/tky823/ssspy.git"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"CdZN4hyHWOy8"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"LcS_A4hyWRqL"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 2\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"O7PTMc5qWTWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"K4qwWkHcWvzv"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"x3gWs4LgWxSD"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.transform import whiten\n","from ssspy.algorithm import projection_back\n","from ssspy.bss.admmbss import ADMMBSS\n","from ssspy.linalg import prox"],"metadata":{"id":"j7Gd1k-8Wy0L"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def l21_fn(y: np.ndarray) -> np.ndarray:\n"," \"\"\"Mixed L21 norm.\n","\n"," Args:\n"," y (np.ndarray):\n"," Input vector with shape of (n_sources, n_bins, n_frames).\n","\n"," Returns:\n"," Sum of mixed L21 norm.\n"," \"\"\"\n"," G = np.linalg.norm(y, axis=1)\n"," loss = np.sum(G, axis=(0, 1))\n","\n"," return loss\n","\n","def prox_l21(y, step_size: float = 1) -> np.ndarray:\n"," \"\"\"Apply proximal operator of mixed L21 norm.\n","\n"," Args:\n"," y (np.ndarray):\n"," Input vector with shape of (n_sources, n_bins, n_frames).\n"," step_size (float):\n"," Step size parameter.\n","\n"," Returns:\n"," Output value computed by proximal operator of mixed L21 norm.\n"," The shape of (n_sources, n_bins, n_frames).\n"," \"\"\"\n"," norm = np.linalg.norm(y, axis=1, keepdims=True)\n","\n"," # to suppress warning RuntimeWarning\n"," norm = np.where(norm < step_size, step_size, norm)\n","\n"," return np.maximum(1 - step_size / norm, 0) * y"],"metadata":{"id":"w87aPEguW1jU"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["admm_bss = ADMMBSS(\n"," rho=0.5,\n"," relaxation=1.75,\n"," penalty_fn=l21_fn,\n"," prox_penalty=prox_l21,\n"," scale_restoration=False,\n",")\n","print(admm_bss)"],"metadata":{"id":"YrwbKut7W9AE"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"xoI6OCVvXAOm"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_mix_whitened = whiten(spectrogram_mix)\n","spectrogram_mix_normalized = admm_bss.normalize_by_spectral_norm(spectrogram_mix_whitened)\n","spectrogram_est = admm_bss(spectrogram_mix_normalized, n_iter=1000)\n","spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)"],"metadata":{"id":"L6B1UOU0XCJf"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EoeJuHniXDiG"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"7c0RlLEFXGgM"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(admm_bss.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"XnPQAUxGXH2i"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"y9ydVrlzXaQH"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/ADMMBSS/ADMMBSS_multi-penalty.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyOoGT5woLRACpymG/HuTm0y"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"2A1FpkQ-h3ks"},"outputs":[],"source":["!pip install git+https://github.com/tky823/ssspy.git"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"T-AFlnULiPGj"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"SLUkWHVGiQN4"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 2\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"UHG42oI4iRPU"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"1jdtdFu9iSRX"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"vCzucZlBiTps"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["import functools"],"metadata":{"id":"wqekYV1aiU0n"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.transform import whiten\n","from ssspy.algorithm import projection_back\n","from ssspy.bss.admmbss import ADMMBSS"],"metadata":{"id":"-78PZk8WiV2z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def l21_fn(y: np.ndarray) -> np.ndarray:\n"," \"\"\"Compute sum of mixed L21 norm.\n","\n"," Args:\n"," y (np.ndarray):\n"," Input vector with shape of (n_sources, n_bins, n_frames).\n","\n"," Returns:\n"," Sum of mixed L21 norm.\n"," \"\"\"\n"," G = np.linalg.norm(y, axis=1)\n"," loss = np.sum(G, axis=(0, 1))\n","\n"," return loss\n","\n","def lamb_l1_fn(y: np.ndarray, lamb: float = 1) -> np.ndarray:\n"," \"\"\"Compute sum of L1 norm.\n","\n"," Args:\n"," y (np.ndarray):\n"," Input vector with shape of (n_sources, n_bins, n_frames).\n","\n"," Returns:\n"," Sum of L1 norm.\n"," \"\"\"\n"," G = np.abs(y)\n"," loss = lamb * np.sum(G, axis=(0, 1, 2))\n","\n"," return loss\n","\n","def prox_l21(y: np.ndarray, step_size: float = 1) -> np.ndarray:\n"," \"\"\"Apply proximal operator of mixed L21 norm.\n","\n"," Args:\n"," y (np.ndarray):\n"," Input vector with shape of (n_sources, n_bins, n_frames).\n"," step_size (float):\n"," Step size parameter.\n","\n"," Returns:\n"," Output value computed by proximal operator of mixed L21 norm.\n"," The shape of (n_sources, n_bins, n_frames).\n"," \"\"\"\n"," norm = np.linalg.norm(y, axis=1, keepdims=True)\n","\n"," # to suppress warning RuntimeWarning\n"," norm = np.where(norm < step_size, step_size, norm)\n","\n"," return np.maximum(1 - step_size / norm, 0) * y\n","\n","def prox_lamb_l1(y: np.ndarray, step_size: float=1, lamb: float = 1) -> np.ndarray:\n"," \"\"\"Apply proximal operator of L1 norm.\n","\n"," Args:\n"," y (np.ndarray):\n"," Input vector with shape of (n_sources, n_bins, n_frames).\n"," step_size (float):\n"," Step size parameter.\n","\n"," Returns:\n"," Output value computed by proximal operator of L1 norm.\n"," The shape of (n_sources, n_bins, n_frames).\n"," \"\"\"\n"," norm = np.abs(y)\n","\n"," # to suppress warning RuntimeWarning\n"," norm = np.where(norm < step_size * lamb, step_size * lamb, norm)\n","\n"," return np.maximum(1 - (step_size * lamb) / norm, 0) * y"],"metadata":{"id":"FrDLnrzOiYYM"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# SparseIVA without masking\n","penalty_fn = [\n"," l21_fn,\n"," functools.partial(lamb_l1_fn, lamb=1e-4),\n","]\n","prox_penalty = [\n"," prox_l21,\n"," functools.partial(prox_lamb_l1, lamb=1e-4),\n","]"],"metadata":{"id":"wk0H7Lj9iaOQ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["admm_bss = ADMMBSS(\n"," rho=0.5,\n"," relaxation=1.75,\n"," penalty_fn=penalty_fn,\n"," prox_penalty=prox_penalty,\n"," scale_restoration=False\n",")\n","print(admm_bss)"],"metadata":{"id":"lguDznJaihz_"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"T4aYnKTXirf4"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_mix_whitened = whiten(spectrogram_mix)\n","spectrogram_mix_normalized = admm_bss.normalize_by_spectral_norm(spectrogram_mix_whitened)\n","spectrogram_est = admm_bss(spectrogram_mix_normalized, n_iter=1000)\n","spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)"],"metadata":{"id":"w9DwgYhOitYv"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"TuUenwuBivge"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"8HjLaHqnixOp"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(admm_bss.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"ENGBhkjHiyTi"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"dssD7pzZkZ0H"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/CACGMM/CACGMM.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"mPK5sQpmunbL"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd\n","from tqdm.notebook import tqdm"],"metadata":{"id":"gF-CVqpZuq7y"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"LJto4YUHusBK"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 2\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"RypuaK8GutWj"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"F8UVIE67uwNv"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"M6rHkLzcuzZL"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.cacgmm import CACGMM as CACGMMBase"],"metadata":{"id":"mBxB5tm2u0X7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class CACGMM(CACGMMBase):\n"," def __init__(self, *args, **kwargs):\n"," super().__init__(*args, **kwargs)\n","\n"," self.progress_bar = None\n","\n"," def __call__(\n"," self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs\n"," ) -> np.ndarray:\n"," self.n_iter = n_iter\n","\n"," return super().__call__(input, n_iter=n_iter, initial_call=initial_call, **kwargs)\n","\n"," def update_once(self) -> None:\n"," if self.progress_bar is None:\n"," self.progress_bar = tqdm(total=self.n_iter)\n","\n"," super().update_once()\n","\n"," self.progress_bar.update(1)"],"metadata":{"id":"12peq3Xyu2WM"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["cacgmm = CACGMM(\n"," permutation_alignment=\"posterior_score\",\n"," global_iter=100,\n"," local_iter=100,\n"," rng=np.random.default_rng(42)\n",")\n","print(cacgmm)"],"metadata":{"id":"P7QNo371u3YP"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"g-vqiYCUu6id"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = cacgmm(spectrogram_mix, n_iter=200)"],"metadata":{"id":"QQ3MmBwpu6f3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"tMz1cFCKu9Nk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"La9fImOkvAE5"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(cacgmm.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"ySSBSypivAAN"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"AyyuPUvlyTxN"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/FDICA/AuxFDICA-IP1.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"tLYLOhzujeH1"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"3WPpY4HOjmvT"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"AdivKHWiDs5L"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 2\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"Xve1xr0Ijn-O"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"ist3AxTsjtLi"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"LFWwc0y9kJDq"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.transform import whiten\n","from ssspy.algorithm import projection_back\n","from ssspy.bss.fdica import AuxFDICA"],"metadata":{"id":"ArUmiPttkKYH"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def contrast_fn(y):\n"," return 2 * np.abs(y)\n","\n","def d_contrast_fn(y):\n"," return 2 * np.ones_like(y)"],"metadata":{"id":"6lTfgEW9kNN0"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["fdica = AuxFDICA(\n"," spatial_algorithm=\"IP1\", # You can set \"IP\" instead of \"IP1\".\n"," contrast_fn=contrast_fn,\n"," d_contrast_fn=d_contrast_fn,\n"," scale_restoration=False\n",")\n","print(fdica)"],"metadata":{"id":"y0T8jywSkRyb"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jHa8QSFtkaEM"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_mix_whitened = whiten(spectrogram_mix)\n","spectrogram_est = fdica(spectrogram_mix_whitened, n_iter=20)\n","spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)"],"metadata":{"id":"tHNkZgPakcAc"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"ZHugIyaRkdID"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"NYCtIWSbke2a"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(fdica.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"O7vW3o-ykhCN"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"DSAFgg91kjsb"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/FDICA/AuxFDICA-IP2.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"LbrobIqgDyVY"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 2\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.transform import whiten\n","from ssspy.algorithm import projection_back\n","from ssspy.bss.fdica import AuxFDICA"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def contrast_fn(y):\n"," return 2 * np.abs(y)\n","\n","def d_contrast_fn(y):\n"," return 2 * np.ones_like(y)"],"metadata":{"id":"6AvhtkrdfAeZ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["fdica = AuxFDICA(\n"," spatial_algorithm=\"IP2\",\n"," contrast_fn=contrast_fn,\n"," d_contrast_fn=d_contrast_fn,\n"," scale_restoration=False\n",")\n","print(fdica)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_mix_whitened = whiten(spectrogram_mix)\n","spectrogram_est = fdica(spectrogram_mix_whitened, n_iter=50)\n","spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(fdica.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/FDICA/AuxLaplaceFDICA-IP1.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"tLYLOhzujeH1"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"3WPpY4HOjmvT"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"AdivKHWiDs5L"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 2\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"Xve1xr0Ijn-O"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"ist3AxTsjtLi"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"LFWwc0y9kJDq"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.transform import whiten\n","from ssspy.algorithm import projection_back\n","from ssspy.bss.fdica import AuxLaplaceFDICA"],"metadata":{"id":"ArUmiPttkKYH"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["fdica = AuxLaplaceFDICA(\n"," spatial_algorithm=\"IP1\", # You can set \"IP\" instead of \"IP1\".\n"," scale_restoration=False\n",")\n","print(fdica)"],"metadata":{"id":"y0T8jywSkRyb"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jHa8QSFtkaEM"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_mix_whitened = whiten(spectrogram_mix)\n","spectrogram_est = fdica(spectrogram_mix_whitened, n_iter=20)\n","spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)"],"metadata":{"id":"tHNkZgPakcAc"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"ZHugIyaRkdID"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"NYCtIWSbke2a"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(fdica.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"O7vW3o-ykhCN"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"DSAFgg91kjsb"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/FDICA/AuxLaplaceFDICA-IP2.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"LbrobIqgDyVY"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 2\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.transform import whiten\n","from ssspy.algorithm import projection_back\n","from ssspy.bss.fdica import AuxLaplaceFDICA"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["fdica = AuxLaplaceFDICA(\n"," spatial_algorithm=\"IP2\",\n"," scale_restoration=False\n",")\n","print(fdica)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_mix_whitened = whiten(spectrogram_mix)\n","spectrogram_est = fdica(spectrogram_mix_whitened, n_iter=50)\n","spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(fdica.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/FDICA/GradFDICA.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"9mP-wlwN_imY"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 2\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"LvF_rAusCHdf"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.transform import whiten\n","from ssspy.algorithm import projection_back\n","from ssspy.bss.fdica import GradFDICA"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def contrast_fn(y):\n"," return 2 * np.abs(y)\n","\n","def score_fn(y):\n"," denom = np.maximum(np.abs(y), 1e-10)\n"," return y / denom"],"metadata":{"id":"6AvhtkrdfAeZ"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Holonomic type"],"metadata":{"id":"Gaz1FrqQfSY_"}},{"cell_type":"code","source":["fdica = GradFDICA(\n"," step_size=1e-1,\n"," contrast_fn=contrast_fn,\n"," score_fn=score_fn,\n"," is_holonomic=True,\n"," scale_restoration=False\n",")\n","print(fdica)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_mix_whitened = whiten(spectrogram_mix)\n","spectrogram_est = fdica(spectrogram_mix_whitened, n_iter=500)\n","spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(fdica.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Non-holonomic type"],"metadata":{"id":"vBVuX7V_frpo"}},{"cell_type":"code","source":["fdica = GradFDICA(\n"," step_size=1e-1,\n"," contrast_fn=contrast_fn,\n"," score_fn=score_fn,\n"," is_holonomic=False,\n"," scale_restoration=False\n",")\n","print(fdica)"],"metadata":{"id":"ONLA5XqSfooH"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"PkLxGIrcfw5W"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_mix_whitened = whiten(spectrogram_mix)\n","spectrogram_est = fdica(spectrogram_mix_whitened, n_iter=500)\n","spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)"],"metadata":{"id":"1ZK8uv99fzU-"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"5TmNEew4f00R"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"qEmFDyaLf23E"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(fdica.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"Jvh7YJVdf7QR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/FDICA/GradLaplaceFDICA.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"9mP-wlwN_imY"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 2\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"LvF_rAusCHdf"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.transform import whiten\n","from ssspy.algorithm import projection_back\n","from ssspy.bss.fdica import GradLaplaceFDICA"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Holonomic type"],"metadata":{"id":"Gaz1FrqQfSY_"}},{"cell_type":"code","source":["fdica = GradLaplaceFDICA(\n"," step_size=1e-1,\n"," is_holonomic=True,\n"," scale_restoration=False\n",")\n","print(fdica)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_mix_whitened = whiten(spectrogram_mix)\n","spectrogram_est = fdica(spectrogram_mix_whitened, n_iter=500)\n","spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(fdica.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Non-holonomic type"],"metadata":{"id":"vBVuX7V_frpo"}},{"cell_type":"code","source":["fdica = GradLaplaceFDICA(\n"," step_size=1e-1,\n"," is_holonomic=False,\n"," scale_restoration=False\n",")\n","print(fdica)"],"metadata":{"id":"ONLA5XqSfooH"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"PkLxGIrcfw5W"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_mix_whitened = whiten(spectrogram_mix)\n","spectrogram_est = fdica(spectrogram_mix_whitened, n_iter=500)\n","spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)"],"metadata":{"id":"1ZK8uv99fzU-"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"5TmNEew4f00R"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"qEmFDyaLf23E"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(fdica.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"Jvh7YJVdf7QR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/FDICA/NaturalGradFDICA.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"2z5vwPlhedKU"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"8We0lvX8gHYO"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"otDru1KgDYlM"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 2\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"hdwn-vr6gK7N"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"dsjZqOE3gMJt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"z4VECyzxgNzM"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.transform import whiten\n","from ssspy.algorithm import projection_back\n","from ssspy.bss.fdica import NaturalGradFDICA"],"metadata":{"id":"sjX2hNDEgPVk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def contrast_fn(y):\n"," return 2 * np.abs(y)\n","\n","def score_fn(y):\n"," denom = np.maximum(np.abs(y), 1e-10)\n"," return y / denom"],"metadata":{"id":"ATe35A28gRQp"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Holonomic type"],"metadata":{"id":"r7fJ65ougw56"}},{"cell_type":"code","source":["fdica = NaturalGradFDICA(\n"," step_size=1e-1,\n"," contrast_fn=contrast_fn,\n"," score_fn=score_fn,\n"," is_holonomic=True,\n",")\n","print(fdica)"],"metadata":{"id":"wL3nVUeagSsv"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"CpKeyzimghGg"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = fdica(spectrogram_mix, n_iter=500)"],"metadata":{"id":"uk0geR7pgkY5"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"Y5SHhlEQgmac"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"XxQYDDkegolW"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(fdica.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"EEI1EfyVgsrF"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Non-holonomic type"],"metadata":{"id":"JkE7CwScg6He"}},{"cell_type":"code","source":["fdica = NaturalGradFDICA(\n"," step_size=1e-1,\n"," contrast_fn=contrast_fn,\n"," score_fn=score_fn,\n"," is_holonomic=False,\n"," scale_restoration=False\n",")\n","print(fdica)"],"metadata":{"id":"qjAwSFQJg2mc"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"QS_zwah4g8Cf"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_mix_whitened = whiten(spectrogram_mix)\n","spectrogram_est = fdica(spectrogram_mix_whitened, n_iter=500)\n","spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)"],"metadata":{"id":"cPMDNHZtg-Cu"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jsgG54MGhBAi"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"9c4OLGoWhCkg"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(fdica.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"_5PVRsCdhEFr"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"Z7YaF_48hGcB"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/FDICA/NaturalGradLaplaceFDICA.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"2z5vwPlhedKU"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"8We0lvX8gHYO"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"otDru1KgDYlM"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 2\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"hdwn-vr6gK7N"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"dsjZqOE3gMJt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"z4VECyzxgNzM"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.transform import whiten\n","from ssspy.algorithm import projection_back\n","from ssspy.bss.fdica import NaturalGradLaplaceFDICA"],"metadata":{"id":"sjX2hNDEgPVk"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Holonomic type"],"metadata":{"id":"r7fJ65ougw56"}},{"cell_type":"code","source":["fdica = NaturalGradLaplaceFDICA(\n"," step_size=1e-1,\n"," is_holonomic=True,\n",")\n","print(fdica)"],"metadata":{"id":"wL3nVUeagSsv"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"CpKeyzimghGg"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = fdica(spectrogram_mix, n_iter=500)"],"metadata":{"id":"uk0geR7pgkY5"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"Y5SHhlEQgmac"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"XxQYDDkegolW"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(fdica.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"EEI1EfyVgsrF"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Non-holonomic type"],"metadata":{"id":"JkE7CwScg6He"}},{"cell_type":"code","source":["fdica = NaturalGradLaplaceFDICA(\n"," step_size=1e-1,\n"," is_holonomic=False,\n"," scale_restoration=False\n",")\n","print(fdica)"],"metadata":{"id":"qjAwSFQJg2mc"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"QS_zwah4g8Cf"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_mix_whitened = whiten(spectrogram_mix)\n","spectrogram_est = fdica(spectrogram_mix_whitened, n_iter=500)\n","spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)"],"metadata":{"id":"cPMDNHZtg-Cu"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jsgG54MGhBAi"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"9c4OLGoWhCkg"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(fdica.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"_5PVRsCdhEFr"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"Z7YaF_48hGcB"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/HVA/ADMM-HVA.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyOmTvOlLp1H2ygKo0T05BU4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"eADaj8ajw-Hk"},"outputs":[],"source":["!pip install git+https://github.com/tky823/ssspy.git"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import IPython.display as ipd"],"metadata":{"id":"-vUkn20rxHOV"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"niAQu6AxxRmF"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 2\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"Rf3G-ViRxaRx"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"ivZ5VHBixbnN"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"SzYu9qPTxgfc"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.transform import whiten\n","from ssspy.algorithm import projection_back\n","from ssspy.bss.hva import MaskingADMMHVA"],"metadata":{"id":"niQwCmEkxhtQ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["hva = MaskingADMMHVA(\n"," rho=0.5,\n"," relaxation=1.75,\n"," attenuation=0.2,\n"," scale_restoration=False,\n",")\n","print(hva)"],"metadata":{"id":"HIZcUll6xvXY"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"Q8bfISfwxwfd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_mix_whitened = whiten(spectrogram_mix)\n","spectrogram_mix_normalized = hva.normalize_by_spectral_norm(spectrogram_mix_whitened)\n","spectrogram_est = hva(spectrogram_mix_normalized, n_iter=500)\n","spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)"],"metadata":{"id":"B7jVoSBCxyPu"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"etxpTfiTx1nz"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"I2ZhRdS1x3b6"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"EC6aKaZpx5Zd"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/HVA/HVA.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyMJFxbBzvXESFspRTOKnHrj"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"N9ubIAwVRTVT"},"outputs":[],"source":["!pip install git+https://github.com/tky823/ssspy.git"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import IPython.display as ipd"],"metadata":{"id":"O9gHwzJNRmbD"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"qdeqalLKRoQS"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 2\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"zODACQ83RpUy"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"o-csyI_2RqeI"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"-ORaaTs-RrgL"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.transform import whiten\n","from ssspy.algorithm import projection_back\n","from ssspy.bss.hva import HVA"],"metadata":{"id":"wtsPoH2CR2fl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["hva = HVA(\n"," mu1=1,\n"," mu2=1,\n"," relaxation=1.75,\n"," attenuation=0.2,\n"," scale_restoration=False,\n",")\n","print(hva)"],"metadata":{"id":"JJphe0rSR6PG"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"u-SuTb4NR974"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_mix_whitened = whiten(spectrogram_mix)\n","spectrogram_mix_normalized = hva.normalize_by_spectral_norm(spectrogram_mix_whitened)\n","spectrogram_est = hva(spectrogram_mix_normalized, n_iter=200)\n","spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)"],"metadata":{"id":"BJY7JxF9R_F_"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"kxZrXVL7SBb7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"vs8RFvEuSCqI"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"7CwmQlMpSDwn"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/ICA/FastICA.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 2\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\""],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=False,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.ica import FastICA"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def contrast_fn(x):\n"," return np.log(1 + np.exp(x))\n","\n","def score_fn(x):\n"," return 1 / (1 + np.exp(-x))\n","\n","def d_score_fn(x):\n"," sigma = 1 / (1 + np.exp(-x))\n"," return sigma * (1 - sigma)"],"metadata":{"id":"6AvhtkrdfAeZ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["ica = FastICA(\n"," contrast_fn=contrast_fn,\n"," score_fn=score_fn,\n"," d_score_fn=d_score_fn,\n",")\n","print(ica)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_est = ica(waveform_mix, n_iter=10)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ica.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"PAIdooVaylih"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/ICA/GradICA.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 2\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\""],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=False,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.transform import whiten\n","from ssspy.bss.ica import GradICA"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def contrast_fn(x):\n"," return np.log(1 + np.exp(x))\n","\n","def score_fn(x):\n"," return 1 / (1 + np.exp(-x))"],"metadata":{"id":"6AvhtkrdfAeZ"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Holonomic type"],"metadata":{"id":"XifuX1ebyENl"}},{"cell_type":"code","source":["ica = GradICA(\n"," contrast_fn=contrast_fn, score_fn=score_fn, is_holonomic=True\n",")\n","print(ica)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_mix_whitened = whiten(waveform_mix)\n","waveform_est = ica(waveform_mix_whitened, n_iter=500)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ica.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Non-holonomic type"],"metadata":{"id":"L3-HTRTWyhAl"}},{"cell_type":"code","source":["ica = GradICA(\n"," step_size=1e+0,\n"," contrast_fn=contrast_fn,\n"," score_fn=score_fn,\n"," is_holonomic=False\n",")\n","print(ica)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_mix_whitened = whiten(waveform_mix)\n","waveform_est = ica(waveform_mix_whitened, n_iter=100)"],"metadata":{"id":"Dfp7ncInygoA"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Ch8wdSFGyjrl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ica.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"EeTia-LOykrm"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"PAIdooVaylih"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/ICA/NaturalGradICA.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 2\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\""],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=False,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.transform import whiten\n","from ssspy.bss.ica import NaturalGradICA"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def contrast_fn(x):\n"," return np.log(1 + np.exp(x))\n","\n","def score_fn(x):\n"," return 1 / (1 + np.exp(-x))"],"metadata":{"id":"6AvhtkrdfAeZ"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Holonomic type"],"metadata":{"id":"XifuX1ebyENl"}},{"cell_type":"code","source":["ica = NaturalGradICA(\n"," contrast_fn=contrast_fn, score_fn=score_fn, is_holonomic=True\n",")\n","print(ica)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_est = ica(waveform_mix, n_iter=500)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ica.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Non-holonomic type"],"metadata":{"id":"L3-HTRTWyhAl"}},{"cell_type":"code","source":["ica = NaturalGradICA(\n"," step_size=1e+0,\n"," contrast_fn=contrast_fn,\n"," score_fn=score_fn,\n"," is_holonomic=False\n",")\n","print(ica)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_mix_whitened = whiten(waveform_mix)\n","waveform_est = ica(waveform_mix_whitened, n_iter=100)"],"metadata":{"id":"Dfp7ncInygoA"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Ch8wdSFGyjrl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ica.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"EeTia-LOykrm"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"PAIdooVaylih"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/ILRMA/GGDILRMA-IP1-MM.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd\n","from tqdm.notebook import tqdm"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 4\n","max_duration = 10\n","sisec2010_tag = \"dev1_female4\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.ilrma import GGDILRMA as GGDILRMABase"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class GGDILRMA(GGDILRMABase):\n"," def __init__(self, *args, **kwargs):\n"," super().__init__(*args, **kwargs)\n","\n"," self.progress_bar = None\n","\n"," def __call__(self, *args, n_iter: int = 100, **kwargs):\n"," self.n_iter = n_iter\n","\n"," return super().__call__(*args, n_iter=n_iter, **kwargs)\n","\n"," def update_once(self) -> None:\n"," if self.progress_bar is None:\n"," self.progress_bar = tqdm(total=self.n_iter)\n","\n"," super().update_once()\n","\n"," self.progress_bar.update(1)"],"metadata":{"id":"kyZFMq8BQJ-H"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/ partitioning function\n","For small $\\beta$, the combination of `normalization` and `flooring_fn` sometimes prevents numerical stability, so we set `normalization=False` in this notebook."],"metadata":{"id":"8ErS0NZ12Gyq"}},{"cell_type":"code","source":["ilrma = GGDILRMA(\n"," n_basis=8,\n"," beta=1.95,\n"," spatial_algorithm=\"IP1\", # You can set \"IP\" instead of \"IP1\".\n"," domain=2,\n"," partitioning=True, # w/ partitioning function\n"," normalization=False,\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=200)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/o partitioning function"],"metadata":{"id":"3iVGgX3F2M8Z"}},{"cell_type":"code","source":["ilrma = GGDILRMA(\n"," n_basis=2,\n"," beta=1.8,\n"," spatial_algorithm=\"IP1\", # You can set \"IP\" instead of \"IP1\".\n"," domain=2,\n"," partitioning=False, # w/o partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"X6OsBzSZ2Q8Z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=200)"],"metadata":{"id":"Ccmek_Ek2Q6D"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"7VASi3lZ2Q39"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"ct-qKgs42ZK9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"SIpKCQTo2ZHl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"r6q7VzoOzt-T"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/ILRMA/GGDILRMA-IP2-MM.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd\n","from tqdm.notebook import tqdm"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 4\n","max_duration = 10\n","sisec2010_tag = \"dev1_female4\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.ilrma import GGDILRMA as GGDILRMABase"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class GGDILRMA(GGDILRMABase):\n"," def __init__(self, *args, **kwargs):\n"," super().__init__(*args, **kwargs)\n","\n"," self.progress_bar = None\n","\n"," def __call__(self, *args, n_iter: int = 100, **kwargs):\n"," self.n_iter = n_iter\n","\n"," return super().__call__(*args, n_iter=n_iter, **kwargs)\n","\n"," def update_once(self) -> None:\n"," if self.progress_bar is None:\n"," self.progress_bar = tqdm(total=self.n_iter)\n","\n"," super().update_once()\n","\n"," self.progress_bar.update(1)"],"metadata":{"id":"hxhAhSwcQNuw"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/ partitioning function"],"metadata":{"id":"8ErS0NZ12Gyq"}},{"cell_type":"code","source":["ilrma = GGDILRMA(\n"," n_basis=8,\n"," beta=1.95,\n"," spatial_algorithm=\"IP2\",\n"," domain=2,\n"," partitioning=True, # w/ partitioning function\n"," normalization=False,\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=200)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/o partitioning function"],"metadata":{"id":"3iVGgX3F2M8Z"}},{"cell_type":"code","source":["ilrma = GGDILRMA(\n"," n_basis=2,\n"," beta=1.9,\n"," spatial_algorithm=\"IP2\",\n"," domain=2,\n"," partitioning=False, # w/o partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"X6OsBzSZ2Q8Z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=200)"],"metadata":{"id":"Ccmek_Ek2Q6D"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"7VASi3lZ2Q39"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"ct-qKgs42ZK9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"SIpKCQTo2ZHl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"r6q7VzoOzt-T"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/ILRMA/GGDILRMA-ISS1-MM.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd\n","from tqdm.notebook import tqdm"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 4\n","max_duration = 10\n","sisec2010_tag = \"dev1_female4\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.ilrma import GGDILRMA as GGDILRMABase"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class GGDILRMA(GGDILRMABase):\n"," def __init__(self, *args, **kwargs):\n"," super().__init__(*args, **kwargs)\n","\n"," self.progress_bar = None\n","\n"," def __call__(self, *args, n_iter: int = 100, **kwargs):\n"," self.n_iter = n_iter\n","\n"," return super().__call__(*args, n_iter=n_iter, **kwargs)\n","\n"," def update_once(self) -> None:\n"," if self.progress_bar is None:\n"," self.progress_bar = tqdm(total=self.n_iter)\n","\n"," super().update_once()\n","\n"," self.progress_bar.update(1)"],"metadata":{"id":"Y602oZNqQPBS"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/ partitioning function"],"metadata":{"id":"8ErS0NZ12Gyq"}},{"cell_type":"code","source":["ilrma = GGDILRMA(\n"," n_basis=8,\n"," beta=1.95,\n"," spatial_algorithm=\"ISS1\", # You can set \"ISS\" instead of \"ISS1\".\n"," domain=2,\n"," partitioning=True, # w/ partitioning function\n"," normalization=False,\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=200)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/o partitioning function"],"metadata":{"id":"3iVGgX3F2M8Z"}},{"cell_type":"code","source":["ilrma = GGDILRMA(\n"," n_basis=2,\n"," beta=1.95,\n"," spatial_algorithm=\"ISS1\", # You can set \"ISS\" instead of \"ISS1\".\n"," domain=2,\n"," partitioning=False, # w/o partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"X6OsBzSZ2Q8Z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=200)"],"metadata":{"id":"Ccmek_Ek2Q6D"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"7VASi3lZ2Q39"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"ct-qKgs42ZK9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"SIpKCQTo2ZHl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"r6q7VzoOzt-T"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/ILRMA/GGDILRMA-ISS2-MM.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd\n","from tqdm.notebook import tqdm"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 4\n","max_duration = 10\n","sisec2010_tag = \"dev1_female4\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.ilrma import GGDILRMA as GGDILRMABase"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class GGDILRMA(GGDILRMABase):\n"," def __init__(self, *args, **kwargs):\n"," super().__init__(*args, **kwargs)\n","\n"," self.progress_bar = None\n","\n"," def __call__(self, *args, n_iter: int = 100, **kwargs):\n"," self.n_iter = n_iter\n","\n"," return super().__call__(*args, n_iter=n_iter, **kwargs)\n","\n"," def update_once(self) -> None:\n"," if self.progress_bar is None:\n"," self.progress_bar = tqdm(total=self.n_iter)\n","\n"," super().update_once()\n","\n"," self.progress_bar.update(1)"],"metadata":{"id":"tXuJbQQ1QPvH"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/ partitioning function"],"metadata":{"id":"8ErS0NZ12Gyq"}},{"cell_type":"code","source":["ilrma = GGDILRMA(\n"," n_basis=8,\n"," beta=1.95,\n"," spatial_algorithm=\"ISS2\",\n"," domain=2,\n"," partitioning=True, # w/ partitioning function\n"," normalization=False,\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=200)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/o partitioning function"],"metadata":{"id":"3iVGgX3F2M8Z"}},{"cell_type":"code","source":["ilrma = GGDILRMA(\n"," n_basis=3,\n"," beta=1.95,\n"," spatial_algorithm=\"ISS2\",\n"," domain=2,\n"," partitioning=False, # w/o partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"X6OsBzSZ2Q8Z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=200)"],"metadata":{"id":"Ccmek_Ek2Q6D"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"7VASi3lZ2Q39"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"ct-qKgs42ZK9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"SIpKCQTo2ZHl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"r6q7VzoOzt-T"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/ILRMA/GaussILRMA-IP1-ME.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd\n","from tqdm.notebook import tqdm"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 4\n","max_duration = 10\n","sisec2010_tag = \"dev1_female4\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.ilrma import GaussILRMA as GaussILRMABase"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class GaussILRMA(GaussILRMABase):\n"," def __init__(self, *args, **kwargs):\n"," super().__init__(*args, **kwargs)\n","\n"," self.progress_bar = None\n","\n"," def __call__(self, *args, n_iter: int = 100, **kwargs):\n"," self.n_iter = n_iter\n","\n"," return super().__call__(*args, n_iter=n_iter, **kwargs)\n","\n"," def update_once(self) -> None:\n"," if self.progress_bar is None:\n"," self.progress_bar = tqdm(total=self.n_iter)\n","\n"," super().update_once()\n","\n"," self.progress_bar.update(1)"],"metadata":{"id":"B4E4ILrjIlcT"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/ partitioning function"],"metadata":{"id":"8ErS0NZ12Gyq"}},{"cell_type":"code","source":["ilrma = GaussILRMA(\n"," n_basis=8,\n"," spatial_algorithm=\"IP1\", # You can set \"IP\" instead of \"IP1\".\n"," source_algorithm=\"ME\",\n"," domain=2,\n"," partitioning=True, # w/ partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=100)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/o partitioning function"],"metadata":{"id":"3iVGgX3F2M8Z"}},{"cell_type":"code","source":["ilrma = GaussILRMA(\n"," n_basis=2,\n"," spatial_algorithm=\"IP1\", # You can set \"IP\" instead of \"IP1\".\n"," source_algorithm=\"ME\",\n"," domain=2,\n"," partitioning=False, # w/o partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"X6OsBzSZ2Q8Z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=100)"],"metadata":{"id":"Ccmek_Ek2Q6D"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"7VASi3lZ2Q39"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"ct-qKgs42ZK9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"SIpKCQTo2ZHl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"r6q7VzoOzt-T"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/ILRMA/GaussILRMA-IP1-MM.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd\n","from tqdm.notebook import tqdm"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 4\n","max_duration = 10\n","sisec2010_tag = \"dev1_female4\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.ilrma import GaussILRMA as GaussILRMABase"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class GaussILRMA(GaussILRMABase):\n"," def __init__(self, *args, **kwargs):\n"," super().__init__(*args, **kwargs)\n","\n"," self.progress_bar = None\n","\n"," def __call__(self, *args, n_iter: int = 100, **kwargs):\n"," self.n_iter = n_iter\n","\n"," return super().__call__(*args, n_iter=n_iter, **kwargs)\n","\n"," def update_once(self) -> None:\n"," if self.progress_bar is None:\n"," self.progress_bar = tqdm(total=self.n_iter)\n","\n"," super().update_once()\n","\n"," self.progress_bar.update(1)"],"metadata":{"id":"B4E4ILrjIlcT"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/ partitioning function"],"metadata":{"id":"8ErS0NZ12Gyq"}},{"cell_type":"code","source":["ilrma = GaussILRMA(\n"," n_basis=8,\n"," spatial_algorithm=\"IP1\", # You can set \"IP\" instead of \"IP1\".\n"," source_algorithm=\"MM\",\n"," domain=2,\n"," partitioning=True, # w/ partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=100)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/o partitioning function"],"metadata":{"id":"3iVGgX3F2M8Z"}},{"cell_type":"code","source":["ilrma = GaussILRMA(\n"," n_basis=2,\n"," spatial_algorithm=\"IP1\", # You can set \"IP\" instead of \"IP1\".\n"," source_algorithm=\"MM\",\n"," domain=2,\n"," partitioning=False, # w/o partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"X6OsBzSZ2Q8Z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=100)"],"metadata":{"id":"Ccmek_Ek2Q6D"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"7VASi3lZ2Q39"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"ct-qKgs42ZK9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"SIpKCQTo2ZHl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"r6q7VzoOzt-T"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/ILRMA/GaussILRMA-IP2-ME.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd\n","from tqdm.notebook import tqdm"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 4\n","max_duration = 10\n","sisec2010_tag = \"dev1_female4\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.ilrma import GaussILRMA as GaussILRMABase"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class GaussILRMA(GaussILRMABase):\n"," def __init__(self, *args, **kwargs):\n"," super().__init__(*args, **kwargs)\n","\n"," self.progress_bar = None\n","\n"," def __call__(self, *args, n_iter: int = 100, **kwargs):\n"," self.n_iter = n_iter\n","\n"," return super().__call__(*args, n_iter=n_iter, **kwargs)\n","\n"," def update_once(self) -> None:\n"," if self.progress_bar is None:\n"," self.progress_bar = tqdm(total=self.n_iter)\n","\n"," super().update_once()\n","\n"," self.progress_bar.update(1)"],"metadata":{"id":"nrIom3XyJlB9"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/ partitioning function"],"metadata":{"id":"8ErS0NZ12Gyq"}},{"cell_type":"code","source":["ilrma = GaussILRMA(\n"," n_basis=8,\n"," spatial_algorithm=\"IP2\",\n"," source_algorithm=\"ME\",\n"," domain=2,\n"," partitioning=True, # w/ partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=200)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/o partitioning function"],"metadata":{"id":"3iVGgX3F2M8Z"}},{"cell_type":"code","source":["ilrma = GaussILRMA(\n"," n_basis=2,\n"," spatial_algorithm=\"IP2\",\n"," source_algorithm=\"ME\",\n"," domain=2,\n"," partitioning=False, # w/o partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"X6OsBzSZ2Q8Z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=100)"],"metadata":{"id":"Ccmek_Ek2Q6D"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"7VASi3lZ2Q39"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"ct-qKgs42ZK9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"SIpKCQTo2ZHl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"r6q7VzoOzt-T"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/ILRMA/GaussILRMA-IP2-MM.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd\n","from tqdm.notebook import tqdm"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 4\n","max_duration = 10\n","sisec2010_tag = \"dev1_female4\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.ilrma import GaussILRMA as GaussILRMABase"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class GaussILRMA(GaussILRMABase):\n"," def __init__(self, *args, **kwargs):\n"," super().__init__(*args, **kwargs)\n","\n"," self.progress_bar = None\n","\n"," def __call__(self, *args, n_iter: int = 100, **kwargs):\n"," self.n_iter = n_iter\n","\n"," return super().__call__(*args, n_iter=n_iter, **kwargs)\n","\n"," def update_once(self) -> None:\n"," if self.progress_bar is None:\n"," self.progress_bar = tqdm(total=self.n_iter)\n","\n"," super().update_once()\n","\n"," self.progress_bar.update(1)"],"metadata":{"id":"nrIom3XyJlB9"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/ partitioning function"],"metadata":{"id":"8ErS0NZ12Gyq"}},{"cell_type":"code","source":["ilrma = GaussILRMA(\n"," n_basis=8,\n"," spatial_algorithm=\"IP2\",\n"," source_algorithm=\"MM\",\n"," domain=2,\n"," partitioning=True, # w/ partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=200)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/o partitioning function"],"metadata":{"id":"3iVGgX3F2M8Z"}},{"cell_type":"code","source":["ilrma = GaussILRMA(\n"," n_basis=2,\n"," spatial_algorithm=\"IP2\",\n"," source_algorithm=\"MM\",\n"," domain=2,\n"," partitioning=False, # w/o partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"X6OsBzSZ2Q8Z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=100)"],"metadata":{"id":"Ccmek_Ek2Q6D"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"7VASi3lZ2Q39"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"ct-qKgs42ZK9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"SIpKCQTo2ZHl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"r6q7VzoOzt-T"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/ILRMA/GaussILRMA-IPA-ME.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd\n","from tqdm.notebook import tqdm"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 4\n","max_duration = 10\n","sisec2010_tag = \"dev1_female4\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.ilrma import GaussILRMA as GaussILRMABase"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class GaussILRMA(GaussILRMABase):\n"," def __init__(self, *args, **kwargs):\n"," super().__init__(*args, **kwargs)\n","\n"," self.progress_bar = None\n","\n"," def __call__(self, *args, n_iter: int = 100, **kwargs):\n"," self.n_iter = n_iter\n","\n"," return super().__call__(*args, n_iter=n_iter, **kwargs)\n","\n"," def update_once(self) -> None:\n"," if self.progress_bar is None:\n"," self.progress_bar = tqdm(total=self.n_iter)\n","\n"," super().update_once()\n","\n"," self.progress_bar.update(1)"],"metadata":{"id":"Q2dt5YXZJl2T"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/ partitioning function"],"metadata":{"id":"8ErS0NZ12Gyq"}},{"cell_type":"code","source":["ilrma = GaussILRMA(\n"," n_basis=8,\n"," spatial_algorithm=\"IPA\",\n"," source_algorithm=\"ME\",\n"," domain=2,\n"," partitioning=True, # w/ partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=100)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/o partitioning function"],"metadata":{"id":"3iVGgX3F2M8Z"}},{"cell_type":"code","source":["ilrma = GaussILRMA(\n"," n_basis=2,\n"," spatial_algorithm=\"IPA\",\n"," source_algorithm=\"ME\",\n"," domain=2,\n"," partitioning=False, # w/o partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"X6OsBzSZ2Q8Z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=100)"],"metadata":{"id":"Ccmek_Ek2Q6D"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"7VASi3lZ2Q39"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"ct-qKgs42ZK9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"SIpKCQTo2ZHl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"r6q7VzoOzt-T"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/ILRMA/GaussILRMA-IPA-MM.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd\n","from tqdm.notebook import tqdm"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 4\n","max_duration = 10\n","sisec2010_tag = \"dev1_female4\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.ilrma import GaussILRMA as GaussILRMABase"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class GaussILRMA(GaussILRMABase):\n"," def __init__(self, *args, **kwargs):\n"," super().__init__(*args, **kwargs)\n","\n"," self.progress_bar = None\n","\n"," def __call__(self, *args, n_iter: int = 100, **kwargs):\n"," self.n_iter = n_iter\n","\n"," return super().__call__(*args, n_iter=n_iter, **kwargs)\n","\n"," def update_once(self) -> None:\n"," if self.progress_bar is None:\n"," self.progress_bar = tqdm(total=self.n_iter)\n","\n"," super().update_once()\n","\n"," self.progress_bar.update(1)"],"metadata":{"id":"Q2dt5YXZJl2T"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/ partitioning function"],"metadata":{"id":"8ErS0NZ12Gyq"}},{"cell_type":"code","source":["ilrma = GaussILRMA(\n"," n_basis=8,\n"," spatial_algorithm=\"IPA\",\n"," source_algorithm=\"MM\",\n"," domain=2,\n"," partitioning=True, # w/ partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=100)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/o partitioning function"],"metadata":{"id":"3iVGgX3F2M8Z"}},{"cell_type":"code","source":["ilrma = GaussILRMA(\n"," n_basis=2,\n"," spatial_algorithm=\"IPA\",\n"," source_algorithm=\"MM\",\n"," domain=2,\n"," partitioning=False, # w/o partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"X6OsBzSZ2Q8Z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=100)"],"metadata":{"id":"Ccmek_Ek2Q6D"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"7VASi3lZ2Q39"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"ct-qKgs42ZK9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"SIpKCQTo2ZHl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"r6q7VzoOzt-T"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/ILRMA/GaussILRMA-ISS1-ME.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd\n","from tqdm.notebook import tqdm"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 4\n","max_duration = 10\n","sisec2010_tag = \"dev1_female4\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.ilrma import GaussILRMA as GaussILRMABase"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class GaussILRMA(GaussILRMABase):\n"," def __init__(self, *args, **kwargs):\n"," super().__init__(*args, **kwargs)\n","\n"," self.progress_bar = None\n","\n"," def __call__(self, *args, n_iter: int = 100, **kwargs):\n"," self.n_iter = n_iter\n","\n"," return super().__call__(*args, n_iter=n_iter, **kwargs)\n","\n"," def update_once(self) -> None:\n"," if self.progress_bar is None:\n"," self.progress_bar = tqdm(total=self.n_iter)\n","\n"," super().update_once()\n","\n"," self.progress_bar.update(1)"],"metadata":{"id":"Q2dt5YXZJl2T"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/ partitioning function"],"metadata":{"id":"8ErS0NZ12Gyq"}},{"cell_type":"code","source":["ilrma = GaussILRMA(\n"," n_basis=8,\n"," spatial_algorithm=\"ISS1\", # You can set \"ISS\" instead of \"ISS1\".\n"," source_algorithm=\"ME\",\n"," domain=2,\n"," partitioning=True, # w/ partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=100)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/o partitioning function"],"metadata":{"id":"3iVGgX3F2M8Z"}},{"cell_type":"code","source":["ilrma = GaussILRMA(\n"," n_basis=2,\n"," spatial_algorithm=\"ISS1\", # You can set \"ISS\" instead of \"ISS1\".\n"," source_algorithm=\"ME\",\n"," domain=2,\n"," partitioning=False, # w/o partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"X6OsBzSZ2Q8Z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=100)"],"metadata":{"id":"Ccmek_Ek2Q6D"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"7VASi3lZ2Q39"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"ct-qKgs42ZK9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"SIpKCQTo2ZHl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"r6q7VzoOzt-T"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/ILRMA/GaussILRMA-ISS1-MM.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd\n","from tqdm.notebook import tqdm"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 4\n","max_duration = 10\n","sisec2010_tag = \"dev1_female4\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.ilrma import GaussILRMA as GaussILRMABase"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class GaussILRMA(GaussILRMABase):\n"," def __init__(self, *args, **kwargs):\n"," super().__init__(*args, **kwargs)\n","\n"," self.progress_bar = None\n","\n"," def __call__(self, *args, n_iter: int = 100, **kwargs):\n"," self.n_iter = n_iter\n","\n"," return super().__call__(*args, n_iter=n_iter, **kwargs)\n","\n"," def update_once(self) -> None:\n"," if self.progress_bar is None:\n"," self.progress_bar = tqdm(total=self.n_iter)\n","\n"," super().update_once()\n","\n"," self.progress_bar.update(1)"],"metadata":{"id":"Q2dt5YXZJl2T"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/ partitioning function"],"metadata":{"id":"8ErS0NZ12Gyq"}},{"cell_type":"code","source":["ilrma = GaussILRMA(\n"," n_basis=8,\n"," spatial_algorithm=\"ISS1\", # You can set \"ISS\" instead of \"ISS1\".\n"," source_algorithm=\"MM\",\n"," domain=2,\n"," partitioning=True, # w/ partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=100)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/o partitioning function"],"metadata":{"id":"3iVGgX3F2M8Z"}},{"cell_type":"code","source":["ilrma = GaussILRMA(\n"," n_basis=2,\n"," spatial_algorithm=\"ISS1\", # You can set \"ISS\" instead of \"ISS1\".\n"," source_algorithm=\"MM\",\n"," domain=2,\n"," partitioning=False, # w/o partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"X6OsBzSZ2Q8Z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=100)"],"metadata":{"id":"Ccmek_Ek2Q6D"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"7VASi3lZ2Q39"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"ct-qKgs42ZK9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"SIpKCQTo2ZHl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"r6q7VzoOzt-T"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/ILRMA/GaussILRMA-ISS2-ME.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd\n","from tqdm.notebook import tqdm"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 4\n","max_duration = 10\n","sisec2010_tag = \"dev1_female4\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.ilrma import GaussILRMA as GaussILRMABase"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class GaussILRMA(GaussILRMABase):\n"," def __init__(self, *args, **kwargs):\n"," super().__init__(*args, **kwargs)\n","\n"," self.progress_bar = None\n","\n"," def __call__(self, *args, n_iter: int = 100, **kwargs):\n"," self.n_iter = n_iter\n","\n"," return super().__call__(*args, n_iter=n_iter, **kwargs)\n","\n"," def update_once(self) -> None:\n"," if self.progress_bar is None:\n"," self.progress_bar = tqdm(total=self.n_iter)\n","\n"," super().update_once()\n","\n"," self.progress_bar.update(1)"],"metadata":{"id":"82kfbjHkJml-"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/ partitioning function"],"metadata":{"id":"8ErS0NZ12Gyq"}},{"cell_type":"code","source":["ilrma = GaussILRMA(\n"," n_basis=5,\n"," spatial_algorithm=\"ISS2\",\n"," source_algorithm=\"ME\",\n"," domain=2,\n"," partitioning=True, # w/ partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=100)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/o partitioning function"],"metadata":{"id":"3iVGgX3F2M8Z"}},{"cell_type":"code","source":["ilrma = GaussILRMA(\n"," n_basis=2,\n"," spatial_algorithm=\"ISS2\",\n"," source_algorithm=\"ME\",\n"," domain=2,\n"," partitioning=False, # w/o partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"X6OsBzSZ2Q8Z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=100)"],"metadata":{"id":"Ccmek_Ek2Q6D"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"7VASi3lZ2Q39"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"ct-qKgs42ZK9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"SIpKCQTo2ZHl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"r6q7VzoOzt-T"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/ILRMA/GaussILRMA-ISS2-MM.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd\n","from tqdm.notebook import tqdm"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 4\n","max_duration = 10\n","sisec2010_tag = \"dev1_female4\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.ilrma import GaussILRMA as GaussILRMABase"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class GaussILRMA(GaussILRMABase):\n"," def __init__(self, *args, **kwargs):\n"," super().__init__(*args, **kwargs)\n","\n"," self.progress_bar = None\n","\n"," def __call__(self, *args, n_iter: int = 100, **kwargs):\n"," self.n_iter = n_iter\n","\n"," return super().__call__(*args, n_iter=n_iter, **kwargs)\n","\n"," def update_once(self) -> None:\n"," if self.progress_bar is None:\n"," self.progress_bar = tqdm(total=self.n_iter)\n","\n"," super().update_once()\n","\n"," self.progress_bar.update(1)"],"metadata":{"id":"82kfbjHkJml-"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/ partitioning function"],"metadata":{"id":"8ErS0NZ12Gyq"}},{"cell_type":"code","source":["ilrma = GaussILRMA(\n"," n_basis=5,\n"," spatial_algorithm=\"ISS2\",\n"," source_algorithm=\"MM\",\n"," domain=2,\n"," partitioning=True, # w/ partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=100)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/o partitioning function"],"metadata":{"id":"3iVGgX3F2M8Z"}},{"cell_type":"code","source":["ilrma = GaussILRMA(\n"," n_basis=2,\n"," spatial_algorithm=\"ISS2\",\n"," source_algorithm=\"MM\",\n"," domain=2,\n"," partitioning=False, # w/o partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"X6OsBzSZ2Q8Z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=100)"],"metadata":{"id":"Ccmek_Ek2Q6D"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"7VASi3lZ2Q39"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"ct-qKgs42ZK9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"SIpKCQTo2ZHl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"r6q7VzoOzt-T"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/ILRMA/TILRMA-IP1-ME.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd\n","from tqdm.notebook import tqdm"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 4\n","max_duration = 10\n","sisec2010_tag = \"dev1_female4\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.ilrma import TILRMA as TILRMABase"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class TILRMA(TILRMABase):\n"," def __init__(self, *args, **kwargs):\n"," super().__init__(*args, **kwargs)\n","\n"," self.progress_bar = None\n","\n"," def __call__(self, *args, n_iter: int = 100, **kwargs):\n"," self.n_iter = n_iter\n","\n"," return super().__call__(*args, n_iter=n_iter, **kwargs)\n","\n"," def update_once(self) -> None:\n"," if self.progress_bar is None:\n"," self.progress_bar = tqdm(total=self.n_iter)\n","\n"," super().update_once()\n","\n"," self.progress_bar.update(1)"],"metadata":{"id":"AYaZ5WZFa7Mw"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/ partitioning function"],"metadata":{"id":"8ErS0NZ12Gyq"}},{"cell_type":"code","source":["ilrma = TILRMA(\n"," n_basis=8,\n"," dof=1000,\n"," spatial_algorithm=\"IP1\", # You can set \"IP\" instead of \"IP1\".\n"," source_algorithm=\"ME\",\n"," domain=2,\n"," partitioning=True, # w/ partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=500)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/o partitioning function"],"metadata":{"id":"3iVGgX3F2M8Z"}},{"cell_type":"code","source":["ilrma = TILRMA(\n"," n_basis=2,\n"," dof=1000,\n"," spatial_algorithm=\"IP1\", # You can set \"IP\" instead of \"IP1\".\n"," source_algorithm=\"ME\",\n"," domain=2,\n"," partitioning=False, # w/o partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"X6OsBzSZ2Q8Z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=200)"],"metadata":{"id":"Ccmek_Ek2Q6D"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"7VASi3lZ2Q39"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"ct-qKgs42ZK9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"SIpKCQTo2ZHl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"r6q7VzoOzt-T"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/ILRMA/TILRMA-IP1-MM.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd\n","from tqdm.notebook import tqdm"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 4\n","max_duration = 10\n","sisec2010_tag = \"dev1_female4\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.ilrma import TILRMA as TILRMABase"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class TILRMA(TILRMABase):\n"," def __init__(self, *args, **kwargs):\n"," super().__init__(*args, **kwargs)\n","\n"," self.progress_bar = None\n","\n"," def __call__(self, *args, n_iter: int = 100, **kwargs):\n"," self.n_iter = n_iter\n","\n"," return super().__call__(*args, n_iter=n_iter, **kwargs)\n","\n"," def update_once(self) -> None:\n"," if self.progress_bar is None:\n"," self.progress_bar = tqdm(total=self.n_iter)\n","\n"," super().update_once()\n","\n"," self.progress_bar.update(1)"],"metadata":{"id":"AYaZ5WZFa7Mw"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/ partitioning function"],"metadata":{"id":"8ErS0NZ12Gyq"}},{"cell_type":"code","source":["ilrma = TILRMA(\n"," n_basis=8,\n"," dof=1000,\n"," spatial_algorithm=\"IP1\", # You can set \"IP\" instead of \"IP1\".\n"," source_algorithm=\"MM\",\n"," domain=2,\n"," partitioning=True, # w/ partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=500)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/o partitioning function"],"metadata":{"id":"3iVGgX3F2M8Z"}},{"cell_type":"code","source":["ilrma = TILRMA(\n"," n_basis=2,\n"," dof=1000,\n"," spatial_algorithm=\"IP1\", # You can set \"IP\" instead of \"IP1\".\n"," source_algorithm=\"MM\",\n"," domain=2,\n"," partitioning=False, # w/o partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"X6OsBzSZ2Q8Z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=200)"],"metadata":{"id":"Ccmek_Ek2Q6D"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"7VASi3lZ2Q39"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"ct-qKgs42ZK9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"SIpKCQTo2ZHl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"r6q7VzoOzt-T"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/ILRMA/TILRMA-IP2-ME.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd\n","from tqdm.notebook import tqdm"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 4\n","max_duration = 10\n","sisec2010_tag = \"dev1_female4\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.ilrma import TILRMA as TILRMABase"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class TILRMA(TILRMABase):\n"," def __init__(self, *args, **kwargs):\n"," super().__init__(*args, **kwargs)\n","\n"," self.progress_bar = None\n","\n"," def __call__(self, *args, n_iter: int = 100, **kwargs):\n"," self.n_iter = n_iter\n","\n"," return super().__call__(*args, n_iter=n_iter, **kwargs)\n","\n"," def update_once(self) -> None:\n"," if self.progress_bar is None:\n"," self.progress_bar = tqdm(total=self.n_iter)\n","\n"," super().update_once()\n","\n"," self.progress_bar.update(1)"],"metadata":{"id":"F5nHP_t0a9zc"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/ partitioning function"],"metadata":{"id":"8ErS0NZ12Gyq"}},{"cell_type":"code","source":["ilrma = TILRMA(\n"," n_basis=8,\n"," dof=1000,\n"," spatial_algorithm=\"IP2\",\n"," source_algorithm=\"ME\",\n"," domain=2,\n"," partitioning=True, # w/ partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=200)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/o partitioning function"],"metadata":{"id":"3iVGgX3F2M8Z"}},{"cell_type":"code","source":["ilrma = TILRMA(\n"," n_basis=2,\n"," dof=1000,\n"," spatial_algorithm=\"IP2\",\n"," source_algorithm=\"ME\",\n"," domain=2,\n"," partitioning=False, # w/o partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"X6OsBzSZ2Q8Z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=200)"],"metadata":{"id":"Ccmek_Ek2Q6D"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"7VASi3lZ2Q39"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"ct-qKgs42ZK9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"SIpKCQTo2ZHl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"r6q7VzoOzt-T"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/ILRMA/TILRMA-IP2-MM.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd\n","from tqdm.notebook import tqdm"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 4\n","max_duration = 10\n","sisec2010_tag = \"dev1_female4\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.ilrma import TILRMA as TILRMABase"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class TILRMA(TILRMABase):\n"," def __init__(self, *args, **kwargs):\n"," super().__init__(*args, **kwargs)\n","\n"," self.progress_bar = None\n","\n"," def __call__(self, *args, n_iter: int = 100, **kwargs):\n"," self.n_iter = n_iter\n","\n"," return super().__call__(*args, n_iter=n_iter, **kwargs)\n","\n"," def update_once(self) -> None:\n"," if self.progress_bar is None:\n"," self.progress_bar = tqdm(total=self.n_iter)\n","\n"," super().update_once()\n","\n"," self.progress_bar.update(1)"],"metadata":{"id":"F5nHP_t0a9zc"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/ partitioning function"],"metadata":{"id":"8ErS0NZ12Gyq"}},{"cell_type":"code","source":["ilrma = TILRMA(\n"," n_basis=8,\n"," dof=1000,\n"," spatial_algorithm=\"IP2\",\n"," source_algorithm=\"MM\",\n"," domain=2,\n"," partitioning=True, # w/ partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=200)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/o partitioning function"],"metadata":{"id":"3iVGgX3F2M8Z"}},{"cell_type":"code","source":["ilrma = TILRMA(\n"," n_basis=2,\n"," dof=1000,\n"," spatial_algorithm=\"IP2\",\n"," source_algorithm=\"MM\",\n"," domain=2,\n"," partitioning=False, # w/o partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"X6OsBzSZ2Q8Z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=200)"],"metadata":{"id":"Ccmek_Ek2Q6D"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"7VASi3lZ2Q39"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"ct-qKgs42ZK9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"SIpKCQTo2ZHl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"r6q7VzoOzt-T"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/ILRMA/TILRMA-ISS1-ME.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd\n","from tqdm.notebook import tqdm"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 4\n","max_duration = 10\n","sisec2010_tag = \"dev1_female4\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.ilrma import TILRMA as TILRMABase"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class TILRMA(TILRMABase):\n"," def __init__(self, *args, **kwargs):\n"," super().__init__(*args, **kwargs)\n","\n"," self.progress_bar = None\n","\n"," def __call__(self, *args, n_iter: int = 100, **kwargs):\n"," self.n_iter = n_iter\n","\n"," return super().__call__(*args, n_iter=n_iter, **kwargs)\n","\n"," def update_once(self) -> None:\n"," if self.progress_bar is None:\n"," self.progress_bar = tqdm(total=self.n_iter)\n","\n"," super().update_once()\n","\n"," self.progress_bar.update(1)"],"metadata":{"id":"mctxmyYha-Wl"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/ partitioning function"],"metadata":{"id":"8ErS0NZ12Gyq"}},{"cell_type":"code","source":["ilrma = TILRMA(\n"," n_basis=8,\n"," dof=1000,\n"," spatial_algorithm=\"ISS1\", # You can set \"ISS\" instead of \"ISS1\".\n"," source_algorithm=\"ME\",\n"," domain=2,\n"," partitioning=True, # w/ partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=500)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/o partitioning function"],"metadata":{"id":"3iVGgX3F2M8Z"}},{"cell_type":"code","source":["ilrma = TILRMA(\n"," n_basis=2,\n"," dof=1000,\n"," spatial_algorithm=\"ISS1\", # You can set \"ISS\" instead of \"ISS1\".\n"," source_algorithm=\"ME\",\n"," domain=2,\n"," partitioning=False, # w/o partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"X6OsBzSZ2Q8Z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=200)"],"metadata":{"id":"Ccmek_Ek2Q6D"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"7VASi3lZ2Q39"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"ct-qKgs42ZK9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"SIpKCQTo2ZHl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"r6q7VzoOzt-T"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/ILRMA/TILRMA-ISS1-MM.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd\n","from tqdm.notebook import tqdm"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 4\n","max_duration = 10\n","sisec2010_tag = \"dev1_female4\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.ilrma import TILRMA as TILRMABase"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class TILRMA(TILRMABase):\n"," def __init__(self, *args, **kwargs):\n"," super().__init__(*args, **kwargs)\n","\n"," self.progress_bar = None\n","\n"," def __call__(self, *args, n_iter: int = 100, **kwargs):\n"," self.n_iter = n_iter\n","\n"," return super().__call__(*args, n_iter=n_iter, **kwargs)\n","\n"," def update_once(self) -> None:\n"," if self.progress_bar is None:\n"," self.progress_bar = tqdm(total=self.n_iter)\n","\n"," super().update_once()\n","\n"," self.progress_bar.update(1)"],"metadata":{"id":"mctxmyYha-Wl"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/ partitioning function"],"metadata":{"id":"8ErS0NZ12Gyq"}},{"cell_type":"code","source":["ilrma = TILRMA(\n"," n_basis=8,\n"," dof=1000,\n"," spatial_algorithm=\"ISS1\", # You can set \"ISS\" instead of \"ISS1\".\n"," source_algorithm=\"MM\",\n"," domain=2,\n"," partitioning=True, # w/ partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=500)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/o partitioning function"],"metadata":{"id":"3iVGgX3F2M8Z"}},{"cell_type":"code","source":["ilrma = TILRMA(\n"," n_basis=2,\n"," dof=1000,\n"," spatial_algorithm=\"ISS1\", # You can set \"ISS\" instead of \"ISS1\".\n"," source_algorithm=\"MM\",\n"," domain=2,\n"," partitioning=False, # w/o partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"X6OsBzSZ2Q8Z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=200)"],"metadata":{"id":"Ccmek_Ek2Q6D"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"7VASi3lZ2Q39"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"ct-qKgs42ZK9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"SIpKCQTo2ZHl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"r6q7VzoOzt-T"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/ILRMA/TILRMA-ISS2-ME.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd\n","from tqdm.notebook import tqdm"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 4\n","max_duration = 10\n","sisec2010_tag = \"dev1_female4\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.ilrma import TILRMA as TILRMABase"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class TILRMA(TILRMABase):\n"," def __init__(self, *args, **kwargs):\n"," super().__init__(*args, **kwargs)\n","\n"," self.progress_bar = None\n","\n"," def __call__(self, *args, n_iter: int = 100, **kwargs):\n"," self.n_iter = n_iter\n","\n"," return super().__call__(*args, n_iter=n_iter, **kwargs)\n","\n"," def update_once(self) -> None:\n"," if self.progress_bar is None:\n"," self.progress_bar = tqdm(total=self.n_iter)\n","\n"," super().update_once()\n","\n"," self.progress_bar.update(1)"],"metadata":{"id":"F_vG0afTa_F9"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/ partitioning function"],"metadata":{"id":"8ErS0NZ12Gyq"}},{"cell_type":"code","source":["ilrma = TILRMA(\n"," n_basis=8,\n"," dof=1000,\n"," spatial_algorithm=\"ISS2\",\n"," source_algorithm=\"ME\",\n"," domain=2,\n"," partitioning=True, # w/ partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=500)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/o partitioning function"],"metadata":{"id":"3iVGgX3F2M8Z"}},{"cell_type":"code","source":["ilrma = TILRMA(\n"," n_basis=3,\n"," dof=1000,\n"," spatial_algorithm=\"ISS2\",\n"," source_algorithm=\"ME\",\n"," domain=2,\n"," partitioning=False, # w/o partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"X6OsBzSZ2Q8Z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=200)"],"metadata":{"id":"Ccmek_Ek2Q6D"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"7VASi3lZ2Q39"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"ct-qKgs42ZK9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"SIpKCQTo2ZHl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"HEf1-l6w4zon"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/ILRMA/TILRMA-ISS2-MM.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd\n","from tqdm.notebook import tqdm"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 4\n","max_duration = 10\n","sisec2010_tag = \"dev1_female4\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.ilrma import TILRMA as TILRMABase"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class TILRMA(TILRMABase):\n"," def __init__(self, *args, **kwargs):\n"," super().__init__(*args, **kwargs)\n","\n"," self.progress_bar = None\n","\n"," def __call__(self, *args, n_iter: int = 100, **kwargs):\n"," self.n_iter = n_iter\n","\n"," return super().__call__(*args, n_iter=n_iter, **kwargs)\n","\n"," def update_once(self) -> None:\n"," if self.progress_bar is None:\n"," self.progress_bar = tqdm(total=self.n_iter)\n","\n"," super().update_once()\n","\n"," self.progress_bar.update(1)"],"metadata":{"id":"F_vG0afTa_F9"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/ partitioning function"],"metadata":{"id":"8ErS0NZ12Gyq"}},{"cell_type":"code","source":["ilrma = TILRMA(\n"," n_basis=8,\n"," dof=1000,\n"," spatial_algorithm=\"ISS2\",\n"," source_algorithm=\"MM\",\n"," domain=2,\n"," partitioning=True, # w/ partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=500)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/o partitioning function"],"metadata":{"id":"3iVGgX3F2M8Z"}},{"cell_type":"code","source":["ilrma = TILRMA(\n"," n_basis=3,\n"," dof=1000,\n"," spatial_algorithm=\"ISS2\",\n"," source_algorithm=\"MM\",\n"," domain=2,\n"," partitioning=False, # w/o partitioning function\n"," rng=np.random.default_rng(42),\n",")\n","print(ilrma)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"X6OsBzSZ2Q8Z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=200)"],"metadata":{"id":"Ccmek_Ek2Q6D"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"7VASi3lZ2Q39"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"ct-qKgs42ZK9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ilrma.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"SIpKCQTo2ZHl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"HEf1-l6w4zon"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/IPSDTA/GaussIPSDTA-VCD.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd\n","from tqdm.notebook import tqdm"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 2\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.ipsdta import GaussIPSDTA as GaussIPSDTABase"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class GaussIPSDTA(GaussIPSDTABase):\n"," def __init__(self, *args, source_steps=1, spatial_steps=1, **kwargs):\n"," super().__init__(*args, **kwargs)\n","\n"," self.progress_bar = None\n"," self.source_steps = source_steps\n"," self.spatial_steps = spatial_steps\n","\n"," def __call__(self, *args, n_iter: int = 100, **kwargs):\n"," self.n_iter = n_iter\n","\n"," return super().__call__(*args, n_iter=n_iter, **kwargs)\n","\n"," def update_once(self) -> None:\n"," if self.progress_bar is None:\n"," self.progress_bar = tqdm(total=self.n_iter)\n","\n"," for _ in range(self.source_steps):\n"," self.update_source_model()\n","\n"," for _ in range(self.spatial_steps):\n"," self.update_spatial_model()\n","\n"," self.progress_bar.update(1)"],"metadata":{"id":"KChk_QhJ8Xhg"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["ipsdta = GaussIPSDTA(\n"," n_basis=2,\n"," n_blocks=1024, # block 1: {1, 2}, ..., block 1023: {2045, 2046}, block 1024: {2047, 2048, 2049}\n"," spatial_algorithm=\"VCD\",\n"," source_steps=1,\n"," spatial_steps=10,\n"," rng=np.random.default_rng(42),\n",")\n","print(ipsdta)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"X6OsBzSZ2Q8Z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ipsdta(spectrogram_mix, n_iter=100)"],"metadata":{"id":"Ccmek_Ek2Q6D"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"7VASi3lZ2Q39"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"ct-qKgs42ZK9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ipsdta.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"SIpKCQTo2ZHl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"r6q7VzoOzt-T"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/IPSDTA/TIPSDTA-VCD.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd\n","from tqdm.notebook import tqdm"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 2\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.ipsdta import TIPSDTA as TIPSDTABase"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class TIPSDTA(TIPSDTABase):\n"," def __init__(self, *args, source_steps=1, spatial_steps=1, **kwargs):\n"," super().__init__(*args, **kwargs)\n","\n"," self.progress_bar = None\n"," self.source_steps = source_steps\n"," self.spatial_steps = spatial_steps\n","\n"," def __call__(self, *args, n_iter: int = 100, **kwargs):\n"," self.n_iter = n_iter\n","\n"," return super().__call__(*args, n_iter=n_iter, **kwargs)\n","\n"," def update_once(self) -> None:\n"," if self.progress_bar is None:\n"," self.progress_bar = tqdm(total=self.n_iter)\n","\n"," for _ in range(self.source_steps):\n"," self.update_source_model()\n","\n"," for _ in range(self.spatial_steps):\n"," self.update_spatial_model()\n","\n"," self.progress_bar.update(1)"],"metadata":{"id":"KChk_QhJ8Xhg"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["ipsdta = TIPSDTA(\n"," n_basis=2,\n"," n_blocks=1024, # block 1: {1, 2}, ..., block 1023: {2045, 2046}, block 1024: {2047, 2048, 2049}\n"," dof=1000,\n"," spatial_algorithm=\"VCD\",\n"," source_steps=1,\n"," spatial_steps=10,\n"," rng=np.random.default_rng(42),\n",")\n","print(ipsdta)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"X6OsBzSZ2Q8Z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ipsdta(spectrogram_mix, n_iter=100)"],"metadata":{"id":"Ccmek_Ek2Q6D"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"7VASi3lZ2Q39"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"ct-qKgs42ZK9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(ipsdta.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"SIpKCQTo2ZHl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"r6q7VzoOzt-T"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/IVA/AuxGaussIVA-IP1.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 3\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.iva import AuxGaussIVA"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["iva = AuxGaussIVA(\n"," spatial_algorithm=\"IP1\", # You can set \"IP\" instead of \"IP1\".\n",")\n","print(iva)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = iva(spectrogram_mix, n_iter=50)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(iva.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/IVA/AuxGaussIVA-IP2.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 3\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.iva import AuxGaussIVA"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["iva = AuxGaussIVA(\n"," spatial_algorithm=\"IP2\",\n",")\n","print(iva)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = iva(spectrogram_mix, n_iter=50)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(iva.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/IVA/AuxGaussIVA-IPA.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 3\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.iva import AuxGaussIVA"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["iva = AuxGaussIVA(\n"," spatial_algorithm=\"IPA\",\n"," newton_iter=10,\n",")\n","print(iva)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = iva(spectrogram_mix, n_iter=50)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(iva.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/IVA/AuxGaussIVA-ISS1.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 3\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.iva import AuxGaussIVA"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["iva = AuxGaussIVA(\n"," spatial_algorithm=\"ISS1\", # You can set \"ISS\" instead of \"ISS1\".\n",")\n","print(iva)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = iva(spectrogram_mix, n_iter=50)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(iva.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/IVA/AuxGaussIVA-ISS2.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 4\n","max_duration = 10\n","sisec2010_tag = \"dev1_female4\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.iva import AuxGaussIVA"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["iva = AuxGaussIVA(\n"," spatial_algorithm=\"ISS2\",\n",")\n","print(iva)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = iva(spectrogram_mix, n_iter=50)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(iva.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/IVA/AuxIVA-IP1.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 3\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.iva import AuxIVA"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def contrast_fn(y):\n"," return 2 * np.linalg.norm(y, axis=1)\n","\n","def d_contrast_fn(y):\n"," return 2 * np.ones_like(y)"],"metadata":{"id":"6AvhtkrdfAeZ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["iva = AuxIVA(\n"," spatial_algorithm=\"IP1\", # You can set \"IP\" instead of \"IP1\".\n"," contrast_fn=contrast_fn,\n"," d_contrast_fn=d_contrast_fn\n",")\n","print(iva)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = iva(spectrogram_mix, n_iter=50)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(iva.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/IVA/AuxIVA-IP2.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 3\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.iva import AuxIVA"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def contrast_fn(y):\n"," return 2 * np.linalg.norm(y, axis=1)\n","\n","def d_contrast_fn(y):\n"," return 2 * np.ones_like(y)"],"metadata":{"id":"6AvhtkrdfAeZ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["iva = AuxIVA(\n"," spatial_algorithm=\"IP2\",\n"," contrast_fn=contrast_fn,\n"," d_contrast_fn=d_contrast_fn\n",")\n","print(iva)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = iva(spectrogram_mix, n_iter=50)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(iva.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/IVA/AuxIVA-IPA.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 3\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.iva import AuxIVA"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def contrast_fn(y):\n"," return 2 * np.linalg.norm(y, axis=1)\n","\n","def d_contrast_fn(y):\n"," return 2 * np.ones_like(y)"],"metadata":{"id":"6AvhtkrdfAeZ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["iva = AuxIVA(\n"," spatial_algorithm=\"IPA\",\n"," contrast_fn=contrast_fn,\n"," d_contrast_fn=d_contrast_fn,\n"," newton_iter=10,\n",")\n","print(iva)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = iva(spectrogram_mix, n_iter=50)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(iva.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"OoJoi32YK5PQ"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/IVA/AuxIVA-ISS1.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 3\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.iva import AuxIVA"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def contrast_fn(y):\n"," return 2 * np.linalg.norm(y, axis=1)\n","\n","def d_contrast_fn(y):\n"," return 2 * np.ones_like(y)"],"metadata":{"id":"6AvhtkrdfAeZ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["iva = AuxIVA(\n"," spatial_algorithm=\"ISS1\", # You can set \"ISS\" instead of \"ISS1\".\n"," contrast_fn=contrast_fn,\n"," d_contrast_fn=d_contrast_fn\n",")\n","print(iva)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = iva(spectrogram_mix, n_iter=50)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(iva.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/IVA/AuxIVA-ISS2.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 4\n","max_duration = 10\n","sisec2010_tag = \"dev1_female4\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.iva import AuxIVA"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def contrast_fn(y):\n"," return 2 * np.linalg.norm(y, axis=1)\n","\n","def d_contrast_fn(y):\n"," return 2 * np.ones_like(y)"],"metadata":{"id":"6AvhtkrdfAeZ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["iva = AuxIVA(\n"," spatial_algorithm=\"ISS2\",\n"," contrast_fn=contrast_fn,\n"," d_contrast_fn=d_contrast_fn\n",")\n","print(iva)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = iva(spectrogram_mix, n_iter=50)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(iva.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/IVA/AuxLaplaceIVA-IP1.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 3\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.iva import AuxLaplaceIVA"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["iva = AuxLaplaceIVA(\n"," spatial_algorithm=\"IP1\", # You can set \"IP\" instead of \"IP1\".\n",")\n","print(iva)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = iva(spectrogram_mix, n_iter=50)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(iva.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/IVA/AuxLaplaceIVA-IP2.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 3\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.iva import AuxLaplaceIVA"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["iva = AuxLaplaceIVA(\n"," spatial_algorithm=\"IP2\",\n",")\n","print(iva)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = iva(spectrogram_mix, n_iter=50)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(iva.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/IVA/AuxLaplaceIVA-IPA.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 3\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.iva import AuxLaplaceIVA"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["iva = AuxLaplaceIVA(\n"," spatial_algorithm=\"IPA\",\n"," newton_iter=10,\n",")\n","print(iva)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = iva(spectrogram_mix, n_iter=50)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(iva.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/IVA/AuxLaplaceIVA-ISS1.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 3\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.iva import AuxLaplaceIVA"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["iva = AuxLaplaceIVA(\n"," spatial_algorithm=\"ISS1\", # You can set \"ISS\" instead of \"ISS1\".\n",")\n","print(iva)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = iva(spectrogram_mix, n_iter=50)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(iva.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/IVA/AuxLaplaceIVA-ISS2.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 4\n","max_duration = 10\n","sisec2010_tag = \"dev1_female4\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.iva import AuxLaplaceIVA"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["iva = AuxLaplaceIVA(\n"," spatial_algorithm=\"ISS2\",\n",")\n","print(iva)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = iva(spectrogram_mix, n_iter=50)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(iva.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/IVA/FastIVA.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 3\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.iva import FastIVA"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def contrast_fn(y):\n"," return 2 * np.linalg.norm(y, axis=1)\n","\n","def d_contrast_fn(y):\n"," return 2 * np.ones_like(y)\n","\n","def dd_contrast_fn(y):\n"," return 2 * np.zeros_like(y)"],"metadata":{"id":"6AvhtkrdfAeZ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["iva = FastIVA(\n"," contrast_fn=contrast_fn,\n"," d_contrast_fn=d_contrast_fn,\n"," dd_contrast_fn=dd_contrast_fn\n",")\n","print(iva)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = iva(spectrogram_mix, n_iter=100)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(iva.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"YQdFA8YAv7NK"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/IVA/FasterIVA.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 2\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.iva import FasterIVA"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def contrast_fn(y):\n"," return 2 * np.linalg.norm(y, axis=1)\n","\n","def d_contrast_fn(y):\n"," return 2 * np.ones_like(y)"],"metadata":{"id":"6AvhtkrdfAeZ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["iva = FasterIVA(\n"," contrast_fn=contrast_fn,\n"," d_contrast_fn=d_contrast_fn\n",")\n","print(iva)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = iva(spectrogram_mix, n_iter=50)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(iva.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"YQdFA8YAv7NK"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/IVA/GradGaussIVA.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"6qelLYKGEye2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 3\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.transform import whiten\n","from ssspy.algorithm import projection_back\n","from ssspy.bss.iva import GradGaussIVA"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Holonomic type"],"metadata":{"id":"9XmmBfq1p32-"}},{"cell_type":"code","source":["iva = GradGaussIVA(\n"," step_size=1e-1,\n"," is_holonomic=True,\n"," scale_restoration=False\n",")\n","print(iva)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_mix_whitened = whiten(spectrogram_mix)\n","spectrogram_est = iva(spectrogram_mix_whitened, n_iter=100)\n","spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(iva.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Non-holonomic type"],"metadata":{"id":"YNtBangeqF3m"}},{"cell_type":"code","source":["iva = GradGaussIVA(\n"," step_size=1e-1,\n"," is_holonomic=False,\n"," scale_restoration=False\n",")\n","print(iva)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"wFA0WatXqG3C"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_mix_whitened = whiten(spectrogram_mix)\n","spectrogram_est = iva(spectrogram_mix_whitened, n_iter=100)\n","spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)"],"metadata":{"id":"HEbtWD5nqJpD"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"9LYvRj20qLW2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"_pvjv8PDqNBh"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(iva.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"Dfue8hDuqOVf"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"UsEuDyYjqPx0"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/IVA/GradIVA.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"6qelLYKGEye2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 3\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.transform import whiten\n","from ssspy.algorithm import projection_back\n","from ssspy.bss.iva import GradIVA"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def contrast_fn(y):\n"," return 2 * np.linalg.norm(y, axis=1)\n","\n","def score_fn(y):\n"," norm = np.linalg.norm(y, axis=1, keepdims=True)\n"," norm = np.maximum(norm, 1e-10)\n"," return y / norm"],"metadata":{"id":"6AvhtkrdfAeZ"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Holonomic type"],"metadata":{"id":"9XmmBfq1p32-"}},{"cell_type":"code","source":["iva = GradIVA(\n"," step_size=1e+2,\n"," contrast_fn=contrast_fn,\n"," score_fn=score_fn,\n"," is_holonomic=True,\n"," scale_restoration=False\n",")\n","print(iva)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_mix_whitened = whiten(spectrogram_mix)\n","spectrogram_est = iva(spectrogram_mix_whitened, n_iter=500)\n","spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(iva.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Non-holonomic type"],"metadata":{"id":"YNtBangeqF3m"}},{"cell_type":"code","source":["iva = GradIVA(\n"," step_size=1e+1,\n"," contrast_fn=contrast_fn,\n"," score_fn=score_fn,\n"," is_holonomic=False,\n"," scale_restoration=False\n",")\n","print(iva)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"wFA0WatXqG3C"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_mix_whitened = whiten(spectrogram_mix)\n","spectrogram_est = iva(spectrogram_mix_whitened, n_iter=500)\n","spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)"],"metadata":{"id":"HEbtWD5nqJpD"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"9LYvRj20qLW2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"_pvjv8PDqNBh"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(iva.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"Dfue8hDuqOVf"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"UsEuDyYjqPx0"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/IVA/GradLaplaceIVA.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"6qelLYKGEye2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 3\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.transform import whiten\n","from ssspy.algorithm import projection_back\n","from ssspy.bss.iva import GradLaplaceIVA"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Holonomic type"],"metadata":{"id":"9XmmBfq1p32-"}},{"cell_type":"code","source":["iva = GradLaplaceIVA(\n"," step_size=1e+2,\n"," is_holonomic=True,\n"," scale_restoration=False\n",")\n","print(iva)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_mix_whitened = whiten(spectrogram_mix)\n","spectrogram_est = iva(spectrogram_mix_whitened, n_iter=500)\n","spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(iva.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Non-holonomic type"],"metadata":{"id":"YNtBangeqF3m"}},{"cell_type":"code","source":["iva = GradLaplaceIVA(\n"," step_size=1e+1,\n"," is_holonomic=False,\n"," scale_restoration=False\n",")\n","print(iva)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"wFA0WatXqG3C"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_mix_whitened = whiten(spectrogram_mix)\n","spectrogram_est = iva(spectrogram_mix_whitened, n_iter=500)\n","spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)"],"metadata":{"id":"HEbtWD5nqJpD"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"9LYvRj20qLW2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"_pvjv8PDqNBh"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(iva.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"Dfue8hDuqOVf"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"UsEuDyYjqPx0"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/IVA/NaturalGradGaussIVA.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"4VTo4zewE2uu"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 3\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.transform import whiten\n","from ssspy.algorithm import projection_back\n","from ssspy.bss.iva import NaturalGradGaussIVA"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Holonomic type"],"metadata":{"id":"9XmmBfq1p32-"}},{"cell_type":"code","source":["iva = NaturalGradGaussIVA(\n"," step_size=1e-2,\n"," is_holonomic=True,\n"," scale_restoration=False\n",")\n","print(iva)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_mix_whitened = whiten(spectrogram_mix)\n","spectrogram_est = iva(spectrogram_mix_whitened, n_iter=500)\n","spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(iva.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Non-holonomic type"],"metadata":{"id":"YNtBangeqF3m"}},{"cell_type":"code","source":["iva = NaturalGradGaussIVA(\n"," step_size=1e-1,\n"," is_holonomic=False,\n"," scale_restoration=False\n",")\n","print(iva)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"wFA0WatXqG3C"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_mix_whitened = whiten(spectrogram_mix)\n","spectrogram_est = iva(spectrogram_mix_whitened, n_iter=500)\n","spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)"],"metadata":{"id":"HEbtWD5nqJpD"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"9LYvRj20qLW2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"_pvjv8PDqNBh"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(iva.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"Dfue8hDuqOVf"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"UsEuDyYjqPx0"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/IVA/NaturalGradIVA.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"4VTo4zewE2uu"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 3\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.transform import whiten\n","from ssspy.algorithm import projection_back\n","from ssspy.bss.iva import NaturalGradIVA"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def contrast_fn(y):\n"," return 2 * np.linalg.norm(y, axis=1)\n","\n","def score_fn(y):\n"," norm = np.linalg.norm(y, axis=1, keepdims=True)\n"," norm = np.maximum(norm, 1e-10)\n"," return y / norm"],"metadata":{"id":"6AvhtkrdfAeZ"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Holonomic type"],"metadata":{"id":"9XmmBfq1p32-"}},{"cell_type":"code","source":["iva = NaturalGradIVA(\n"," step_size=1e-1,\n"," contrast_fn=contrast_fn,\n"," score_fn=score_fn,\n"," is_holonomic=True\n",")\n","print(iva)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = iva(spectrogram_mix, n_iter=500)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(iva.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Non-holonomic type"],"metadata":{"id":"YNtBangeqF3m"}},{"cell_type":"code","source":["iva = NaturalGradIVA(\n"," step_size=1e+1,\n"," contrast_fn=contrast_fn,\n"," score_fn=score_fn,\n"," is_holonomic=False,\n"," scale_restoration=False\n",")\n","print(iva)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"wFA0WatXqG3C"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_mix_whitened = whiten(spectrogram_mix)\n","spectrogram_est = iva(spectrogram_mix_whitened, n_iter=500)\n","spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)"],"metadata":{"id":"HEbtWD5nqJpD"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"9LYvRj20qLW2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"_pvjv8PDqNBh"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(iva.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"Dfue8hDuqOVf"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"UsEuDyYjqPx0"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/IVA/NaturalGradLaplaceIVA.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"4VTo4zewE2uu"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 3\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.transform import whiten\n","from ssspy.algorithm import projection_back\n","from ssspy.bss.iva import NaturalGradLaplaceIVA"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Holonomic type"],"metadata":{"id":"9XmmBfq1p32-"}},{"cell_type":"code","source":["iva = NaturalGradLaplaceIVA(\n"," step_size=1e-1,\n"," is_holonomic=True,\n",")\n","print(iva)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = iva(spectrogram_mix, n_iter=500)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(iva.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Non-holonomic type"],"metadata":{"id":"YNtBangeqF3m"}},{"cell_type":"code","source":["iva = NaturalGradLaplaceIVA(\n"," step_size=1e+1,\n"," is_holonomic=False,\n"," scale_restoration=False\n",")\n","print(iva)"],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"wFA0WatXqG3C"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_mix_whitened = whiten(spectrogram_mix)\n","spectrogram_est = iva(spectrogram_mix_whitened, n_iter=500)\n","spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)"],"metadata":{"id":"HEbtWD5nqJpD"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"9LYvRj20qLW2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"_pvjv8PDqNBh"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(iva.loss)\n","plt.show()\n","plt.close()"],"metadata":{"id":"Dfue8hDuqOVf"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"UsEuDyYjqPx0"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/MNMF/FastGaussMNMF-IP1.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"8WeTZarWPKFL"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd\n","from tqdm.notebook import tqdm"],"metadata":{"id":"EJnsybEqPOuX"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"z9eNi9hHPQjM"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 2\n","max_duration = 10\n","reverb_duration = 0.36\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"eLlZRFo8PRrL"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," reverb_duration=reverb_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"KxmdFO-mPUPr"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"HqsfDVe5PV3j"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.mnmf import FastGaussMNMF as FastGaussMNMFBase"],"metadata":{"id":"pD9uMM6ZPaqZ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class FastGaussMNMF(FastGaussMNMFBase):\n"," def __init__(self, *args, **kwargs):\n"," super().__init__(*args, **kwargs)\n","\n"," self.progress_bar = None\n","\n"," def __call__(self, *args, n_iter: int = 100, **kwargs):\n"," self.n_iter = n_iter\n","\n"," return super().__call__(*args, n_iter=n_iter, **kwargs)\n","\n"," def update_once(self) -> None:\n"," if self.progress_bar is None:\n"," self.progress_bar = tqdm(total=self.n_iter)\n","\n"," super().update_once()\n","\n"," self.progress_bar.update(1)"],"metadata":{"id":"zqa964qWKKhN"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["mnmf = FastGaussMNMF(\n"," n_basis=16,\n"," n_sources=2,\n"," diagonalizer_algorithm=\"IP1\", # You can set \"IP\" instead of \"IP1\".\n"," partitioning=False,\n"," rng=np.random.default_rng(42),\n",")\n","print(mnmf)"],"metadata":{"id":"vF2lhWRfPXDs"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"JD5ndId6Pe6l"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = mnmf(spectrogram_mix, n_iter=200)"],"metadata":{"id":"G3LxT-USPgB6"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"luvhp8QRPg_1"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"YddaALjBPiA_"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(mnmf.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"zIFWvEW0PjDI"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"UqwiQjodPkiH"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/MNMF/FastGaussMNMF-IP2.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"8WeTZarWPKFL"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd\n","from tqdm.notebook import tqdm"],"metadata":{"id":"EJnsybEqPOuX"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"z9eNi9hHPQjM"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 2\n","max_duration = 10\n","reverb_duration = 0.36\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"eLlZRFo8PRrL"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," reverb_duration=reverb_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"KxmdFO-mPUPr"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"HqsfDVe5PV3j"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.mnmf import FastGaussMNMF as FastGaussMNMFBase"],"metadata":{"id":"pD9uMM6ZPaqZ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class FastGaussMNMF(FastGaussMNMFBase):\n"," def __init__(self, *args, **kwargs):\n"," super().__init__(*args, **kwargs)\n","\n"," self.progress_bar = None\n","\n"," def __call__(self, *args, n_iter: int = 100, **kwargs):\n"," self.n_iter = n_iter\n","\n"," return super().__call__(*args, n_iter=n_iter, **kwargs)\n","\n"," def update_once(self) -> None:\n"," if self.progress_bar is None:\n"," self.progress_bar = tqdm(total=self.n_iter)\n","\n"," super().update_once()\n","\n"," self.progress_bar.update(1)"],"metadata":{"id":"TrIDa8g-KQwx"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["mnmf = FastGaussMNMF(\n"," n_basis=8,\n"," n_sources=2,\n"," diagonalizer_algorithm=\"IP2\",\n"," partitioning=False,\n"," rng=np.random.default_rng(42),\n",")\n","print(mnmf)"],"metadata":{"id":"vF2lhWRfPXDs"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"JD5ndId6Pe6l"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = mnmf(spectrogram_mix, n_iter=200)"],"metadata":{"id":"G3LxT-USPgB6"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"luvhp8QRPg_1"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"YddaALjBPiA_"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(mnmf.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"zIFWvEW0PjDI"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"UqwiQjodPkiH"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/MNMF/GaussMNMF.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"ab7xaF2Brwdn"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd\n","from tqdm.notebook import tqdm"],"metadata":{"id":"DXzG4Q9pr0bb"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"-Dv4Pr4lr1h5"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 2\n","max_duration = 10\n","reverb_duration = 0.36\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"zIcbmVidr2da"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," reverb_duration=reverb_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7qf18ULsr3Ra"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"eqGUE2Lxr4kE"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.bss.mnmf import GaussMNMF as GaussMNMFBase"],"metadata":{"id":"AE2qyrBar5mu"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class GaussMNMF(GaussMNMFBase):\n"," def __init__(self, *args, **kwargs) -> None:\n"," super().__init__(*args, **kwargs)\n","\n"," self.progress_bar = None\n","\n"," def __call__(self, *args, n_iter: int = 100, **kwargs) -> np.ndarray:\n"," self.n_iter = n_iter\n","\n"," return super().__call__(*args, n_iter=n_iter, **kwargs)\n","\n"," def update_once(self) -> None:\n"," if self.progress_bar is None:\n"," self.progress_bar = tqdm(total=self.n_iter)\n","\n"," super().update_once()\n","\n"," self.progress_bar.update(1)"],"metadata":{"id":"7WV07guJr6ql"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/ partitioning function"],"metadata":{"id":"e-LL4HxC1B0o"}},{"cell_type":"code","source":["mnmf = GaussMNMF(\n"," n_basis=30,\n"," n_sources=2,\n"," partitioning=True,\n"," rng=np.random.default_rng(42),\n",")\n","print(mnmf)"],"metadata":{"id":"AhCv-7wsr7qW"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"ZrzY92tor-KR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = mnmf(spectrogram_mix, n_iter=200)"],"metadata":{"id":"1q6-YqiRr_bR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"3RrMJzK7sAjB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"x5NrR6GYsBhp"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(mnmf.loss[10:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"gai3R1cUsCpz"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## w/o partitioning function"],"metadata":{"id":"QRlsyJNh1G3S"}},{"cell_type":"code","source":["mnmf = GaussMNMF(\n"," n_basis=10,\n"," n_sources=2,\n"," partitioning=False,\n"," rng=np.random.default_rng(42),\n",")\n","print(mnmf)"],"metadata":{"id":"ulNwZ-TusD3F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"CJzGmTKY1N2p"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = mnmf(spectrogram_mix, n_iter=500)"],"metadata":{"id":"E5zfaLRl1Ou5"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"YMgxFM0-1Pzb"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"n4bek3rj1Q34"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(mnmf.loss[10:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"ymg1yCBn1Ruj"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"UzRVEMZo1Uxv"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/PDSBSS/PDSBSS.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 2\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.transform import whiten\n","from ssspy.algorithm import projection_back\n","from ssspy.bss.pdsbss import PDSBSS"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def l21_fn(y: np.ndarray) -> np.ndarray:\n"," \"\"\"Mixed L21 norm.\n","\n"," Args:\n"," y (np.ndarray):\n"," Input vector with shape of (n_sources, n_bins, n_frames).\n","\n"," Returns:\n"," Sum of mixed L21 norm.\n"," \"\"\"\n"," G = np.linalg.norm(y, axis=1)\n"," loss = np.sum(G, axis=(0, 1))\n","\n"," return loss\n","\n","def prox_l21(y, step_size: float = 1) -> np.ndarray:\n"," \"\"\"Apply proximal operator of mixed L21 norm.\n","\n"," Args:\n"," y (np.ndarray):\n"," Input vector with shape of (n_sources, n_bins, n_frames).\n"," step_size (float):\n"," Step size parameter.\n","\n"," Returns:\n"," Output value computed by proximal operator of mixed L21 norm.\n"," The shape of (n_sources, n_bins, n_frames).\n"," \"\"\"\n"," norm = np.linalg.norm(y, axis=1, keepdims=True)\n","\n"," return np.maximum(1 - step_size / norm, 0) * y"],"metadata":{"id":"6AvhtkrdfAeZ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["pds_bss = PDSBSS(\n"," mu1=1,\n"," mu2=1,\n"," alpha=1.75,\n"," penalty_fn=l21_fn,\n"," prox_penalty=prox_l21,\n"," scale_restoration=False,\n",")\n","print(pds_bss)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_mix_whitened = whiten(spectrogram_mix)\n","spectrogram_mix_normalized = pds_bss.normalize_by_spectral_norm(spectrogram_mix_whitened)\n","spectrogram_est = pds_bss(spectrogram_mix_normalized, n_iter=1000)\n","spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(pds_bss.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/PDSBSS/PDSBSS_masking.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyM/ed6diEsnD5EoCV6u0+FD"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"nr1INZyXeXuR"},"outputs":[],"source":["!pip install git+https://github.com/tky823/ssspy.git"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import IPython.display as ipd"],"metadata":{"id":"r9-0QuFfeaut"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"bYCpxjc6ecsj"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 2\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"psnOiF2Tedv_"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"9DGQmwafeeoR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"FVoIVFKPefZB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["import functools"],"metadata":{"id":"NFL0jjzSevDf"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.transform import whiten\n","from ssspy.algorithm import projection_back\n","from ssspy.bss.pdsbss import MaskingPDSBSS"],"metadata":{"id":"YFJ_fx8gegRT"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def l21_mask(y, step_size: float = 1) -> np.ndarray:\n"," \"\"\"Apply proximal operator of mixed L21 norm.\n","\n"," Args:\n"," y (np.ndarray):\n"," Input vector with shape of (n_sources, n_bins, n_frames).\n"," step_size (float):\n"," Step size parameter.\n","\n"," Returns:\n"," Mask based on mixed L21 norm.\n"," The shape of (n_sources, n_bins, n_frames).\n","\n"," \"\"\"\n"," norm = np.linalg.norm(y, axis=1, keepdims=True)\n"," mask = np.maximum(1 - step_size / norm, 0)\n"," mask = np.tile(mask, (1, y.shape[1], 1))\n","\n"," return mask"],"metadata":{"id":"JyLNkfqKepOH"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["pds_bss = MaskingPDSBSS(\n"," mu1=1,\n"," mu2=1,\n"," relaxation=1.75,\n"," mask_fn=functools.partial(l21_mask, step_size=1),\n"," scale_restoration=False,\n",")\n","print(pds_bss)"],"metadata":{"id":"3pvUgw5vesw1"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"ii11j5XrexNu"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_mix_whitened = whiten(spectrogram_mix)\n","spectrogram_mix_normalized = pds_bss.normalize_by_spectral_norm(spectrogram_mix_whitened)\n","spectrogram_est = pds_bss(spectrogram_mix_normalized, n_iter=1000)\n","spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)"],"metadata":{"id":"XpScpxUQe1cT"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"HY8VE2ARe3mW"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"v9vi88kKe48R"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"008c-U04e6Dw"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/BSS/PDSBSS/PDSBSS_multi-penalty.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"iLMRz5h_I_U_"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import matplotlib.pyplot as plt\n","import IPython.display as ipd"],"metadata":{"id":"k4vmcNb5em_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.utils.dataset import download_sample_speech_data"],"metadata":{"id":"rOvxrG-sfp02"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 2\n","max_duration = 10\n","sisec2010_tag = \"dev1_female3\"\n","n_fft, hop_length = 4096, 2048"],"metadata":{"id":"faUwi6X9esWR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_speech_data(\n"," n_sources=n_sources,\n"," sisec2010_tag=sisec2010_tag,\n"," max_duration=max_duration,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"7jwyi2wReuRR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"Pa28gsTce8yt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["import functools"],"metadata":{"id":"ga-3Noc9sr5p"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ssspy.transform import whiten\n","from ssspy.algorithm import projection_back\n","from ssspy.bss.pdsbss import PDSBSS"],"metadata":{"id":"tixTvLybe-w7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def l21_fn(y: np.ndarray) -> np.ndarray:\n"," \"\"\"Compute sum of mixed L21 norm.\n","\n"," Args:\n"," y (np.ndarray):\n"," Input vector with shape of (n_sources, n_bins, n_frames).\n","\n"," Returns:\n"," Sum of mixed L21 norm.\n"," \"\"\"\n"," G = np.linalg.norm(y, axis=1)\n"," loss = np.sum(G, axis=(0, 1))\n","\n"," return loss\n","\n","def lamb_l1_fn(y: np.ndarray, lamb: float = 1) -> np.ndarray:\n"," \"\"\"Compute sum of L1 norm.\n","\n"," Args:\n"," y (np.ndarray):\n"," Input vector with shape of (n_sources, n_bins, n_frames).\n","\n"," Returns:\n"," Sum of L1 norm.\n"," \"\"\"\n"," G = np.abs(y)\n"," loss = lamb * np.sum(G, axis=(0, 1, 2))\n","\n"," return loss\n","\n","def prox_l21(y: np.ndarray, step_size: float = 1) -> np.ndarray:\n"," \"\"\"Apply proximal operator of mixed L21 norm.\n","\n"," Args:\n"," y (np.ndarray):\n"," Input vector with shape of (n_sources, n_bins, n_frames).\n"," step_size (float):\n"," Step size parameter.\n","\n"," Returns:\n"," Output value computed by proximal operator of mixed L21 norm.\n"," The shape of (n_sources, n_bins, n_frames).\n"," \"\"\"\n"," norm = np.linalg.norm(y, axis=1, keepdims=True)\n","\n"," return np.maximum(1 - step_size / norm, 0) * y\n","\n","def prox_lamb_l1(y: np.ndarray, step_size: float=1, lamb: float = 1) -> np.ndarray:\n"," \"\"\"Apply proximal operator of L1 norm.\n","\n"," Args:\n"," y (np.ndarray):\n"," Input vector with shape of (n_sources, n_bins, n_frames).\n"," step_size (float):\n"," Step size parameter.\n","\n"," Returns:\n"," Output value computed by proximal operator of L1 norm.\n"," The shape of (n_sources, n_bins, n_frames).\n"," \"\"\"\n"," norm = np.abs(y)\n","\n"," return np.maximum(1 - (step_size * lamb) / norm, 0) * y"],"metadata":{"id":"6AvhtkrdfAeZ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# SparseIVA without masking\n","penalty_fn = [\n"," l21_fn,\n"," functools.partial(lamb_l1_fn, lamb=2e-3),\n","]\n","prox_penalty = [\n"," prox_l21,\n"," functools.partial(prox_lamb_l1, lamb=2e-3),\n","]"],"metadata":{"id":"hdi1FlWTs0Oc"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["pds_bss = PDSBSS(\n"," mu1=1,\n"," mu2=1,\n"," alpha=1.75,\n"," penalty_fn=penalty_fn,\n"," prox_penalty=prox_penalty,\n"," scale_restoration=False\n",")\n","print(pds_bss)"],"metadata":{"id":"h5GghnnMfP1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"EecGuY-JfSBB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_mix_whitened = whiten(spectrogram_mix)\n","spectrogram_mix_normalized = pds_bss.normalize_by_spectral_norm(spectrogram_mix_whitened)\n","spectrogram_est = pds_bss(spectrogram_mix_normalized, n_iter=1000)\n","spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)"],"metadata":{"id":"O3D2eA8HfVs2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=\"hann\", nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"jdCqrAdPfXkk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"xXTnlid-fZb3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.figure()\n","plt.plot(pds_bss.loss[1:])\n","plt.show()\n","plt.close()"],"metadata":{"id":"OHjwNcZIfe0K"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"qsBes0ysf9Sd"},"execution_count":null,"outputs":[]}]} ================================================ FILE: notebooks/Examples/Getting-Started.ipynb ================================================ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyO78AF8O5/QHlr4M5G1igtQ"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":["# Quick Example of Blind Source Separation\n","In this notebook, we will show you a quick example of blind source separation by Gauss-ILRMA."],"metadata":{"id":"RvVmZjNHDlDm"}},{"cell_type":"code","execution_count":null,"metadata":{"id":"A5OAL3QxCtDz"},"outputs":[],"source":["!pip install ssspy"]},{"cell_type":"code","source":["import numpy as np\n","import scipy.signal as ss\n","import IPython.display as ipd\n","import matplotlib.pyplot as plt"],"metadata":{"id":"h78m9LcCDVnJ"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Here, we use sample music data.\n","\n"],"metadata":{"id":"uIkscttUxNDy"}},{"cell_type":"code","source":["import os\n","import urllib.request\n","from typing import Tuple\n","\n","import ssspy\n","\n","def download_sample_music_data(\n"," n_sources: int = 2,\n"," conv: bool = True,\n"," branch: str = \"main\"\n",") -> Tuple[np.ndarray, int]:\n"," instruments = [\"violin\", \"piano\"]\n"," root = \".data\"\n"," url_template = \"https://github.com/tky823/ssspy-data/raw/{branch}/{path}\"\n"," path_template = \"audio/{instrument}_8k_reverbed.wav\"\n"," waveforms = []\n","\n"," for src_idx in range(1, n_sources + 1):\n"," instrument = instruments[src_idx - 1]\n"," path = path_template.format(instrument=instrument)\n"," download_path = os.path.join(root, path)\n"," url = url_template.format(branch=branch, path=path)\n","\n"," os.makedirs(os.path.dirname(download_path), exist_ok=True)\n","\n"," if not os.path.exists(download_path):\n"," urllib.request.urlretrieve(url, download_path)\n","\n"," waveform, sample_rate = ssspy.wavread(download_path, channels_first=True)\n","\n"," assert sample_rate == 8000, f\"Sampling rate is expected 8000, but {sample_rate} is given.\"\n","\n"," waveforms.append(waveform)\n","\n"," waveforms = np.stack(waveforms, axis=1)\n","\n"," return waveforms[: n_sources], 8000"],"metadata":{"id":"FRgyrAJQJ6qy"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["n_sources = 2\n","n_fft, hop_length = 2048, 512\n","window = \"hann\""],"metadata":{"id":"kQlKWdixDW7f"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["waveform_src_img, sample_rate = download_sample_music_data(\n"," n_sources=n_sources,\n"," conv=True,\n",") # (n_channels, n_sources, n_samples)\n","waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples)"],"metadata":{"id":"kmz01toPNq0b"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Let's listen to mixtures."],"metadata":{"id":"nr7q1nIEeJkS"}},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_mix):\n"," print(\"Mixture: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"fbc32pgfDZzO"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["`ssspy.bss.ilrma.GaussILRMA` class provides Gauss-ILRMA."],"metadata":{"id":"uGd-LoXHxmcA"}},{"cell_type":"code","source":["from ssspy.bss.ilrma import GaussILRMA"],"metadata":{"id":"dw0_3iy2DbKW"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["ilrma = GaussILRMA(n_basis=3, rng=np.random.default_rng(0))\n","print(ilrma)"],"metadata":{"id":"f_lnJNFuDdAj"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, _, spectrogram_mix = ss.stft(waveform_mix, window=window, nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"4_URDBncDr1o"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["spectrogram_est = ilrma(spectrogram_mix, n_iter=500)"],"metadata":{"id":"q2ueWnTdDtcY"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["_, waveform_est = ss.istft(spectrogram_est, window=window, nperseg=n_fft, noverlap=n_fft-hop_length)"],"metadata":{"id":"YARLJx37DuwV"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Let's listen to estimated sources."],"metadata":{"id":"nxlG0Qa2eFFX"}},{"cell_type":"code","source":["for idx, waveform in enumerate(waveform_est):\n"," print(\"Estimated source: {}\".format(idx + 1))\n"," display(ipd.Audio(waveform, rate=sample_rate))\n"," print()"],"metadata":{"id":"jsom6b5YDv_0"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Display cost at each iteration."],"metadata":{"id":"wwh56-cGeAQ_"}},{"cell_type":"code","source":["plt.figure()\n","plt.plot(range(len(ilrma.loss)), ilrma.loss)\n","plt.xlabel(\"Iteration\")\n","plt.ylabel(\"Cost\")\n","plt.show()\n","plt.close()"],"metadata":{"id":"n-z8DgHUDxLI"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Since there has been a significant decrease by the first update, let's see the costs after the first iteration."],"metadata":{"id":"9rtftb4vU_rw"}},{"cell_type":"code","source":["plt.figure()\n","plt.plot(range(1, len(ilrma.loss)), ilrma.loss[1:])\n","plt.xlabel(\"Iteration\")\n","plt.ylabel(\"Cost\")\n","plt.show()\n","plt.close()"],"metadata":{"id":"tV2nAF-xKmZw"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"v2r5VTF2VeW5"},"execution_count":null,"outputs":[]}]} ================================================ FILE: pyproject.toml ================================================ [build-system] # ref: https://github.com/pypa/setuptools_scm requires = [ "setuptools>=45", "setuptools_scm[toml]>=6.2", ] build-backend = "setuptools.build_meta" [project] name = "ssspy" authors = [ {name = "Takuya Hasumi"}, ] description = "A Python toolkit for sound source separation." readme = "README.md" license = {file = "LICENSE"} urls = {url = "https://github.com/tky823/ssspy"} requires-python = ">=3.8, <4" dependencies = [ "numpy", ] dynamic = [ "version", ] [project.optional-dependencies] dev = [ "flake8", "black", "isort" ] docs = [ "sphinx", "sphinx-autodoc-typehints", "sphinx-autobuild", "nbsphinx", "furo", ] notebooks = [ "ipykernel", "matplotlib", "scipy", # for STFT in notebooks ] tests = [ "pytest", "pytest-cov", "pytest-xdist", "scipy", ] [tool.setuptools.dynamic] version = {attr = "ssspy.__version__"} [tool.setuptools.packages.find] # TODO: redundancy with MANIFEST.in # see https://github.com/tky823/ssspy/issues/256 include = [ "ssspy", ] [tool.setuptools_scm] write_to = "ssspy/_version.py" version_scheme = "guess-next-dev" local_scheme = "no-local-version" [tool.black] line-length = 100 exclude = "ssspy/_version.py" [tools.flake8] max-line-length = 100 exclude = "ssspy/_version.py" [tool.isort] profile = "black" line_length = 100 [tool.pytest.ini_options] # to import relative paths pythonpath = [ "tests", ] ================================================ FILE: ssspy/__init__.py ================================================ try: from .io import wavread, wavwrite except ModuleNotFoundError: # to avoid module not found error during installation # e.g. numpy is not found in io.py pass try: from ._version import __version__ except ModuleNotFoundError: __version__ = "0.2.0" __all__ = ["__version__", "wavread", "wavwrite"] ================================================ FILE: ssspy/algorithm/__init__.py ================================================ from . import permutation_alignment from .minimal_distortion_principle import minimal_distortion_principle from .projection_back import projection_back __all__ = ["permutation_alignment", "minimal_distortion_principle", "projection_back"] PROJECTION_BACK_KEYWORDS = ["projection_back", "projection-back", "PB"] MINIMAL_DISTORTION_PRINCIPLE_KEYWORDS = [ "minimal_distortion_principle", "minimal-distortion-principle", "MDP", ] ================================================ FILE: ssspy/algorithm/minimal_distortion_principle.py ================================================ from typing import Optional import numpy as np def minimal_distortion_principle( estimated: np.ndarray, reference: Optional[np.ndarray] = None, reference_id: Optional[int] = 0, ) -> np.ndarray: r"""Minimal distortion principle to restore scale ambiguity. The implementation is based on [#matsuoka2002minimal]_. Args: estimated (numpy.ndarray): Estimated spectrograms with shape of (n_channels, n_bins, n_frames). reference (numpy.ndarray, optional): Reference spectrogram with shape of (n_sources, n_bins, n_frames). reference_id (int, optional): Reference microphone index. Default: ``0``. Returns: numpy.ndarray of rescaled estimated spectrograms or demixing filters. .. [#matsuoka2002minimal] N. Murata, S. Ikeda, and A. Ziehe, "Minimal distortion principle for blind source separation," in *Proc. ICA*, 2001, pp. 722-727. """ Y = estimated X_conj = reference.conj() if reference_id is None: num = np.sum(Y * X_conj[:, np.newaxis, :, :], axis=-1, keepdims=True) else: num = np.sum(Y * X_conj[reference_id], axis=-1, keepdims=True) denom = np.sum(np.abs(Y) ** 2, axis=-1, keepdims=True) Z = num / denom output_scaled = Z.conj() * Y return output_scaled ================================================ FILE: ssspy/algorithm/permutation_alignment.py ================================================ import functools import itertools from typing import Callable, Optional import numpy as np from ..special.flooring import identity, max_flooring EPS = 1e-10 def correlation_based_permutation_solver( sequence: np.ndarray, *args, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), overwrite: bool = True, ) -> np.ndarray: r"""Solve permutation of estimated spectrograms. Group channels at each frequency bin according to correlations between frequencies [#murata2001approach]_. Args: sequence (numpy.ndarray): Array-like sequence of shape (n_bins, n_sources, n_frames). args (tuple of numpy.ndarray, optional): Positional arguments each of which is ``numpy.ndarray``. The shapes of each item should be (n_bins, n_sources, \*). flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to receive (n_channels, n_bins, n_frames) and return (n_channels, n_bins, n_frames). If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``partial(max_flooring, eps=1e-10)``. overwrite (bool): Overwrite ``sequence`` and ``args`` if ``overwrite=True``. Default: ``True``. Returns: - If ``args`` is not given, ``numpy.ndarray`` of permutated separated spectrograms with shape of (n_sources, n_bins, n_frames) are returned. - If one positional argument is given, ``numpy.ndarray``s of permutated separated spectrograms and the permutated positional argument are returned. - If more than two positional arguments are given, ``numpy.ndarray``s of permutated separated spectrograms and the permutated positional arguments are returned. .. [#murata2001approach] N. Murata, S. Ikeda, and A. Ziehe, "An approach to blind source separation based on temporal structure of speech signals," in *Neurocomputing*, vol. 41, no. 1, pp. 1-24, 2001. .. note:: In this function, the shape of ``separated`` is expected ``(n_bins, n_sources, ...)``, which is different from other functions. """ assert sequence.ndim == 3, "Dimension of sequence is expected to be 3." for pos_idx, arg in enumerate(args): if arg.shape[:2] != sequence.shape[:2]: raise ValueError("The shape of {}th argument is invalid.".format(pos_idx + 1)) if overwrite: Y = sequence permutable = args else: Y = sequence.copy() permutable = [] for arg in args: permutable.append(arg.copy()) permutable = tuple(permutable) if flooring_fn is None: flooring_fn = identity else: flooring_fn = flooring_fn n_bins, n_sources, _ = Y.shape permutations = list(itertools.permutations(range(n_sources))) P = np.abs(Y) norm = np.sqrt(np.sum(P**2, axis=1, keepdims=True)) norm = flooring_fn(norm) P = P / norm correlation = np.sum(P @ P.transpose(0, 2, 1), axis=(1, 2)) indices = np.argsort(correlation) min_idx = indices[0] P_criteria = P[min_idx] for bin_idx in range(1, n_bins): min_idx = indices[bin_idx] P_max = None perm_max = None for perm in permutations: P_perm = np.sum(P_criteria * P[min_idx, perm, :]) if P_max is None or P_perm > P_max: P_max = P_perm perm_max = perm P_criteria = P_criteria + P[min_idx, perm_max, :] Y[min_idx, :] = Y[min_idx, perm_max] for idx in range(len(permutable)): permutable[idx][min_idx, :] = permutable[idx][min_idx, perm_max] if len(permutable) == 0: return Y elif len(permutable) == 1: return Y, permutable[0] else: return Y, permutable def score_based_permutation_solver( sequence: np.ndarray, *args, global_iter: int = 1, local_iter: int = 1, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), multi_centroids: bool = False, overwrite: bool = True, ) -> np.ndarray: r"""Align permutations between frequencies based on score value [#sawada2010underdetermined]_. Args: sequence (numpy.ndarray): Array-like sequence of shape (n_bins, n_sources, n_frames). args (tuple of numpy.ndarray, optional): Positional arguments each of which is ``numpy.ndarray``. The shapes of each item should be (n_bins, n_sources, \*). global_iter (int): Number of iterations in global optimization. local_iter (int): Number of iterations in local optimization. flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to receive (n_channels, n_bins, n_frames) and return (n_channels, n_bins, n_frames). If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``partial(max_flooring, eps=1e-10)``. multi_centroids (bool): If ``multi_centroids=True``, multiple centroids are used in global optimization. However, this is not supported now. Default: ``False``. overwrite (bool): Overwrite ``sequence`` and ``args`` if ``overwrite=True``. Default: ``True``. .. [#sawada2010underdetermined] H. Sawada, S. Araki, and S. Makino, "Underdetermined convolutive blind source separation \ via frequency bin-wise clustering and permutation alignment," in *IEEE Trans. ASLP*, vol. 19, no. 3, pp. 516-527, 2010. """ assert sequence.ndim == 3, "Dimension of sequence is expected to be 3." assert not multi_centroids, "multi_centroids version is not supported." for pos_idx, arg in enumerate(args): if arg.shape[:2] != sequence.shape[:2]: raise ValueError("The shape of {}th argument is invalid.".format(pos_idx + 1)) if overwrite: permutable = args else: sequence = sequence.copy() permutable = [] for arg in args: permutable.append(arg.copy()) permutable = tuple(permutable) if flooring_fn is None: flooring_fn = identity else: flooring_fn = flooring_fn n_bins, n_sources = sequence.shape[:2] na = np.newaxis eye = np.eye(n_sources) permutations = np.array(list(itertools.permutations(range(n_sources)))) sequence_mean = sequence.mean(axis=-1, keepdims=True) sequence_std = sequence.std(axis=-1, keepdims=True) sequence_normalized = (sequence - sequence_mean) / sequence_std for _ in range(global_iter): centroid = sequence_normalized.mean(axis=0) centroid_std = centroid.std(axis=-1, keepdims=True) scores = [] for perm in permutations: num = np.mean(sequence_normalized[:, perm, na] * centroid[na, :], axis=-1) denom = flooring_fn(centroid_std) corr = num / denom score = np.sum(eye * corr - (1 - eye) * corr, axis=(1, 2)) scores.append(score) scores = np.stack(scores, axis=1) perm_max = np.argmax(scores, axis=1) perm_max = permutations[perm_max] sequence_normalized = _parallel_sort(sequence_normalized, perm_max) sequence = _parallel_sort(sequence, perm_max) for idx in range(len(permutable)): permutable[idx][:] = _parallel_sort(permutable[idx], perm_max) # local optimization for _ in range(local_iter): for bin_idx in range(n_bins): min_idx = max(0, bin_idx - 3) max_idx = min(n_bins - 1, bin_idx + 3) covariant_indices = set(range(min_idx, bin_idx)) | set(range(bin_idx + 1, max_idx + 1)) min_idx = max(0, bin_idx // 2 - 1) max_idx = min(n_bins - 1, bin_idx // 2 + 1) covariant_indices |= set(range(min_idx, max_idx + 1)) min_idx = max(0, 2 * bin_idx - 1) max_idx = min(n_bins - 1, 2 * bin_idx + 1) covariant_indices |= set(range(min_idx, max_idx + 1)) # deterministic covariant_indices = sorted(list(covariant_indices)) covariant_sequence = sequence_normalized[covariant_indices] scores = [] for perm in permutations: num = np.mean( sequence_normalized[bin_idx, perm, na] * covariant_sequence[:, na], axis=-1, ) denom = flooring_fn(centroid_std) corr = num / denom score = np.sum(eye * corr - (1 - eye) * corr, axis=(1, 2)) score = score.sum(axis=0) scores.append(score) scores = np.stack(scores, axis=0) perm_max = np.argmax(scores, axis=0) perm_max = permutations[perm_max] sequence_normalized[bin_idx] = sequence_normalized[bin_idx, perm_max] sequence[bin_idx] = sequence[bin_idx, perm_max] for idx in range(len(permutable)): permutable[idx][bin_idx] = permutable[idx][bin_idx, perm_max] if len(permutable) == 0: return sequence elif len(permutable) == 1: return sequence, permutable[0] else: return sequence, permutable def _parallel_sort(X: np.ndarray, indices: np.ndarray) -> np.ndarray: shape = X.shape idx = np.repeat(indices, repeats=np.prod(shape[2:]), axis=-1).reshape(shape) X = np.take_along_axis(X, idx, axis=1) return X ================================================ FILE: ssspy/algorithm/projection_back.py ================================================ from typing import Optional import numpy as np def projection_back( data_or_filter: np.ndarray, reference: Optional[np.ndarray] = None, reference_id: Optional[int] = 0, ) -> np.ndarray: r"""Projection back technique to restore scale ambiguity. The implementation is based on [#murata2001approach]_. Args: data_or_filter (numpy.ndarray): Estimated spectrograms or demixing filters. reference (numpy.ndarray, optional): Reference spectrogram. reference_id (int, optional): Reference microphone index. Default: ``0``. Returns: numpy.ndarray of rescaled estimated spectrograms or demixing filters. Examples: When you give estimated spectrograms, .. code-block:: python >>> import numpy as np >>> from ssspy.algorithm import projection_back >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> n_sources = n_channels >>> rng = np.random.default_rng(42) >>> spectrogram_mix = \ ... rng.standard_normal((n_channels, n_bins, n_frames)) \ ... + 1j * rng.standard_normal((n_channels, n_bins, n_frames)) >>> demix_filter = \ ... rng.standard_normal((n_sources, n_channels)) \ ... + 1j * rng.standard_normal((n_sources, n_channels)) >>> spectrogram_est = demix_filter @ spectrogram_mix.transpose(1, 0, 2) >>> # (n_bins, n_sources, n_frames) -> (n_sources, n_bins, n_frames) >>> spectrogram_est = spectrogram_est.transpose(1, 0, 2) >>> spectrogram_est_scaled = \ ... projection_back(spectrogram_est, reference=spectrogram_mix, reference_id=0) >>> spectrogram_est_scaled.shape (2, 2049, 128) When you give demixing filters, .. code-block:: python >>> import numpy as np >>> from ssspy.algorithm import projection_back >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> n_sources = n_channels >>> rng = np.random.default_rng(42) >>> spectrogram_mix = \ ... rng.standard_normal((n_channels, n_bins, n_frames)) \ ... + 1j * rng.standard_normal((n_channels, n_bins, n_frames)) >>> demix_filter = \ ... rng.standard_normal((n_sources, n_channels)) \ ... + 1j * rng.standard_normal((n_sources, n_channels)) >>> demix_filter_scaled = projection_back(demix_filter, reference_id=0) >>> spectrogram_est_scaled = demix_filter_scaled @ spectrogram_mix.transpose(1, 0, 2) >>> # (n_bins, n_sources, n_frames) -> (n_sources, n_bins, n_frames) >>> spectrogram_est_scaled = spectrogram_est_scaled.transpose(1, 0, 2) >>> spectrogram_est_scaled.shape (2, 2049, 128) .. [#murata2001approach] N. Murata, S. Ikeda, and A. Ziehe, "An approach to blind source separation based on temporal structure of speech signals," *Neurocomputing*, vol. 41, no. 1-4, pp. 1-24, 2001. """ if reference is None: W = data_or_filter # (*, n_sources, n_channels) scale = np.linalg.inv(W) # (*, n_channels, n_sources) if reference_id is None: scale = scale[..., np.newaxis] # (*, n_channels, n_sources, 1) scale = np.rollaxis(scale, -3, 0) # (n_channels, *, n_sources, 1) demix_filter_scaled = W * scale # (n_channels, *, n_sources, n_channels) else: scale = scale[..., reference_id, :] # (*, n_sources) demix_filter_scaled = W * scale[..., np.newaxis] # (*, n_sources, n_channels) return demix_filter_scaled else: Y = data_or_filter # (n_sources, n_bins, n_frames) X = reference # (n_channels, n_bins, n_frames) Y = Y.transpose(1, 0, 2) # (n_bins, n_sources, n_frames) X = X.transpose(1, 0, 2) # (n_bins, n_channels, n_frames) Y_Hermite = Y.transpose(0, 2, 1).conj() # (n_bins, n_frames, n_sources) XY_Hermite = X @ Y_Hermite # (n_bins, n_channels, n_sources) YY_Hermite = Y @ Y_Hermite # (n_bins, n_sources, n_sources) scale = XY_Hermite @ np.linalg.inv(YY_Hermite) # (n_bins, n_channels, n_sources) if reference_id is None: scale = scale.transpose(1, 0, 2) # (n_channels, n_bins, n_sources) Y_scaled = Y * scale[..., np.newaxis] # (n_channels, n_bins, n_sources, n_frames) output_scaled = Y_scaled.swapaxes(-3, -2) # (n_channels, n_sources, n_bins, n_frames) else: scale = scale[..., reference_id, :] # (n_bins, n_sources) Y_scaled = Y * scale[..., np.newaxis] # (n_bins, n_sources, n_frames) output_scaled = Y_scaled.swapaxes(-3, -2) # (n_sources, n_bins, n_frames) return output_scaled ================================================ FILE: ssspy/bss/__init__.py ================================================ from . import fdica, ica, ilrma, iva, mnmf __all__ = ["ica", "fdica", "iva", "ilrma", "mnmf"] ================================================ FILE: ssspy/bss/_flooring.py ================================================ import warnings import numpy as np EPS = 1e-10 def identity(input: np.ndarray) -> np.ndarray: r"""Identity function.""" warnings.warn("Use ssspy.special.identity instead.", FutureWarning) return input def max_flooring(input: np.ndarray, eps: float = EPS) -> np.ndarray: r"""Max flooring operation.""" warnings.warn("Use ssspy.special.max_flooring instead.", FutureWarning) return np.maximum(input, eps) def add_flooring(input: np.ndarray, eps: float = EPS) -> np.ndarray: r"""Add flooring operation.""" warnings.warn("Use ssspy.special.add_flooring instead.", FutureWarning) return input + eps ================================================ FILE: ssspy/bss/_psd.py ================================================ import functools import warnings from typing import Callable, Optional import numpy as np from ..special.flooring import max_flooring from ..special.psd import to_psd as _to_psd EPS = 1e-10 def to_psd( X: np.ndarray, axis1: int = -2, axis2: int = -1, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), ) -> np.ndarray: r"""Ensure matrix to be positive semidefinite. Args: X (np.ndarray): A complex Hermitian matrix. axis1 (int): Axis to be used as first axis of 2D sub-arrays. axis2 (int): Axis to be used as second axis of 2D sub-arrays. flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. Returns: Positive semidefinite matrix. """ warnings.warn("Use ssspy.special.to_psd instead.", FutureWarning) return _to_psd(X, axis1=axis1, axis2=axis2, flooring_fn=flooring_fn) ================================================ FILE: ssspy/bss/_select_pair.py ================================================ import warnings from typing import Iterable, Optional, Tuple from ..utils.select_pair import combination_pair_selector as combination_pair_selector_base from ..utils.select_pair import sequential_pair_selector as sequential_pair_selector_base def sequential_pair_selector( n_sources: int, stop: Optional[int] = None, step: int = 1, sort: bool = False ) -> Iterable[Tuple[int, int]]: r"""Select pair in pairwise update. Args: n_sources (int): Number of sources. step (int): This parameter determines step size. For instance, if ``sequential_pair_selector(n_sources=6, step=2, sort=False)``, this function yields ``0, 1``, ``2, 3``, ``4, 5``, ``0, 1``, ``2, 3``, ``4, 5``. Default: ``1``. sort (bool): Sort pair to ensure :math:`m>> for m, n in combination_pair_selector(4): ... print(m, n) 0 1 1 2 2 3 3 0 """ warnings.warn("Use ssspy.utils.select_pair.sequential_pair_selector instead.", UserWarning) yield from sequential_pair_selector_base(n_sources, stop=stop, step=step, sort=sort) def combination_pair_selector(n_sources: int, sort: bool = False) -> Iterable[Tuple[int, int]]: r"""Select pair in pairwise update. Args: n_sources (int): Number of sources. sort (bool): Sort pair to ensure :math:`m>> for m, n in combination_pair_selector(4): ... print(m, n) 0 1 0 2 0 3 1 2 1 3 2 3 """ warnings.warn("Use ssspy.utils.select_pair.combination_pair_selector instead.", UserWarning) yield from combination_pair_selector_base(n_sources, sort=sort) ================================================ FILE: ssspy/bss/_solve_permutation.py ================================================ import functools import warnings from typing import Callable, Optional import numpy as np from ..algorithm.permutation_alignment import ( correlation_based_permutation_solver as correlation_based_permutation_solver_base, ) from ..special.flooring import max_flooring EPS = 1e-10 def correlation_based_permutation_solver( separated: np.ndarray, *args, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), overwrite: bool = True, ) -> np.ndarray: r"""Solve permutaion of estimated spectrograms.""" warnings.warn( "Use ssspy.algorithm.permutation_alignment.correlation_based_permutation_solver instead.", UserWarning, ) return correlation_based_permutation_solver_base( separated, *args, flooring_fn=flooring_fn, overwrite=overwrite ) ================================================ FILE: ssspy/bss/_update_spatial_model.py ================================================ import functools from typing import Callable, Iterable, Optional, Tuple import numpy as np from ..linalg._solve import solve from ..linalg.eigh import eigh2 from ..linalg.inv import inv2 from ..linalg.lqpqm import lqpqm2 from ..special.flooring import identity, max_flooring from ..special.psd import to_psd from ..utils.select_pair import sequential_pair_selector EPS = 1e-10 def update_by_ip1( demix_filter: np.ndarray, weighted_covariance: np.ndarray, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), overwrite: bool = True, ) -> np.ndarray: r"""Update demixing filters by iterative projection. Args: demix_filter (numpy.ndarray): Demixing filters to be updated. The shape is (n_bins, n_sources, n_channels). weighted_covariance (numpy.ndarray): Weighted covariance matrix. The shape is (n_bins, n_sources, n_channels, n_channels). flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. overwrite (bool): Overwrite ``demix_filter`` if ``overwrite=True``. Default: ``True``. Returns: numpy.ndarray of updated demixing filters. The shape is (n_bins, n_sources, n_channels). """ if flooring_fn is None: flooring_fn = identity if overwrite: W = demix_filter else: W = demix_filter.copy() U = weighted_covariance n_bins, n_sources, n_channels = W.shape E = np.eye(n_sources, n_channels) # (n_sources, n_channels) E = np.tile(E, reps=(n_bins, 1, 1)) # (n_bins, n_sources, n_channels) for src_idx in range(n_sources): w_n_Hermite = W[:, src_idx, :] # (n_bins, n_channels) U_n = U[:, src_idx, :, :] e_n = E[:, src_idx, :] # (n_bins, n_n_channels) WU = W @ U_n w_n = solve(WU, e_n) # (n_bins, n_channels) wUw = w_n[:, np.newaxis, :].conj() @ U_n @ w_n[:, :, np.newaxis] wUw = np.real(wUw[..., 0]) wUw = np.maximum(wUw, 0) denom = np.sqrt(wUw) denom = flooring_fn(denom) w_n_Hermite = w_n.conj() / denom W[:, src_idx, :] = w_n_Hermite return W def update_by_ip2( demix_filter: np.ndarray, weighted_covariance: np.ndarray, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), pair_selector: Optional[Callable[[int], Iterable[Tuple[int, int]]]] = None, overwrite: bool = True, ) -> np.ndarray: r"""Update demixing filters by pairwise iterative projection [#ono2018fast]_. Args: demix_filter (numpy.ndarray): Demixing filters to be updated. The shape is (n_bins, n_sources, n_channels). weighted_covariance (numpy.ndarray): Weighted covariance matrix. The shape is (n_bins, n_sources, n_channels, n_channels). flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. pair_selector (callable, optional): Selector to choose updaing pair. If ``None`` is given, ``sequential_pair_selector`` is used. Default: ``None``. overwrite (bool): Overwrite ``demix_filter`` if ``overwrite=True``. Default: ``True``. Returns: numpy.ndarray of updated demixing filters. The shape is (n_bins, n_sources, n_channels). .. [#ono2018fast] N. Ono, \ "Fast algorithm for independent component/vector/low-rank matrix analysis \ with three or more sources," \ in *Proc. ASJ Spring meeting*, 2018 (in Japanese). """ if flooring_fn is None: flooring_fn = identity if pair_selector is None: pair_selector = sequential_pair_selector if overwrite: W = demix_filter else: W = demix_filter.copy() U = weighted_covariance _, n_sources, _ = W.shape for m, n in pair_selector(n_sources): pair = (m, n) W[:, pair, :] = update_by_ip2_one_pair( W, U[:, pair, :, :], pair=pair, flooring_fn=flooring_fn ) return W def update_by_iss1( separated: np.ndarray, weight: np.ndarray, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), ) -> np.ndarray: r"""Update estimated spectrogram by iterative source steering. Args: separated (numpy.ndarray): Estimated spectrograms to be updated. The shape is (n_sources, n_bins, n_frames). weight (numpy.ndarray): Weights for estimated spectrogram. The shape is (n_sources, n_bins, n_frames). flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. Returns: numpy.ndarray of updated spectrograms. The shape is (n_sources, n_bins, n_frames). """ if flooring_fn is None: flooring_fn = identity Y = separated varphi = weight n_sources = Y.shape[0] for src_idx in range(n_sources): Y_n = Y[src_idx] # (n_bins, n_frames) YY_n_conj = Y * Y_n.conj() YY_n = np.abs(Y_n) ** 2 num = np.mean(varphi * YY_n_conj, axis=-1) denom = np.mean(varphi * YY_n, axis=-1) denom = flooring_fn(denom) v_n = num / denom v_n[src_idx] = 1 - 1 / np.sqrt(denom[src_idx]) Y = Y - v_n[:, :, np.newaxis] * Y_n return Y def update_by_iss2( separated: np.ndarray, weight: np.ndarray, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), pair_selector: Optional[Callable[[int], Iterable[Tuple[int, int]]]] = None, ) -> np.ndarray: r"""Update estimated spectrogram by pairwise iterative source steering. Args: separated (numpy.ndarray): Estimated spectrograms to be updated. The shape is (n_sources, n_bins, n_frames). weight (numpy.ndarray): Weights for estimated spectrogram. The shape is (n_sources, n_bins, n_frames). flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. pair_selector (callable, optional): Selector to choose updaing pair. If ``None`` is given, ``sequential_pair_selector`` is used. Default: ``None``. Returns: numpy.ndarray of updated spectrograms. The shape is (n_sources, n_bins, n_frames). """ Y = separated varphi = weight n_sources = Y.shape[0] if flooring_fn is None: flooring_fn = identity if pair_selector is None: pair_selector = functools.partial(sequential_pair_selector, stop=n_sources, step=2) for m, n in pair_selector(n_sources): if m < 0: m = n_sources + m if n < 0: n = n_sources + n if m > n: ascend = False m, n = n, m else: ascend = True # Split into main and sub Y_1, Y_m, Y_2, Y_n, Y_3 = np.split(Y, [m, m + 1, n, n + 1], axis=0) Y_sub = np.concatenate([Y_1, Y_2, Y_3], axis=0) # (n_sources - 2, n_bins, n_frames) varphi_1, varphi_m, varphi_2, varphi_n, varphi_3 = np.split( varphi, [m, m + 1, n, n + 1], axis=0 ) varphi_sub = np.concatenate([varphi_1, varphi_2, varphi_3], axis=0) if ascend: Y_main = np.concatenate([Y_m, Y_n], axis=0) # (2, n_bins, n_frames) varphi_main = np.concatenate([varphi_m, varphi_n], axis=0) else: Y_main = np.concatenate([Y_n, Y_m], axis=0) # (2, n_bins, n_frames) varphi_main = np.concatenate([varphi_n, varphi_m], axis=0) YY_main = Y_main[:, np.newaxis, :, :] * Y_main[np.newaxis, :, :, :].conj() YY_sub = Y_main[:, np.newaxis, :, :] * Y_sub[np.newaxis, :, :, :].conj() YY_main = YY_main.transpose(2, 0, 1, 3) YY_sub = YY_sub.transpose(1, 2, 0, 3) Y_main = Y_main.transpose(1, 0, 2) # Sub G_sub = np.mean( varphi_sub[:, :, np.newaxis, np.newaxis, :] * YY_main[np.newaxis, :, :, :, :], axis=-1, ) F = np.mean(varphi_sub[:, :, np.newaxis, :] * YY_sub, axis=-1) Q = -inv2(G_sub) @ F[:, :, :, np.newaxis] Q = Q.squeeze(axis=-1) Q = Q.transpose(1, 0, 2) QY = Q.conj() @ Y_main Y_sub = Y_sub + QY.transpose(1, 0, 2) # Main G_main = np.mean( varphi_main[:, :, np.newaxis, np.newaxis, :] * YY_main[np.newaxis, :, :, :, :], axis=-1, ) G_m, G_n = G_main _, H_mn = eigh2(G_m, G_n) h_mn = H_mn.transpose(2, 0, 1) hGh_mn = h_mn[:, :, np.newaxis, :].conj() @ G_main @ h_mn[:, :, :, np.newaxis] hGh_mn = np.squeeze(hGh_mn, axis=-1) hGh_mn = np.real(hGh_mn) hGh_mn = np.maximum(hGh_mn, 0) denom_mn = np.sqrt(hGh_mn) denom_mn = flooring_fn(denom_mn) P = h_mn / denom_mn P = P.transpose(1, 0, 2) Y_main = P.conj() @ Y_main Y_main = Y_main.transpose(1, 0, 2) # Concat Y_m, Y_n = np.split(Y_main, [1], axis=0) Y1, Y2, Y3 = np.split(Y_sub, [m, n - 1], axis=0) if ascend: Y = np.concatenate([Y1, Y_m, Y2, Y_n, Y3], axis=0) else: Y = np.concatenate([Y1, Y_n, Y2, Y_m, Y3], axis=0) return Y def update_by_ip2_one_pair( demix_filter: np.ndarray, weighted_covariance_pair: np.ndarray, pair: Tuple[int], flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), ) -> np.ndarray: r"""Update demixing filters by pairwise iterative projection. Args: demix_filter (numpy.ndarray): Demixing filters. The shape is (n_bins, n_sources, n_channels). weighted_covariance_pair (numpy.ndarray): Weighted covariance matrix. The shape is (n_bins, 2, n_channels, n_channels). pair (tuple): Pair of source index to be updated. flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. Returns: numpy.ndarray of updated demixing filter pair. The shape is (n_bins, 2, n_channels). """ if flooring_fn is None: flooring_fn = identity m, n = pair W = demix_filter U_m, U_n = weighted_covariance_pair.transpose(1, 0, 2, 3) n_bins, n_sources, n_channels = W.shape E = np.eye(n_channels, n_sources) E_mn = E[:, (m, n)] E_mn = np.tile(E_mn, reps=(n_bins, 1, 1)) WU_m = W @ U_m WU_n = W @ U_n P_m = solve(WU_m, E_mn) P_n = solve(WU_n, E_mn) PUP_m = P_m.transpose(0, 2, 1).conj() @ U_m @ P_m PUP_n = P_n.transpose(0, 2, 1).conj() @ U_n @ P_n _, H_mn = eigh2(PUP_m, PUP_n) H_mn = H_mn[..., ::-1] H_mn = H_mn.transpose(2, 0, 1) h_m, h_n = H_mn hUh_m = h_m[:, np.newaxis, :].conj() @ PUP_m @ h_m[:, :, np.newaxis] hUh_m = np.real(hUh_m[..., 0]) hUh_m = np.maximum(hUh_m, 0) denom = np.sqrt(hUh_m) denom = flooring_fn(denom) h_m = h_m / denom hUh_n = h_n[:, np.newaxis, :].conj() @ PUP_n @ h_n[:, :, np.newaxis] hUh_n = np.real(hUh_n[..., 0]) hUh_n = np.maximum(hUh_n, 0) denom = np.sqrt(hUh_n) denom = flooring_fn(denom) h_n = h_n / denom w_m = P_m @ h_m[..., np.newaxis] w_n = P_n @ h_n[..., np.newaxis] W_mn_conj = np.concatenate([w_m, w_n], axis=-1) W_mn = W_mn_conj.transpose(0, 2, 1).conj() return W_mn def update_by_ipa( separated: np.ndarray, weight: np.ndarray, normalization: bool = True, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), max_iter: int = 1, ) -> np.ndarray: r"""Update estimated spectrogram by iterative projection with adjustment (IPA). Args: separated (numpy.ndarray): Estimated spectrograms to be updated. The shape is (n_sources, n_bins, n_frames). weight (numpy.ndarray): Weights for estimated spectrogram. The shape is (n_sources, n_bins, n_frames). normalization (bool): If ``normalization=True``, normalization is applied to LQPQM. Default: ``True``. flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. max_iter (int): Maximum number of Newton-Raphson method. Default: ``1``. Returns: numpy.ndarray of estimated spectrograms of shape (n_sources, n_bins, n_frames). """ if flooring_fn is None: flooring_fn = identity Y = separated varphi = weight n_sources = Y.shape[0] E = np.eye(n_sources) for source_idx in range(n_sources): YY_conj = Y[:, np.newaxis] * Y[np.newaxis, :].conj() U_tilde = np.mean(varphi[:, np.newaxis, np.newaxis] * YY_conj, axis=-1) U_tilde = U_tilde.transpose(3, 0, 1, 2) U_tilde = to_psd(U_tilde, axis1=-2, axis2=-1, flooring_fn=flooring_fn) E_n_left, e_n, E_n_right = np.split(E, [source_idx, source_idx + 1], axis=-1) E_n = np.concatenate([E_n_left, E_n_right], axis=-1) U_tilde_n = U_tilde[:, source_idx, :, :] U_tilde_n_inverse = _psd_inv(U_tilde_n, flooring_fn=flooring_fn) a_n = U_tilde[:, :, source_idx, source_idx] a_n = np.real(a_n) a_n = a_n @ E_n b_n = np.diagonal(U_tilde[:, :, source_idx, :], axis1=-2, axis2=-1) b_n = b_n @ E_n d_n = E_n.transpose(1, 0) @ U_tilde_n_inverse.conj() C_n = d_n @ E_n d_n = d_n[:, :, source_idx] Cd_n = solve(C_n, d_n) dCd_n = np.sum(d_n.conj() * Cd_n, axis=-1) dCd_n = np.real(dCd_n) eUe_n = U_tilde_n_inverse[:, source_idx, source_idx] eUe_n = np.real(eUe_n) z_n = eUe_n - dCd_n a_sqrt_n = np.sqrt(a_n) aa_n = a_sqrt_n[:, :, np.newaxis] * a_sqrt_n[:, np.newaxis, :] H_n = C_n / aa_n v_n = -b_n / a_sqrt_n - a_sqrt_n * Cd_n if normalization: trace = np.trace(H_n, axis1=-2, axis2=-1) trace = np.real(trace) H_n = H_n / trace[..., np.newaxis, np.newaxis] z_n = z_n / trace q_check_n = lqpqm2( H_n, v_n, z_n, flooring_fn=flooring_fn, singular_fn=lambda x: x < flooring_fn(0), max_iter=max_iter, ) q_n = q_check_n / a_sqrt_n - b_n / a_n Eq_n = q_n.conj() @ E_n.transpose(1, 0) q_tilde_n = e_n.transpose(1, 0) - Eq_n Uq_n = solve(U_tilde_n, q_tilde_n) qUq_n = np.sum(q_tilde_n.conj() * Uq_n, axis=-1, keepdims=True) qUq_n = np.real(qUq_n) qUq_n = np.maximum(qUq_n, 0) denom = np.sqrt(qUq_n) denom = flooring_fn(denom) p_n = Uq_n / denom Y_n = Y[source_idx] p_n_conj = p_n.transpose(1, 0).conj() PY_n = np.sum(p_n_conj[..., np.newaxis] * Y, axis=0) PY_n = e_n[:, np.newaxis] * (PY_n - Y_n) Eq_n = Eq_n.transpose(1, 0) QY_n = Eq_n[:, :, np.newaxis] * Y_n Y = Y + PY_n + QY_n return Y def update_by_block_decomposition_vcd( demix_filter: np.ndarray, weighted_covariance: np.ndarray, singular_fn: Optional[Callable[[np.ndarray], np.ndarray]] = None, overwrite: bool = True, ) -> np.ndarray: r""" Args: demix_filter (numpy.ndarray): Demixing filters to be updated. The shape is (n_blocks, n_neighbors, n_sources, n_channels). weighted_covariance (numpy.ndarray): Weighted covariance matrix. The shape is (n_blocks, n_neighbors, n_neighbors, n_sources, n_channels, n_channels). singular_fn (callable, optional): A flooring function to return singular condition. This function is expected to return the same shape bool tensor as the input. If ``singular_fn=None``,``lambda x: x == 0`` is used. overwrite (bool): Overwrite ``demix_filter`` if ``overwrite=True``. Default: ``True``. Returns: numpy.ndarray of updated demixing filters. The shape is (n_blocks, n_neighbors, n_sources, n_channels). """ na = np.newaxis if singular_fn is None: def _is_zero(x: np.ndarray) -> np.ndarray: return x == 0 singular_fn = _is_zero if overwrite: W = demix_filter else: W = demix_filter.copy() RXX = weighted_covariance U = np.diagonal(RXX, axis1=1, axis2=2) n_blocks, n_neighbors, n_sources, n_channels = W.shape E_i = np.eye(n_neighbors) E_n = np.eye(n_sources) E_n = np.tile(E_n, reps=(n_blocks, 1, 1)) for neighbor_idx in range(n_neighbors): pad_mask_i = 1 - E_i[neighbor_idx] U_i = U[:, :, :, :, neighbor_idx] RXX_i = RXX[:, neighbor_idx] for source_idx in range(n_sources): e_n = E_n[:, source_idx, :] U_in = U_i[:, source_idx, :, :] RXX_in = RXX_i[:, :, source_idx] w_n_conj = W[:, :, source_idx, :].conj() RXY_in = RXX_in @ w_n_conj[:, :, :, na] gamma_in = np.sum(pad_mask_i[:, na] * RXY_in[..., 0], axis=1) WU_in = W[:, neighbor_idx, :, :] @ U_in eta_in = solve(WU_in, e_n) eta_hat_in = solve(U_in, gamma_in) eta_U_in = eta_in[:, na, :].conj() @ U_in xi_in = eta_U_in @ eta_in[:, :, na] xi_hat_in = eta_U_in @ eta_hat_in[:, :, na] xi_in = np.real(xi_in[..., 0]) xi_in = np.maximum(xi_in, 0) xi_hat_in = xi_hat_in[..., 0] singular_condition = singular_fn(xi_hat_in) # to avoid zero division, but these will be ignored. xi_hat_in[singular_condition] = 1 coeff = (xi_hat_in / (2 * xi_in)) * ( 1 - np.sqrt(1 + 4 * xi_in / (np.abs(xi_hat_in) ** 2)) ) coeff_singular = 1 / np.sqrt(xi_in) coeff = np.where(singular_condition, coeff_singular, coeff) w_in = coeff * eta_in - eta_hat_in W[:, neighbor_idx, source_idx, :] = w_in.conj() return W def _psd_inv( X: np.ndarray, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), ) -> np.ndarray: """Compute inversion of positive semidefinite matrix. Args: X (np.ndarray): Positive semidefinite matrix of shape (*, N, N). flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. Returns: np.ndarray: Inversion of input matrix. """ if flooring_fn is None: flooring_fn = identity Lamb, P = np.linalg.eigh(X) P_Hermite = P.swapaxes(-2, -1) if np.iscomplexobj(X): P_Hermite = P_Hermite.conj() Lamb_inv = 1 / flooring_fn(Lamb) Lamb_inv = Lamb_inv[..., np.newaxis] * np.eye(Lamb.shape[-1]) return P @ Lamb_inv @ P_Hermite ================================================ FILE: ssspy/bss/admmbss.py ================================================ import warnings from typing import Callable, List, Optional, Union import numpy as np from ..linalg import prox from ..linalg._solve import solve from .proxbss import ProxBSSBase EPS = 1e-10 __all__ = ["ADMMBSS", "MaskingADMMBSS"] class ADMMBSSBase(ProxBSSBase): """Base class of blind source separation via alternative direction method of multiplier. Args: penalty_fn (callable): Penalty function that determines source model. prox_penalty (callable): Proximal operator of penalty function. Default: ``None``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` explicitly. Default: ``True``. record_loss (bool): Record the loss at each iteration of the update algorithm if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back. Default: ``0``. """ def __repr__(self) -> str: s = "ADMMBSS(" s += "n_penalties={n_penalties}".format(n_penalties=self.n_penalties) s += ", scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) class ADMMBSS(ADMMBSSBase): """Base class of blind source separation via alternative direction method of multiplier. Args: rho (float): Penalty parameter. Default: ``1``. alpha (float): Relaxation parameter (deprecated). Set ``relaxation`` instead. relaxation (float): Relaxation parameter. Default: ``1``. penalty_fn (callable): Penalty function that determines source model. prox_penalty (callable): Proximal operator of penalty function. Default: ``None``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` explicitly. Default: ``True``. record_loss (bool): Record the loss at each iteration of the update algorithm if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back. Default: ``0``. """ def __init__( self, rho: float = 1, alpha: float = None, relaxation: float = 1, penalty_fn: Callable[[np.ndarray, np.ndarray], float] = None, prox_penalty: Callable[[np.ndarray, float], np.ndarray] = None, callbacks: Optional[ Union[Callable[["ADMMBSS"], None], List[Callable[["ADMMBSS"], None]]] ] = None, scale_restoration: bool = True, record_loss: bool = True, reference_id: int = 0, ) -> None: super().__init__( penalty_fn=penalty_fn, prox_penalty=prox_penalty, callbacks=callbacks, scale_restoration=scale_restoration, record_loss=record_loss, reference_id=reference_id, ) self.rho = rho if alpha is None: self.relaxation = relaxation else: assert relaxation == 1, "You cannot specify relaxation and alpha simultaneously." warnings.warn("alpha is deprecated. Set relaxation instead.", DeprecationWarning) self.relaxation = alpha def __call__(self, input, n_iter=100, initial_call: bool = True, **kwargs) -> np.ndarray: r"""Separate a frequency-domain multichannel signal. Args: input (numpy.ndarray): Mixture signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). n_iter (int): Number of iterations of demixing filter updates. Default: ``100``. initial_call (bool): If ``True``, perform callbacks (and computation of loss if necessary) before iterations. Returns: numpy.ndarray of the separated signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). """ self.input = input.copy() self._reset(**kwargs) # Call __call__ of ADMMBSSBase's parent, i.e. __call__ of IterativeMethodBase super(ADMMBSSBase, self).__call__(n_iter=n_iter, initial_call=initial_call) if self.scale_restoration: self.restore_scale() self.output = self.separate(self.input, demix_filter=self.demix_filter) return self.output def __repr__(self) -> str: s = "ADMMBSS(" s += "rho={rho}" s += ", relaxation={relaxation}" s += ", n_penalties={n_penalties}".format(n_penalties=self.n_penalties) s += ", scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) def _reset(self, **kwargs) -> None: r"""Reset attributes by given keyword arguments. Args: kwargs: Keyword arguments to set as attributes of ADMMBSS. """ if "aux1" in kwargs.keys(): warnings.warn("aux1 is deprecated. Use auxiliary1 instead.", DeprecationWarning) kwargs["auxiliary1"] = kwargs.pop("aux1") if "aux2" in kwargs.keys(): warnings.warn("aux2 is deprecated. Use auxiliary2 instead.", DeprecationWarning) kwargs["auxiliary2"] = kwargs.pop("aux2") super()._reset(**kwargs) n_penalties = self.n_penalties n_sources, n_channels = self.n_sources, self.n_channels n_bins, n_frames = self.n_bins, self.n_frames if not hasattr(self, "auxiliary1"): auxiliary1 = np.zeros((n_bins, n_sources, n_channels), dtype=np.complex128) else: # To avoid overwriting ``auxiliary1`` given by keyword arguments. auxiliary1 = self.auxiliary1.copy() if not hasattr(self, "auxiliary2"): auxiliary2 = np.zeros((n_penalties, n_sources, n_bins, n_frames), dtype=np.complex128) else: # To avoid overwriting ``auxiliary2`` given by keyword arguments. auxiliary2 = self.auxiliary2.copy() if not hasattr(self, "dual1"): dual1 = np.zeros((n_bins, n_sources, n_channels), dtype=np.complex128) else: # To avoid overwriting ``dual1`` given by keyword arguments. dual1 = self.dual1.copy() if not hasattr(self, "dual2"): dual2 = np.zeros((n_penalties, n_sources, n_bins, n_frames), dtype=np.complex128) else: # To avoid overwriting ``dual2`` given by keyword arguments. dual2 = self.dual2.copy() self.auxiliary1 = auxiliary1 self.auxiliary2 = auxiliary2 self.dual1 = dual1 self.dual2 = dual2 def update_once(self) -> None: r"""Update demixing filters, auxiliary parameters, and dual parameters once.""" n_penalties = self.n_penalties n_channels = self.n_channels rho, alpha = self.rho, self.relaxation V, V_tilde = self.auxiliary1, self.auxiliary2 Y, Y_tilde = self.dual1, self.dual2 X, W = self.input, self.demix_filter XX = X.transpose(1, 0, 2).conj() @ X.transpose(1, 2, 0) E = np.eye(n_channels) VY = V - Y VY_tilde = np.sum(V_tilde - Y_tilde, axis=0) XVY_tilde = X.transpose(1, 0, 2).conj() @ VY_tilde.transpose(1, 2, 0) W = solve(n_penalties * XX + E, VY + XVY_tilde.transpose(0, 2, 1)) XW = self.separate(X, demix_filter=W) U = alpha * W + (1 - alpha) * V U_tilde = alpha * XW + (1 - alpha) * V_tilde V = prox.neg_logdet(U + Y, step_size=1 / rho) V_tilde = [] for U_tilde_q, Y_tilde_q, prox_penalty in zip(U_tilde, Y_tilde, self.prox_penalty): V_tilde_q = prox_penalty(U_tilde_q + Y_tilde_q, step_size=1 / rho) V_tilde.append(V_tilde_q) V_tilde = np.stack(V_tilde, axis=0) Y = Y + U - V Y_tilde = Y_tilde + U_tilde - V_tilde self.auxiliary1, self.auxiliary2 = V, V_tilde self.dual1, self.dual2 = Y, Y_tilde self.demix_filter = W class MaskingADMMBSS(ADMMBSSBase): """Blind source separation via alternative direction method of multiplier with masking function. Args: rho (float): Penalty parameter. Default: ``1``. alpha (float): Relaxation parameter (deprecated). Set ``relaxation`` instead. relaxation (float): Relaxation parameter. Default: ``1``. penalty_fn (callable): Penalty function that determines source model. mask_fn (callable): Proximal operator of penalty function. Default: ``None``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` explicitly. Default: ``True``. record_loss (bool, optional): Record the loss at each iteration of the update algorithm if ``record_loss=True``. Default: ``None``. reference_id (int): Reference channel for projection back. Default: ``0``. """ def __init__( self, rho: float = 1, alpha: float = None, relaxation: float = 1, penalty_fn: Callable[[np.ndarray, np.ndarray], float] = None, mask_fn: Callable[[np.ndarray], float] = None, callbacks: Optional[ Union[Callable[["MaskingADMMBSS"], None], List[Callable[["MaskingADMMBSS"], None]]] ] = None, scale_restoration: bool = True, record_loss: Optional[bool] = None, reference_id: int = 0, ) -> None: super(ProxBSSBase, self).__init__( callbacks=callbacks, record_loss=record_loss, ) if penalty_fn is None: # Since penalty_fn is not necessarily written in closed form, # None is acceptable. if record_loss is None: record_loss = False assert not record_loss, "To record loss, set penalty_fn." else: assert callable(penalty_fn), "penalty_fn should be callable." if record_loss is None: record_loss = True if mask_fn is None: raise ValueError("Specify masking function.") else: assert callable(mask_fn), "mask_fn should be callable." self.penalty_fn = penalty_fn self.mask_fn = mask_fn self.input = None self.scale_restoration = scale_restoration if reference_id is None and scale_restoration: raise ValueError("Specify 'reference_id' if scale_restoration=True.") else: self.reference_id = reference_id self.rho = rho if alpha is None: self.relaxation = relaxation else: assert relaxation == 1, "You cannot specify relaxation and alpha simultaneously." warnings.warn("alpha is deprecated. Set relaxation instead.", DeprecationWarning) self.relaxation = alpha def __call__(self, input, n_iter=100, initial_call: bool = True, **kwargs) -> np.ndarray: self.input = input.copy() self._reset(**kwargs) # Call __call__ of ADMMBSSBase's parent, i.e. __call__ of IterativeMethodBase super(ADMMBSSBase, self).__call__(n_iter=n_iter, initial_call=initial_call) if self.scale_restoration: self.restore_scale() self.output = self.separate(self.input, demix_filter=self.demix_filter) return self.output def _reset(self, **kwargs) -> None: r"""Reset attributes by given keyword arguments. Args: kwargs: Keyword arguments to set as attributes of ADMMBSS. """ if "aux1" in kwargs.keys(): warnings.warn("aux1 is deprecated. Use auxiliary1 instead.", DeprecationWarning) kwargs["auxiliary1"] = kwargs.pop("aux1") if "aux2" in kwargs.keys(): warnings.warn("aux2 is deprecated. Use auxiliary2 instead.", DeprecationWarning) kwargs["auxiliary2"] = kwargs.pop("aux2") super()._reset(**kwargs) assert self.n_penalties == 1, "Number of penalty function should be one." n_sources, n_channels = self.n_sources, self.n_channels n_bins, n_frames = self.n_bins, self.n_frames if not hasattr(self, "auxiliary1"): auxiliary1 = np.zeros((n_bins, n_sources, n_channels), dtype=np.complex128) else: # To avoid overwriting ``auxiliary1`` given by keyword arguments. auxiliary1 = self.auxiliary1.copy() if not hasattr(self, "auxiliary2"): auxiliary2 = np.zeros((n_sources, n_bins, n_frames), dtype=np.complex128) else: # To avoid overwriting ``auxiliary2`` given by keyword arguments. auxiliary2 = self.auxiliary2.copy() if not hasattr(self, "dual1"): dual1 = np.zeros((n_bins, n_sources, n_channels), dtype=np.complex128) else: # To avoid overwriting ``dual1`` given by keyword arguments. dual1 = self.dual1.copy() if not hasattr(self, "dual2"): dual2 = np.zeros((n_sources, n_bins, n_frames), dtype=np.complex128) else: # To avoid overwriting ``dual2`` given by keyword arguments. dual2 = self.dual2.copy() self.auxiliary1 = auxiliary1 self.auxiliary2 = auxiliary2 self.dual1 = dual1 self.dual2 = dual2 @property def n_penalties(self) -> int: r"""Return number of penalty terms.""" return 1 def update_once(self) -> None: r"""Update demixing filters, auxiliary parameters, and dual parameters once.""" n_channels = self.n_channels rho, alpha = self.rho, self.relaxation V, V_tilde = self.auxiliary1, self.auxiliary2 Y, Y_tilde = self.dual1, self.dual2 X, W = self.input, self.demix_filter XX = X.transpose(1, 0, 2).conj() @ X.transpose(1, 2, 0) E = np.eye(n_channels) VY = V - Y VY_tilde = V_tilde - Y_tilde XVY_tilde = X.transpose(1, 0, 2).conj() @ VY_tilde.transpose(1, 2, 0) W = solve(XX + E, VY + XVY_tilde.transpose(0, 2, 1)) XW = self.separate(X, demix_filter=W) U = alpha * W + (1 - alpha) * V U_tilde = alpha * XW + (1 - alpha) * V_tilde V = prox.neg_logdet(U + Y, step_size=1 / rho) V_tilde = self.mask_fn(U_tilde + Y_tilde) * (U_tilde + Y_tilde) Y = Y + U - V Y_tilde = Y_tilde + U_tilde - V_tilde self.auxiliary1, self.auxiliary2 = V, V_tilde self.dual1, self.dual2 = Y, Y_tilde self.demix_filter = W ================================================ FILE: ssspy/bss/base.py ================================================ from typing import Callable, List, Optional, Union import numpy as np __all__ = [ "IterativeMethodBase", ] class IterativeMethodBase: r"""Base class of iterative method. This class provides prototype of iterative updates. Args: callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. record_loss (bool): Record the loss at each iteration of the update algorithm if ``record_loss=True``. Default: ``True``. """ def __init__( self, callbacks: Optional[ Union[ Callable[["IterativeMethodBase"], None], List[Callable[["IterativeMethodBase"], None]], ] ] = None, record_loss: bool = True, ) -> None: if callbacks is not None: if callable(callbacks): callbacks = [callbacks] self.callbacks = callbacks else: self.callbacks = None self.record_loss = record_loss if self.record_loss: self.loss = [] else: self.loss = None def __call__(self, *args, n_iter: int = 100, initial_call: bool = True, **kwargs) -> np.ndarray: r"""Iteratively call ``update_once``. Args: n_iter (int): The number of iterations of demixing filter updates. Default: ``100``. initial_call (bool): If ``True``, perform callbacks (and computation of loss if necessary) before iterations. """ if initial_call: if self.record_loss: loss = self.compute_loss() self.loss.append(loss) if self.callbacks is not None: for callback in self.callbacks: callback(self) for _ in range(n_iter): self.update_once() if self.record_loss: loss = self.compute_loss() self.loss.append(loss) if self.callbacks is not None: for callback in self.callbacks: callback(self) def update_once(self) -> None: r"""Update parameters once.""" raise NotImplementedError("Implement 'update_once' method.") def compute_loss(self) -> float: r"""Compute loss. Returns: Computed loss. The type is expected ``float``. """ raise NotImplementedError("Implement 'compute_loss' method.") ================================================ FILE: ssspy/bss/cacgmm.py ================================================ import functools from typing import Callable, List, Optional, Union import numpy as np from ..algorithm.permutation_alignment import ( correlation_based_permutation_solver, score_based_permutation_solver, ) from ..linalg.quadratic import quadratic from ..special.flooring import identity, max_flooring from ..special.logsumexp import logsumexp from ..special.psd import to_psd from ..special.softmax import softmax from ..utils.flooring import choose_flooring_fn from .base import IterativeMethodBase EPS = 1e-10 class CACGMMBase(IterativeMethodBase): r"""Base class of complex angular central Gaussian mixture model (cACGMM). Args: n_sources (int, optional): Number of sources to be separated. If ``None`` is given, ``n_sources`` is determined by number of channels in input spectrogram. Default: ``None``. flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. record_loss (bool): Record the loss at each iteration of the update algorithm if ``record_loss=True``. Default: ``True``. rng (numpy.random.Generator, optioinal): Random number generator. This is mainly used to randomly initialize parameters of cACGMM. If ``None`` is given, ``np.random.default_rng()`` is used. Default: ``None``. """ def __init__( self, n_sources: Optional[int] = None, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), callbacks: Optional[ Union[ Callable[["CACGMMBase"], None], List[Callable[["CACGMMBase"], None]], ] ] = None, record_loss: bool = True, rng: Optional[np.random.Generator] = None, ) -> None: self.normalization: bool self.permutation_alignment: bool super().__init__(callbacks=callbacks, record_loss=record_loss) self.n_sources = n_sources if flooring_fn is None: self.flooring_fn = identity else: self.flooring_fn = flooring_fn if rng is None: rng = np.random.default_rng() self.rng = rng def __call__( self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs ) -> np.ndarray: r"""Separate a frequency-domain multichannel signal. Args: input (numpy.ndarray): The mixture signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). n_iter (int): The number of iterations of demixing filter updates. Default: ``100``. initial_call (bool): If ``True``, perform callbacks (and computation of loss if necessary) before iterations. Returns: numpy.ndarray of the separated signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). """ self.input = input.copy() self._reset(**kwargs) raise NotImplementedError("Implement '__call__' method.") def __repr__(self) -> str: s = "CACGMM(" if self.n_sources is not None: s += "n_sources={n_sources}, " s += "record_loss={record_loss}" s += ")" return s.format(**self.__dict__) def _reset( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", **kwargs, ) -> None: r"""Reset attributes by given keyword arguments. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. kwargs: Keyword arguments to set as attributes of CACGMM. """ assert self.input is not None, "Specify data!" flooring_fn = choose_flooring_fn(flooring_fn, method=self) for key in kwargs.keys(): setattr(self, key, kwargs[key]) X = self.input norm = np.linalg.norm(X, axis=0) Z = X / flooring_fn(norm) self.unit_input = Z n_sources = self.n_sources n_channels, n_bins, n_frames = X.shape if n_sources is None: n_sources = n_channels self.n_sources, self.n_channels = n_sources, n_channels self.n_bins, self.n_frames = n_bins, n_frames self._init_parameters(rng=self.rng) def _init_parameters(self, rng: Optional[np.random.Generator] = None) -> None: r"""Initialize parameters of cACGMM. Args: rng (numpy.random.Generator, optional): Random number generator. If ``None`` is given, ``np.random.default_rng()`` is used. Default: ``None``. .. note:: Custom initialization is not supported now. """ n_sources, n_channels = self.n_sources, self.n_channels n_bins = self.n_bins if rng is None: rng = np.random.default_rng() alpha = rng.random((n_sources, n_bins)) alpha = alpha / alpha.sum(axis=0) eye = np.eye(n_channels, dtype=np.complex128) B_diag = self.rng.random((n_sources, n_bins, n_channels)) B_diag = B_diag / B_diag.sum(axis=-1, keepdims=True) B = B_diag[:, :, :, np.newaxis] * eye self.mixing = alpha self.covariance = B # The shape of posterior is (n_sources, n_bins, n_frames). # This is always required to satisfy posterior.sum(axis=0) = 1 self.posterior = None def separate(self, input: np.ndarray) -> np.ndarray: r"""Separate ``input``. Args: input (numpy.ndarray): The mixture signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). Returns: numpy.ndarray of the separated signal in frequency-domain. The shape is (n_sources, n_bins, n_frames). """ raise NotImplementedError("Implement 'separate' method.") def normalize_covariance(self) -> None: r"""Normalize covariance of cACG. .. math:: \boldsymbol{B}_{in} \leftarrow\frac{\boldsymbol{B}_{in}}{\mathrm{tr}(\boldsymbol{B}_{in})} """ assert self.normalization, "Set normalization." B = self.covariance trace = np.trace(B, axis1=-2, axis2=-1) trace = np.real(trace) B = B / trace[..., np.newaxis, np.newaxis] self.covariance = B def compute_loss(self) -> float: r"""Compute loss :math:`\mathcal{L}`. Returns: Computed loss. """ raise NotImplementedError("Implement 'compute_loss' method.") def compute_logdet(self, covariance: np.ndarray) -> np.ndarray: r"""Compute log-determinant of input. Args: covariance (numpy.ndarray): Covariance matrix with shape of (n_sources, n_bins, n_channels, n_channels). Returns: numpy.ndarray of log-determinant. """ _, logdet = np.linalg.slogdet(covariance) return logdet def solve_permutation( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Align posteriors and separated spectrograms. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ permutation_alignment = self.permutation_alignment flooring_fn = choose_flooring_fn(flooring_fn, method=self) assert permutation_alignment, "Set permutation_alignment=True." if type(permutation_alignment) is bool: # when permutation_alignment is True permutation_alignment = "posterior_score" if permutation_alignment in ["posterior_score", "posterior_correlation"]: target = "posterior" elif permutation_alignment in ["amplitude_score", "amplitude_correlation"]: target = "amplitude" else: raise NotImplementedError( "permutation_alignment {} is not implemented.".format(permutation_alignment) ) if permutation_alignment in ["posterior_score", "amplitude_score"]: self.solve_permutation_by_score(target=target, flooring_fn=flooring_fn) elif permutation_alignment in ["posterior_correlation", "amplitude_correlation"]: self.solve_permutation_by_correlation(target=target, flooring_fn=flooring_fn) else: raise NotImplementedError( "permutation_alignment {} is not implemented.".format(permutation_alignment) ) def solve_permutation_by_score( self, target: str = "posterior", flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Align posteriors and amplitudes of separated spectrograms by score value. Args: target (str): Target to compute score values. Choose ``posterior`` or ``amplitude``. Default: ``posterior``. flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ assert target in ["posterior", "amplitude"], "Invalid target {} is specified.".format( target ) flooring_fn = choose_flooring_fn(flooring_fn, method=self) X = self.input alpha = self.mixing B = self.covariance gamma = self.posterior if hasattr(self, "global_iter"): global_iter = self.global_iter else: global_iter = 1 if hasattr(self, "local_iter"): local_iter = self.local_iter else: local_iter = 1 Y = self.separate(X, posterior=gamma) alpha = alpha.transpose(1, 0) B = B.transpose(1, 0, 2, 3) gamma = gamma.transpose(1, 0, 2) if target == "posterior": gamma, (alpha, B) = score_based_permutation_solver( gamma, alpha, B, global_iter=global_iter, local_iter=local_iter, flooring_fn=flooring_fn, ) elif target == "amplitude": Y = Y.transpose(1, 0, 2) amplitude = np.abs(Y) _, (alpha, B, gamma) = score_based_permutation_solver( amplitude, alpha, B, gamma, global_iter=global_iter, local_iter=local_iter, flooring_fn=flooring_fn, ) else: raise ValueError("Invalid target {} is specified.".format(target)) alpha = alpha.transpose(1, 0) B = B.transpose(1, 0, 2, 3) gamma = gamma.transpose(1, 0, 2) Y = self.separate(X, posterior=gamma) self.mixing = alpha self.covariance = B self.posterior = gamma self.output = Y def solve_permutation_by_correlation( self, target: str = "amplitude", flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Align posteriors and amplitudes of separated spectrograms by correlation. Args: target (str): Target to compute correlations. Choose ``posterior`` or ``amplitude``. Default: ``amplitude``. flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ assert target == "amplitude", "Only amplitude is supported as target." flooring_fn = choose_flooring_fn(flooring_fn, method=self) X = self.input alpha = self.mixing B = self.covariance gamma = self.posterior Y = self.separate(X, posterior=self.posterior) alpha = alpha.transpose(1, 0) B = B.transpose(1, 0, 2, 3) gamma = gamma.transpose(1, 0, 2) Y = Y.transpose(1, 0, 2) Y, (alpha, B, gamma) = correlation_based_permutation_solver( Y, alpha, B, gamma, flooring_fn=flooring_fn ) alpha = alpha.transpose(1, 0) B = B.transpose(1, 0, 2, 3) gamma = gamma.transpose(1, 0, 2) Y = Y.transpose(1, 0, 2) self.mixing = alpha self.covariance = B self.posterior = gamma self.output = Y class CACGMM(CACGMMBase): r"""Complex angular central Gaussian mixture model (cACGMM) [#ito2016complex]_. Args: n_sources (int, optional): Number of sources to be separated. If ``None`` is given, ``n_sources`` is determined by number of channels in input spectrogram. Default: ``None``. flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. normalization (bool): If ``True`` is given, normalization is applied to covariance in cACG. permutation_alignment (bool): If ``permutation_alignment=True``, a permutation solver is used to align estimated spectrograms. Default: ``True``. record_loss (bool): Record the loss at each iteration of the update algorithm if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel to extract separated signals. Default: ``0``. rng (numpy.random.Generator, optioinal): Random number generator. This is mainly used to randomly initialize parameters of cACGMM. If ``None`` is given, ``np.random.default_rng()`` is used. Default: ``None``. .. [#ito2016complex] N. Ito, S. Araki, and T. Nakatani. \ "Complex angular central Gaussian mixture model for directional statistics \ in mask-based microphone array signal processing," in *Proc. EUSIPCO*, 2016, pp. 1153-1157. """ def __init__( self, n_sources: Optional[int] = None, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), callbacks: Optional[ Union[ Callable[["CACGMM"], None], List[Callable[["CACGMM"], None]], ] ] = None, normalization: bool = True, permutation_alignment: bool = True, record_loss: bool = True, reference_id: int = 0, rng: Optional[np.random.Generator] = None, **kwargs, ) -> None: super().__init__( n_sources=n_sources, flooring_fn=flooring_fn, callbacks=callbacks, record_loss=record_loss, rng=rng, ) self.normalization = normalization self.permutation_alignment = permutation_alignment self.reference_id = reference_id if type(permutation_alignment) is bool and permutation_alignment: valid_keys = {"global_iter", "local_iter"} elif type(permutation_alignment) is str and permutation_alignment in [ "posterior_score", "amplitude_score", ]: valid_keys = {"global_iter", "local_iter"} else: valid_keys = set() invalid_keys = set(kwargs) - valid_keys assert invalid_keys == set(), "Invalid keywords {} are given.".format(invalid_keys) for key, value in kwargs.items(): setattr(self, key, value) def __call__( self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs ) -> np.ndarray: r"""Separate a frequency-domain multichannel signal. Args: input (numpy.ndarray): The mixture signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). n_iter (int): The number of iterations of demixing filter updates. Default: ``100``. initial_call (bool): If ``True``, perform callbacks (and computation of loss if necessary) before iterations. Returns: numpy.ndarray of the separated signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). """ self.input = input.copy() self._reset(flooring_fn=self.flooring_fn, **kwargs) # Call __call__ of CACGMMBase's parent, i.e. __call__ of IterativeMethodBase super(CACGMMBase, self).__call__(n_iter=n_iter, initial_call=initial_call) # posterior should be updated self.update_posterior(flooring_fn=self.flooring_fn) if self.permutation_alignment: self.solve_permutation(flooring_fn=self.flooring_fn) X = self.input self.output = self.separate(X, posterior=self.posterior) return self.output def __repr__(self) -> str: s = "CACGMM(" if self.n_sources is not None: s += "n_sources={n_sources}, " s += "record_loss={record_loss}" s += ", normalization={normalization}" s += ", permutation_alignment={permutation_alignment}" s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) def separate(self, input: np.ndarray, posterior: Optional[np.ndarray] = None) -> np.ndarray: r"""Separate ``input`` using posterior probabilities. In this method, ``self.posterior`` is not updated. Args: input (numpy.ndarray): The mixture signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). posterior (numpy.ndarray, optional): Posterior probability. If not specified, ``posterior`` is computed by current parameters. Returns: numpy.ndarray of the separated signal in frequency-domain. The shape is (n_sources, n_bins, n_frames). """ X = input if posterior is None: alpha = self.mixing Z = self.unit_input B = self.covariance Z = Z.transpose(1, 2, 0) B_inverse = np.linalg.inv(B) ZBZ = quadratic(Z, B_inverse[:, :, np.newaxis]) ZBZ = np.real(ZBZ) ZBZ = np.maximum(ZBZ, 0) ZBZ = self.flooring_fn(ZBZ) log_alpha = np.log(alpha) _, logdet = np.linalg.slogdet(B) log_prob = log_alpha - logdet log_gamma = log_prob[:, :, np.newaxis] - self.n_channels * np.log(ZBZ) gamma = softmax(log_gamma, axis=0) else: gamma = posterior return gamma * X[self.reference_id] def update_once( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self" ) -> None: r"""Perform E and M step once. In ``update_posterior``, posterior probabilities are updated, which corresponds to E step. In ``update_parameters``, parameters of cACGMM are updated, which corresponds to M step. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) self.update_posterior(flooring_fn=flooring_fn) self.update_parameters(flooring_fn=flooring_fn) if self.normalization: self.normalize_covariance() def update_posterior( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self" ) -> None: r"""Update posteriors. This method corresponds to E step in EM algorithm for cACGMM. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) alpha = self.mixing Z = self.unit_input B = self.covariance Z = Z.transpose(1, 2, 0) B_inverse = np.linalg.inv(B) ZBZ = quadratic(Z, B_inverse[:, :, np.newaxis]) ZBZ = np.real(ZBZ) ZBZ = np.maximum(ZBZ, 0) ZBZ = flooring_fn(ZBZ) log_prob = np.log(alpha) - self.compute_logdet(B) log_gamma = log_prob[:, :, np.newaxis] - self.n_channels * np.log(ZBZ) gamma = softmax(log_gamma, axis=0) self.posterior = gamma def update_parameters( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self" ) -> None: r"""Update parameters of mixture of complex angular central Gaussian distributions. This method corresponds to M step in EM algorithm for cACGMM. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) Z = self.unit_input B = self.covariance gamma = self.posterior Z = Z.transpose(1, 2, 0) B_inverse = np.linalg.inv(B) ZBZ = quadratic(Z, B_inverse[:, :, np.newaxis]) ZBZ = np.real(ZBZ) ZBZ = np.maximum(ZBZ, 0) ZBZ = flooring_fn(ZBZ) ZZ = Z[:, :, :, np.newaxis] * Z[:, :, np.newaxis, :].conj() alpha = np.mean(gamma, axis=-1) GZBZ = gamma / ZBZ num = np.sum(GZBZ[:, :, :, np.newaxis, np.newaxis] * ZZ, axis=2) denom = np.sum(gamma, axis=2) B = self.n_channels * (num / denom[:, :, np.newaxis, np.newaxis]) B = to_psd(B, flooring_fn=flooring_fn) self.mixing = alpha self.covariance = B def compute_loss(self) -> float: r"""Compute loss of cACGMM :math:`\mathcal{L}`. :math:`\mathcal{L}` is defined as follows: .. math:: \mathcal{L} = -\frac{1}{J}\sum_{i,j}\log\left( \sum_{n}\frac{\alpha_{in}}{\det\boldsymbol{B}_{in}} \frac{1}{(\boldsymbol{z}_{ij}^{\mathsf{H}}\boldsymbol{B}_{in}^{-1}\boldsymbol{z}_{ij})^{M}} \right). """ alpha = self.mixing Z = self.unit_input B = self.covariance Z = Z.transpose(1, 2, 0) B_inverse = np.linalg.inv(B) ZBZ = quadratic(Z, B_inverse[:, :, np.newaxis]) ZBZ = np.real(ZBZ) ZBZ = np.maximum(ZBZ, 0) ZBZ = self.flooring_fn(ZBZ) log_prob = np.log(alpha) - self.compute_logdet(B) log_gamma = log_prob[:, :, np.newaxis] - self.n_channels * np.log(ZBZ) loss = -logsumexp(log_gamma, axis=0) loss = np.mean(loss, axis=-1) loss = loss.sum(axis=0) loss = loss.item() return loss ================================================ FILE: ssspy/bss/fdica.py ================================================ import functools from typing import Callable, Iterable, List, Optional, Tuple, Union import numpy as np from ..algorithm import ( MINIMAL_DISTORTION_PRINCIPLE_KEYWORDS, PROJECTION_BACK_KEYWORDS, minimal_distortion_principle, projection_back, ) from ..algorithm.permutation_alignment import correlation_based_permutation_solver from ..special.flooring import identity, max_flooring from ..utils.flooring import choose_flooring_fn from ..utils.select_pair import sequential_pair_selector from ._update_spatial_model import update_by_ip1, update_by_ip2_one_pair from .base import IterativeMethodBase __all__ = [ "GradFDICA", "NaturalGradFDICA", "AuxFDICA", "GradLaplaceFDICA", "NaturalGradLaplaceFDICA", "AuxLaplaceFDICA", ] spatial_algorithms = ["IP", "IP1", "IP2"] EPS = 1e-10 class FDICABase(IterativeMethodBase): r"""Base class of frequency-domain independent component analysis (FDICA). Args: contrast_fn (callable): A contrast function which corresponds to :math:`-\log p(y_{ijn})`. This function is expected to receive (n_channels, n_bins, n_frames) and return (n_channels, n_bins, n_frames). flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to receive (n_channels, n_bins, n_frames) and return (n_channels, n_bins, n_frames). If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``partial(max_flooring, eps=1e-10)``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. permutation_alignment (bool): If ``permutation_alignment=True``, a permutation solver is used to align estimated spectrograms. Default: ``True``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` or ``minimal_distortion_principle``. Default: ``True``. record_loss (bool): Record the loss at each iteration of the update algorithm if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back and minimal distortion principle. Default: ``0``. """ def __init__( self, contrast_fn: Callable[[np.ndarray], np.ndarray] = None, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), callbacks: Optional[ Union[Callable[["FDICABase"], None], List[Callable[["FDICABase"], None]]] ] = None, permutation_alignment: bool = True, scale_restoration: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, ) -> None: super().__init__(callbacks=callbacks, record_loss=record_loss) if contrast_fn is None: raise ValueError("Specify contrast function.") else: self.contrast_fn = contrast_fn if flooring_fn is None: self.flooring_fn = identity else: self.flooring_fn = flooring_fn self.input = None self.permutation_alignment = permutation_alignment self.scale_restoration = scale_restoration if reference_id is None and scale_restoration: raise ValueError("Specify 'reference_id' if scale_restoration=True.") else: self.reference_id = reference_id def __call__( self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs ) -> np.ndarray: r"""Separate a frequency-domain multichannel signal. Args: input (numpy.ndarray): Mixture signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). n_iter (int): Number of iterations of demixing filter updates. Default: ``100``. initial_call (bool): If ``True``, perform callbacks (and computation of loss if necessary) before iterations. Returns: numpy.ndarray of the separated signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). """ self.input = input.copy() self._reset(**kwargs) super().__call__(n_iter=n_iter, initial_call=initial_call) raise NotImplementedError("Implement '__call__' method.") def __repr__(self) -> str: s = "FDICA(" s += ", permutation_alignment={permutation_alignment}" s += ", scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) def _reset(self, **kwargs) -> None: r"""Reset attributes by given keyword arguments. Args: kwargs: Keyword arguments to set as attributes of FDICA. """ assert self.input is not None, "Specify data!" for key in kwargs.keys(): setattr(self, key, kwargs[key]) X = self.input n_channels, n_bins, n_frames = X.shape n_sources = n_channels # n_channels == n_sources self.n_sources, self.n_channels = n_sources, n_channels self.n_bins, self.n_frames = n_bins, n_frames if not hasattr(self, "demix_filter"): W = np.eye(n_sources, n_channels, dtype=np.complex128) W = np.tile(W, reps=(n_bins, 1, 1)) else: if self.demix_filter is None: W = None else: # To avoid overwriting ``demix_filter`` given by keyword arguments. W = self.demix_filter.copy() self.demix_filter = W self.output = self.separate(X, demix_filter=W) def separate(self, input: np.ndarray, demix_filter: np.ndarray) -> np.ndarray: r"""Separate ``input`` using ``demixing_filter``. .. math:: \boldsymbol{y}_{ij} = \boldsymbol{W}_{i}\boldsymbol{x}_{ij} Args: input (numpy.ndarray): The mixture signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). demix_filter (numpy.ndarray): The demixing filters to separate ``input``. The shape is (n_bins, n_sources, n_channels). Returns: numpy.ndarray of the separated signal in frequency-domain. The shape is (n_sources, n_bins, n_frames). """ X, W = input, demix_filter Y = W @ X.transpose(1, 0, 2) output = Y.transpose(1, 0, 2) return output def compute_loss(self) -> float: r"""Compute loss :math:`\mathcal{L}`. :math:`\mathcal{L}` is given as follows: .. math:: \mathcal{L} &= \sum_{i}\mathcal{L}^{[i]}, \\ \mathcal{L}^{[i]} &= \frac{1}{J}\sum_{j,n}G(y_{ijn}) - 2\log|\det\boldsymbol{W}_{i}|, \\ G(y_{ijn}) \ &= - \log p(y_{ijn}) Returns: Computed loss. """ X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) # (n_sources, n_bins, n_frames) logdet = self.compute_logdet(W) # (n_bins,) G = self.contrast_fn(Y) # (n_sources, n_bins, n_frames) loss = np.sum(np.mean(G, axis=2), axis=0) - 2 * logdet loss = loss.sum(axis=0).item() return loss def compute_logdet(self, demix_filter: np.ndarray) -> np.ndarray: r"""Compute log-determinant of demixing filter. Args: demix_filter (numpy.ndarray): Demixing filters with shape of (n_bins, n_sources, n_channels). Returns: numpy.ndarray of computed log-determinant values. """ _, logdet = np.linalg.slogdet(demix_filter) # (n_bins,) return logdet def solve_permutation(self) -> None: r"""Align demixing filters and separated spectrograms""" permutation_alignment = self.permutation_alignment assert permutation_alignment, "Set permutation_alignment=True." if type(permutation_alignment) is bool: # when permutation_alignment is True permutation_alignment = "spectrogram_correlation" if permutation_alignment == "spectrogram_correlation": self.solve_permutation_by_correlation() else: raise NotImplementedError( "permutation_alignment {} is not implemented.".format(permutation_alignment) ) def solve_permutation_by_correlation( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Align posteriors and separated spectrograms by correlation. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Y = Y.transpose(1, 0, 2) Y, W = correlation_based_permutation_solver(Y, W, flooring_fn=flooring_fn) Y = Y.transpose(1, 0, 2) self.output, self.demix_filter = Y, W def restore_scale(self) -> None: r"""Restore scale ambiguity. If ``self.scale_restoration=projection_back``, we use projection back technique. If ``self.scale_restoration=minimal_distortion_principle``, we use minimal distortion principle. """ scale_restoration = self.scale_restoration assert scale_restoration, "Set self.scale_restoration=True." if type(scale_restoration) is bool: scale_restoration = PROJECTION_BACK_KEYWORDS[0] if scale_restoration in PROJECTION_BACK_KEYWORDS: self.apply_projection_back() elif scale_restoration in MINIMAL_DISTORTION_PRINCIPLE_KEYWORDS: self.apply_minimal_distortion_principle() else: raise ValueError("{} is not supported for scale restoration.".format(scale_restoration)) def apply_projection_back(self) -> None: r"""Apply projection back technique to estimated spectrograms.""" assert self.scale_restoration, "Set self.scale_restoration=True." X, W = self.input, self.demix_filter W_scaled = projection_back(W, reference_id=self.reference_id) Y_scaled = self.separate(X, demix_filter=W_scaled) self.output, self.demix_filter = Y_scaled, W_scaled def apply_minimal_distortion_principle(self) -> None: r"""Apply minimal distortion principle to estimated spectrograms.""" assert self.scale_restoration, "Set self.scale_restoration=True." X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Y_scaled = minimal_distortion_principle(Y, reference=X, reference_id=self.reference_id) X = X.transpose(1, 0, 2) Y = Y_scaled.transpose(1, 0, 2) X_Hermite = X.transpose(0, 2, 1).conj() W_scaled = Y @ X_Hermite @ np.linalg.inv(X @ X_Hermite) self.output, self.demix_filter = Y_scaled, W_scaled class GradFDICABase(FDICABase): r"""Base class of frequency-domain independent component analysis (FDICA) \ using the gradient descent. Args: step_size (float): A step size of the gradient descent. Default: ``1e-1``. contrast_fn (callable): A contrast function which corresponds to :math:`-\log p(y_{ijn})`. This function is expected to receive (n_channels, n_bins, n_frames) and return (n_channels, n_bins, n_frames). score_fn (callable): A score function which corresponds to the partial derivative of the contrast function. This function is expected to receive (n_channels, n_bins, n_frames) and return (n_channels, n_bins, n_frames). flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to receive (n_channels, n_bins, n_frames) and return (n_channels, n_bins, n_frames). If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``partial(max_flooring, eps=1e-10)``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. permutation_alignment (bool): If ``permutation_alignment=True``, a permutation solver is used to align estimated spectrograms. Default: ``True``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` or ``minimal_distortion_principle``. Default: ``True``. record_loss (bool): Record the loss at each iteration of the gradient descent if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back and minimal distortion principle. Default: ``0``. """ def __init__( self, step_size: float = 1e-1, contrast_fn: Callable[[np.ndarray], np.ndarray] = None, score_fn: Callable[[np.ndarray], np.ndarray] = None, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), callbacks: Optional[ Union[Callable[["GradFDICABase"], None], List[Callable[["GradFDICABase"], None]]] ] = None, permutation_alignment: bool = True, scale_restoration: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, ) -> None: super().__init__( contrast_fn=contrast_fn, flooring_fn=flooring_fn, callbacks=callbacks, permutation_alignment=permutation_alignment, scale_restoration=scale_restoration, record_loss=record_loss, reference_id=reference_id, ) self.step_size = step_size if score_fn is None: raise ValueError("Specify score function.") else: self.score_fn = score_fn def __call__( self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs ) -> np.ndarray: r"""Separate a frequency-domain multichannel signal. Args: input (numpy.ndarray): The mixture signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). n_iter (int): The number of iterations of demixing filter updates. Default: ``100``. initial_call (bool): If ``True``, perform callbacks (and computation of loss if necessary) before iterations. Returns: numpy.ndarray of the separated signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). """ self.input = input.copy() self._reset(**kwargs) # Call __call__ of FDICABase's parent, i.e. __call__ of IterativeMethodBase super(FDICABase, self).__call__(n_iter=n_iter, initial_call=initial_call) if self.permutation_alignment: self.solve_permutation() if self.scale_restoration: self.restore_scale() self.output = self.separate(self.input, demix_filter=self.demix_filter) return self.output def __repr__(self) -> str: s = "GradFDICA(" s += "step_size={step_size}" s += ", permutation_alignment={permutation_alignment}" s += ", scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) def update_once(self) -> None: r"""Update demixing filters once.""" raise NotImplementedError("Implement 'update_once' method.") class GradFDICA(GradFDICABase): r"""Frequency-domain independent component analysis (FDICA) \ using the gradient descent. Args: step_size (float): A step size of the gradient descent. Default: ``1e-1``. contrast_fn (callable): A contrast function corresponds to :math:`-\log p(y_{ijn})`. This function is expected to receive (n_channels, n_bins, n_frames) and return (n_channels, n_bins, n_frames). score_fn (callable): A score function corresponds to the partial derivative of the contrast function. This function is expected to receive (n_channels, n_bins, n_frames) and return (n_channels, n_bins, n_frames). flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to receive (n_channels, n_bins, n_frames) and return (n_channels, n_bins, n_frames). If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``partial(max_flooring, eps=1e-10)``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. is_holonomic (bool): If ``is_holonomic=True``, Holonomic-type update is used. Otherwise, Nonholonomic-type update is used. Default: ``False``. permutation_alignment (bool): If ``permutation_alignment=True``, a permutation solver is used to align estimated spectrograms. Default: ``True``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` or ``minimal_distortion_principle``. Default: ``True``. record_loss (bool): Record the loss at each iteration of the gradient descent if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back and minimal distortion principle. Default: ``0``. Examples: Update demixing filters using Holonomic-type update: .. code-block:: python >>> def contrast_fn(y): ... return 2 * np.abs(y) >>> def score_fn(y): ... denom = np.maximum(np.abs(y), 1e-10) ... return y / denom >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = \ ... np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> fdica = GradFDICA( ... contrast_fn=contrast_fn, ... score_fn=score_fn, ... is_holonomic=True, ... ) >>> spectrogram_est = fdica(spectrogram_mix, n_iter=1000) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters using Nonholonomic-type update: .. code-block:: python >>> def contrast_fn(y): ... return 2 * np.abs(y) >>> def score_fn(y): ... denom = np.maximum(np.abs(y), 1e-10) ... return y / denom >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = \ ... np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> fdica = GradFDICA( ... contrast_fn=contrast_fn, ... score_fn=score_fn, ... is_holonomic=False, ... ) >>> spectrogram_est = fdica(spectrogram_mix, n_iter=1000) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) """ def __init__( self, step_size: float = 1e-1, contrast_fn: Callable[[np.ndarray], np.ndarray] = None, score_fn: Callable[[np.ndarray], np.ndarray] = None, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), callbacks: Optional[ Union[Callable[["GradFDICA"], None], List[Callable[["GradFDICA"], None]]] ] = None, is_holonomic: bool = False, permutation_alignment: bool = True, scale_restoration: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, ) -> None: super().__init__( step_size=step_size, contrast_fn=contrast_fn, score_fn=score_fn, flooring_fn=flooring_fn, callbacks=callbacks, permutation_alignment=permutation_alignment, scale_restoration=scale_restoration, record_loss=record_loss, reference_id=reference_id, ) self.is_holonomic = is_holonomic def __repr__(self) -> str: s = "GradFDICA(" s += "step_size={step_size}" s += ", is_holonomic={is_holonomic}" s += ", permutation_alignment={permutation_alignment}" s += ", scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) def update_once(self) -> None: r"""Update demixing filters once using the gradient descent. If ``is_holonomic=True``, demixing filters are updated as follows: .. math:: \boldsymbol{W}_{i} \leftarrow\boldsymbol{W}_{i} - \eta\left(\frac{1}{J}\sum_{j} \boldsymbol{\phi}(\boldsymbol{y}_{ij})\boldsymbol{y}_{ij}^{\mathsf{H}} -\boldsymbol{I}\right)\boldsymbol{W}_{i}^{-\mathsf{H}}, where .. math:: \boldsymbol{\phi}(\boldsymbol{y}_{ij}) &= \left(\phi(y_{ij1}),\ldots,\phi(y_{ijn}),\ldots,\phi(y_{ijN}) \right)^{\mathsf{T}}\in\mathbb{C}^{N}, \\ \phi(y_{ijn}) &= \frac{\partial G(y_{ijn})}{\partial y_{ijn}^{*}}, \\ G(y_{ijn}) &= -\log p(y_{ijn}). Otherwise (``is_holonomic=False``), .. math:: \boldsymbol{W}_{i} \leftarrow\boldsymbol{W}_{i} - \eta\cdot\mathrm{offdiag}\left(\frac{1}{J}\sum_{j} \boldsymbol{\phi}(\boldsymbol{y}_{ij})\boldsymbol{y}_{ij}^{\mathsf{H}}\right) \boldsymbol{W}_{i}^{-\mathsf{H}}. """ X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Phi = self.score_fn(Y) Y_conj = Y.conj() PhiY = np.mean(Phi[:, np.newaxis, :, :] * Y_conj[np.newaxis, :, :, :], axis=-1) PhiY = PhiY.transpose(2, 0, 1) # (n_bins, n_sources, n_sources) W_inv = np.linalg.inv(W) W_inv_Hermite = W_inv.transpose(0, 2, 1).conj() eye = np.eye(self.n_sources) if self.is_holonomic: delta = (PhiY - eye) @ W_inv_Hermite else: delta = ((1 - eye) * PhiY) @ W_inv_Hermite W = W - self.step_size * delta Y = self.separate(X, demix_filter=W) self.demix_filter = W self.output = Y class NaturalGradFDICA(GradFDICABase): r"""Frequency-domain independent component analysis (FDICA) \ using the natural gradient descent. Args: step_size (float): A step size of the gradient descent. Default: ``1e-1``. contrast_fn (callable): A contrast function corresponds to :math:`-\log p(y_{ijn})`. This function is expected to receive (n_channels, n_bins, n_frames) and return (n_channels, n_bins, n_frames). score_fn (callable): A score function corresponds to the partial derivative of the contrast function. This function is expected to receive (n_channels, n_bins, n_frames) and return (n_channels, n_bins, n_frames). flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to receive (n_channels, n_bins, n_frames) and return (n_channels, n_bins, n_frames). If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``partial(max_flooring, eps=1e-10)``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. is_holonomic (bool): If ``is_holonomic=True``, Holonomic-type update is used. Otherwise, Nonholonomic-type update is used. Default: ``False``. permutation_alignment (bool): If ``permutation_alignment=True``, a permutation solver is used to align estimated spectrograms. Default: ``True``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` or ``minimal_distortion_principle``. Default: ``True``. record_loss (bool): Record the loss at each iteration of the gradient descent if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back and minimal distortion principle. Default: ``0``. Examples: Update demixing filters using Holonomic-type update: .. code-block:: python >>> def contrast_fn(y): ... return 2 * np.abs(y) >>> def score_fn(y): ... denom = np.maximum(np.abs(y), 1e-10) ... return y / denom >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = \ ... np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> fdica = NaturalGradFDICA( ... contrast_fn=contrast_fn, ... score_fn=score_fn, ... is_holonomic=True, ... ) >>> spectrogram_est = fdica(spectrogram_mix, n_iter=1000) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters using Nonholonomic-type update: .. code-block:: python >>> def contrast_fn(y): ... return 2 * np.abs(y) >>> def score_fn(y): ... denom = np.maximum(np.abs(y), 1e-10) ... return y / denom >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = \ ... np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> fdica = NaturalGradFDICA( ... contrast_fn=contrast_fn, ... score_fn=score_fn, ... is_holonomic=False, ... ) >>> spectrogram_est = fdica(spectrogram_mix, n_iter=1000) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) """ def __init__( self, step_size: float = 1e-1, contrast_fn: Callable[[np.ndarray], np.ndarray] = None, score_fn: Callable[[np.ndarray], np.ndarray] = None, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), callbacks: Optional[ Union[Callable[["NaturalGradFDICA"], None], List[Callable[["NaturalGradFDICA"], None]]] ] = None, is_holonomic: bool = False, permutation_alignment: bool = True, scale_restoration: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, ) -> None: super().__init__( step_size=step_size, contrast_fn=contrast_fn, score_fn=score_fn, flooring_fn=flooring_fn, callbacks=callbacks, permutation_alignment=permutation_alignment, scale_restoration=scale_restoration, record_loss=record_loss, reference_id=reference_id, ) self.is_holonomic = is_holonomic def __repr__(self) -> str: s = "NaturalGradFDICA(" s += "step_size={step_size}" s += ", is_holonomic={is_holonomic}" s += ", permutation_alignment={permutation_alignment}" s += ", scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) def update_once(self) -> None: r"""Update demixing filters once using the gradient descent. If ``is_holonomic=True``, demixing filters are updated as follows: .. math:: \boldsymbol{W}_{i} \leftarrow\boldsymbol{W}_{i} - \eta\left(\frac{1}{J}\sum_{j} \boldsymbol{\phi}(\boldsymbol{y}_{ij})\boldsymbol{y}_{ij}^{\mathsf{H}} -\boldsymbol{I}\right)\boldsymbol{W}_{i}, where .. math:: \boldsymbol{\phi}(\boldsymbol{y}_{ij}) &= \left(\phi(y_{ij1}),\ldots,\phi(y_{ijn}),\ldots,\phi(y_{ijN}) \right)^{\mathsf{T}}\in\mathbb{C}^{N}, \\ \phi(y_{ijn}) &= \frac{\partial G(y_{ijn})}{\partial y_{ijn}^{*}}, \\ G(y_{ijn}) &= -\log p(y_{ijn}). Otherwise (``is_holonomic=False``), .. math:: \boldsymbol{W}_{i} \leftarrow\boldsymbol{W}_{i} - \eta\cdot\mathrm{offdiag}\left(\frac{1}{J}\sum_{j} \boldsymbol{\phi}(\boldsymbol{y}_{ij})\boldsymbol{y}_{ij}^{\mathsf{H}}\right) \boldsymbol{W}_{i}. """ X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Phi = self.score_fn(Y) Y_conj = Y.conj() PhiY = np.mean(Phi[:, np.newaxis, :, :] * Y_conj[np.newaxis, :, :, :], axis=-1) PhiY = PhiY.transpose(2, 0, 1) # (n_bins, n_sources, n_sources) eye = np.eye(self.n_sources) if self.is_holonomic: delta = (PhiY - eye) @ W else: delta = ((1 - eye) * PhiY) @ W W = W - self.step_size * delta Y = self.separate(X, demix_filter=W) self.demix_filter = W self.output = Y class AuxFDICA(FDICABase): r"""Auxiliary-function-based frequency-domain independent component analysis \ (AuxFDICA) [#ono2010auxiliary]_. Args: spatial_algorithm (str): Algorithm to update demixing filters. Choose ``IP``, ``IP1``, or ``IP2``. Default: ``IP``. contrast_fn (callable): A contrast function corresponds to :math:`-\log p(y_{ijn})`. This function is expected to receive (n_channels, n_bins, n_frames) and return (n_channels, n_bins, n_frames). d_contrast_fn (callable): A partial derivative of the real contrast function. This function is expected to receive (n_channels, n_bins, n_frames) and return (n_channels, n_bins, n_frames). flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to receive (n_channels, n_bins, n_frames) and return (n_channels, n_bins, n_frames). If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``partial(max_flooring, eps=1e-10)``. pair_selector (callable, optional): Selector to choose updaing pair in ``IP2`` and ``ISS2``. If ``None`` is given, ``partial(sequential_pair_selector, sort=True)`` is used. Default: ``None``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. permutation_alignment (bool): If ``permutation_alignment=True``, a permutation solver is used to align estimated spectrograms. Default: ``True``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` or ``minimal_distortion_principle``. Default: ``True``. record_loss (bool): Record the loss at each iteration of the demixing filter update if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back and minimal distortion principle. Default: ``0``. Examples: Update demixing filters by IP: .. code-block:: python >>> def contrast_fn(y): ... return 2 * np.abs(y) >>> def d_contrast_fn(y): ... return 2 * np.ones_like(y) >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> fdica = AuxFDICA( ... spatial_algorithm="IP", ... contrast_fn=contrast_fn, ... d_contrast_fn=d_contrast_fn, ... ) >>> spectrogram_est = fdica(spectrogram_mix, n_iter=100) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters by IP2: .. code-block:: python >>> from ssspy.utils.select_pair import sequential_pair_selector >>> def contrast_fn(y): ... return 2 * np.abs(y) >>> def d_contrast_fn(y): ... return 2 * np.ones_like(y) >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> fdica = AuxFDICA( ... spatial_algorithm="IP2", ... contrast_fn=contrast_fn, ... d_contrast_fn=d_contrast_fn, ... pair_selector=sequential_pair_selector, ... ) >>> spectrogram_est = fdica(spectrogram_mix, n_iter=100) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) .. [#ono2010auxiliary] N. Ono and S. Miyabe, "Auxiliary-function-based independent component analysis for super-Gaussian sources," in *Proc. LVA/ICA*, 2010, pp.165-172. """ def __init__( self, spatial_algorithm: str = "IP", contrast_fn: Callable[[np.ndarray], np.ndarray] = None, d_contrast_fn: Callable[[np.ndarray], np.ndarray] = None, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), pair_selector: Optional[Callable[[int], Iterable[Tuple[int, int]]]] = None, callbacks: Optional[ Union[Callable[["AuxFDICA"], None], List[Callable[["AuxFDICA"], None]]] ] = None, permutation_alignment: bool = True, scale_restoration: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, ) -> None: super().__init__( contrast_fn=contrast_fn, flooring_fn=flooring_fn, callbacks=callbacks, permutation_alignment=permutation_alignment, scale_restoration=scale_restoration, record_loss=record_loss, reference_id=reference_id, ) assert spatial_algorithm in spatial_algorithms, "Not support {}.".format(spatial_algorithms) self.spatial_algorithm = spatial_algorithm self.d_contrast_fn = d_contrast_fn if pair_selector is None: if spatial_algorithm == "IP2": self.pair_selector = sequential_pair_selector else: self.pair_selector = pair_selector def __call__( self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs ) -> np.ndarray: r"""Separate a frequency-domain multichannel signal. Args: input (numpy.ndarray): The mixture signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). n_iter (int): The number of iterations of demixing filter updates. Default: ``100``. initial_call (bool): If ``True``, perform callbacks (and computation of loss if necessary) before iterations. Returns: numpy.ndarray of the separated signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). """ self.input = input.copy() self._reset(**kwargs) # Call __call__ of FDICABase's parent, i.e. __call__ of IterativeMethodBase super(FDICABase, self).__call__(n_iter=n_iter, initial_call=initial_call) if self.permutation_alignment: self.solve_permutation() if self.scale_restoration: self.restore_scale() if self.demix_filter is not None: self.output = self.separate(self.input, demix_filter=self.demix_filter) else: # TODO: implement demixing-filter-free algorithms (e.g. ISS, IPA, etc.) pass return self.output def __repr__(self) -> str: s = "AuxFDICA(" s += "spatial_algorithm={spatial_algorithm}" s += ", permutation_alignment={permutation_alignment}" s += ", scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) def update_once( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update demixing filters once. - If ``self.spatial_algorithm`` is ``IP`` or ``IP1``, ``update_once_ip1`` is called. - If ``self.spatial_algorithm`` is ``IP2``, ``update_once_ip2`` is called. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) if self.spatial_algorithm in ["IP", "IP1"]: self.update_once_ip1(flooring_fn=flooring_fn) elif self.spatial_algorithm in ["IP2"]: self.update_once_ip2(flooring_fn=flooring_fn) else: raise NotImplementedError("Not support {}.".format(self.spatial_algorithm)) def update_once_ip1( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update demixing filters once using iterative projection. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. Demixing filters are updated sequentially for :math:`n=1,\ldots,N` as follows: .. math:: \boldsymbol{w}_{in} &\leftarrow\left(\boldsymbol{W}_{in}^{\mathsf{H}}\boldsymbol{U}_{in}\right)^{-1} \boldsymbol{e}_{n}, \\ \boldsymbol{w}_{in} &\leftarrow\frac{\boldsymbol{w}_{in}} {\sqrt{\boldsymbol{w}_{in}^{\mathsf{H}}\boldsymbol{U}_{in}\boldsymbol{w}_{in}}}, \\ where .. math:: \boldsymbol{U}_{in} &= \frac{1}{J}\sum_{j} \frac{G'_{\mathbb{R}}(|y_{ijn}|)}{2|y_{ijn}|} \boldsymbol{x}_{ij}\boldsymbol{x}_{ij}^{\mathsf{H}}, \\ G(y_{ijn}) &= -\log p(y_{ijn}), \\ G_{\mathbb{R}}(|y_{ijn}|) &= G(y_{ijn}). """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) XX_Hermite = X[:, np.newaxis, :, :] * X[np.newaxis, :, :, :].conj() XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3) # (n_bins, n_channels, n_channels, n_frames) Y_abs = np.abs(Y) denom = flooring_fn(2 * Y_abs) varphi = self.d_contrast_fn(Y_abs) / denom # (n_sources, n_bins, n_frames) varphi = varphi.transpose(1, 0, 2) # (n_bins, n_sources, n_frames) GXX = varphi[:, :, np.newaxis, np.newaxis, :] * XX_Hermite[:, np.newaxis, :, :, :] U = np.mean(GXX, axis=-1) # (n_bins, n_sources, n_channels, n_channels) self.demix_filter = update_by_ip1(W, U, flooring_fn=flooring_fn) def update_once_ip2( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update demixing filters once using pairwise iterative projection. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. For :math:`n_{1}` and :math:`n_{2}` (:math:`n_{1}\neq n_{2}`), compute auxiliary variables: .. math:: \bar{r}_{ijn_{1}} &\leftarrow|y_{ijn_{1}}| \\ \bar{r}_{ijn_{2}} &\leftarrow|y_{ijn_{2}}| Then, for :math:`n=n_{1},n_{2}`, compute weighted covariance matrix as follows: .. math:: \boldsymbol{U}_{in_{1}} &= \frac{1}{J}\sum_{j} \frac{G'_{\mathbb{R}}(\bar{r}_{ijn_{1}})}{2\bar{r}_{ijn_{1}}} \boldsymbol{x}_{ij}\boldsymbol{x}_{ij}^{\mathsf{H}}, \\ \boldsymbol{U}_{in_{2}} &= \frac{1}{J}\sum_{j} \frac{G'_{\mathbb{R}}(\bar{r}_{ijn_{2}})}{2\bar{r}_{ijn_{2}}} \boldsymbol{x}_{ij}\boldsymbol{x}_{ij}^{\mathsf{H}}, where .. math:: G(y_{ijn}) &= -\log p(y_{ijn}), \\ G_{\mathbb{R}}(|y_{ijn}|) &= G(y_{ijn}). Using :math:`\boldsymbol{U}_{in_{1}}` and :math:`\boldsymbol{U}_{in_{2}}`, we compute generalized eigenvectors. .. math:: \left({\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}}^{\mathsf{H}}\boldsymbol{U}_{in_{1}} \boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\right)\boldsymbol{h}_{i} = \lambda_{i} \left({\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}}^{\mathsf{H}}\boldsymbol{U}_{in_{2}} \boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\right)\boldsymbol{h}_{i}, where .. math:: \boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})} &= (\boldsymbol{W}_{i}\boldsymbol{U}_{in_{1}})^{-1} ( \begin{array}{cc} \boldsymbol{e}_{n_{1}} & \boldsymbol{e}_{n_{2}} \end{array} ), \\ \boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})} &= (\boldsymbol{W}_{i}\boldsymbol{U}_{in_{2}})^{-1} ( \begin{array}{cc} \boldsymbol{e}_{n_{1}} & \boldsymbol{e}_{n_{2}} \end{array} ). After that, we standardize two eigenvectors :math:`\boldsymbol{h}_{in_{1}}` and :math:`\boldsymbol{h}_{in_{2}}`. .. math:: \boldsymbol{h}_{in_{1}} &\leftarrow\frac{\boldsymbol{h}_{in_{1}}} {\sqrt{\boldsymbol{h}_{in_{1}}^{\mathsf{H}} \left({\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}}^{\mathsf{H}}\boldsymbol{U}_{in_{1}} \boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\right) \boldsymbol{h}_{in_{1}}}}, \\ \boldsymbol{h}_{in_{2}} &\leftarrow\frac{\boldsymbol{h}_{in_{2}}} {\sqrt{\boldsymbol{h}_{in_{2}}^{\mathsf{H}} \left({\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}}^{\mathsf{H}}\boldsymbol{U}_{in_{2}} \boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\right) \boldsymbol{h}_{in_{2}}}}. Then, update :math:`\boldsymbol{w}_{in_{1}}` and :math:`\boldsymbol{w}_{in_{2}}` simultaneously. .. math:: \boldsymbol{w}_{in_{1}} &\leftarrow \boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\boldsymbol{h}_{in_{1}} \\ \boldsymbol{w}_{in_{2}} &\leftarrow \boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\boldsymbol{h}_{in_{2}} At each iteration, we update pairs of :math:`n_{1}` and :math:`n_{1}` for :math:`n_{1}\neq n_{2}`. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) n_sources = self.n_sources X, W = self.input, self.demix_filter XX_Hermite = X[:, np.newaxis, :, :] * X[np.newaxis, :, :, :].conj() XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3) for m, n in self.pair_selector(n_sources): W_mn = W[:, (m, n), :] Y_mn = self.separate(X, demix_filter=W_mn) Y_abs_mn = np.abs(Y_mn) denom = flooring_fn(2 * Y_abs_mn) varphi_mn = self.d_contrast_fn(Y_abs_mn) / denom varphi_mn = varphi_mn.transpose(1, 0, 2) GXX_mn = varphi_mn[:, :, np.newaxis, np.newaxis, :] * XX_Hermite[:, np.newaxis, :, :, :] U_mn = np.mean(GXX_mn, axis=-1) W[:, (m, n), :] = update_by_ip2_one_pair( W, U_mn, pair=(m, n), flooring_fn=flooring_fn, ) self.demix_filter = W class GradLaplaceFDICA(GradFDICA): r"""Frequency-domain independent component analysis (FDICA) \ using the gradient descent on a Laplace distribution. We assume :math:`y_{ijn}` follows a Laplace distribution. .. math:: p(y_{ijn})\propto\exp(|y_{ijn}|) Args: step_size (float): A step size of the gradient descent. Default: ``1e-1``. flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to receive (n_channels, n_bins, n_frames) and return (n_channels, n_bins, n_frames). If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``partial(max_flooring, eps=1e-10)``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. is_holonomic (bool): If ``is_holonomic=True``, Holonomic-type update is used. Otherwise, Nonholonomic-type update is used. Default: ``False``. permutation_alignment (bool): If ``permutation_alignment=True``, a permutation solver is used to align estimated spectrograms. Default: ``True``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` or ``minimal_distortion_principle``. Default: ``True``. record_loss (bool): Record the loss at each iteration of the gradient descent if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back and minimal distortion principle. Default: ``0``. Examples: Update demixing filters using Holonomic-type update: .. code-block:: python >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = \ ... np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> fdica = GradLaplaceFDICA(is_holonomic=True) >>> spectrogram_est = fdica(spectrogram_mix, n_iter=1000) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters using Nonholonomic-type update: .. code-block:: python >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = \ ... np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> fdica = GradLaplaceFDICA(is_holonomic=False) >>> spectrogram_est = fdica(spectrogram_mix, n_iter=1000) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) """ def __init__( self, step_size: float = 1e-1, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), callbacks: Optional[ Union[Callable[["GradLaplaceFDICA"], None], List[Callable[["GradLaplaceFDICA"], None]]] ] = None, is_holonomic: bool = False, permutation_alignment: bool = True, scale_restoration: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, ) -> None: def contrast_fn(y: np.ndarray) -> np.ndarray: r"""Contrast function. Args: y (numpy.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: The shape is (n_sources, n_bins, n_frames). """ return 2 * np.abs(y) def score_fn(y: np.ndarray) -> np.ndarray: r"""Score function. Args: y (numpy.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: The shape is (n_sources, n_bins, n_frames). """ denom = self.flooring_fn(np.abs(y)) return y / denom super().__init__( step_size=step_size, contrast_fn=contrast_fn, score_fn=score_fn, flooring_fn=flooring_fn, callbacks=callbacks, is_holonomic=is_holonomic, permutation_alignment=permutation_alignment, scale_restoration=scale_restoration, record_loss=record_loss, reference_id=reference_id, ) def __repr__(self) -> str: s = "GradLaplaceFDICA(" s += "step_size={step_size}" s += ", is_holonomic={is_holonomic}" s += ", permutation_alignment={permutation_alignment}" s += ", scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) class NaturalGradLaplaceFDICA(NaturalGradFDICA): r"""Frequency-domain independent component analysis (FDICA) \ using the natural gradient descent on a Laplace distribution. We assume :math:`y_{ijn}` follows a Laplace distribution. .. math:: p(y_{ijn})\propto\exp(|y_{ijn}|) Args: step_size (float): A step size of the gradient descent. Default: ``1e-1``. flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to receive (n_channels, n_bins, n_frames) and return (n_channels, n_bins, n_frames). If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``partial(max_flooring, eps=1e-10)``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. is_holonomic (bool): If ``is_holonomic=True``, Holonomic-type update is used. Otherwise, Nonholonomic-type update is used. Default: ``False``. permutation_alignment (bool): If ``permutation_alignment=True``, a permutation solver is used to align estimated spectrograms. Default: ``True``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` or ``minimal_distortion_principle``. Default: ``True``. record_loss (bool): Record the loss at each iteration of the gradient descent if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back and minimal distortion principle. Default: ``0``. Examples: Update demixing filters using Holonomic-type update: .. code-block:: python >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = \ ... np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> fdica = NaturalGradLaplaceFDICA(is_holonomic=True) >>> spectrogram_est = fdica(spectrogram_mix, n_iter=1000) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters using Nonholonomic-type update: .. code-block:: python >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = \ ... np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> fdica = NaturalGradLaplaceFDICA(is_holonomic=False) >>> spectrogram_est = fdica(spectrogram_mix, n_iter=1000) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) """ def __init__( self, step_size: float = 1e-1, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), callbacks: Optional[ Union[ Callable[["NaturalGradLaplaceFDICA"], None], List[Callable[["NaturalGradLaplaceFDICA"], None]], ] ] = None, is_holonomic: bool = False, permutation_alignment: bool = True, scale_restoration: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, ) -> None: def contrast_fn(y: np.ndarray) -> np.ndarray: r"""Contrast function. Args: y (numpy.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: The shape is (n_sources, n_bins, n_frames). """ return 2 * np.abs(y) def score_fn(y: np.ndarray) -> np.ndarray: r"""Score function. Args: y (numpy.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: The shape is (n_sources, n_bins, n_frames). """ denom = self.flooring_fn(np.abs(y)) return y / denom super().__init__( step_size=step_size, contrast_fn=contrast_fn, score_fn=score_fn, flooring_fn=flooring_fn, callbacks=callbacks, is_holonomic=is_holonomic, permutation_alignment=permutation_alignment, scale_restoration=scale_restoration, record_loss=record_loss, reference_id=reference_id, ) def __repr__(self) -> str: s = "NaturalGradLaplaceFDICA(" s += "step_size={step_size}" s += ", is_holonomic={is_holonomic}" s += ", permutation_alignment={permutation_alignment}" s += ", scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) class AuxLaplaceFDICA(AuxFDICA): r"""Auxiliary-function-based frequency-domain independent component analysis \ on a Laplace distribution. We assume :math:`y_{ijn}` follows a Laplace distribution. .. math:: p(y_{ijn})\propto\exp(|y_{ijn}|) Args: spatial_algorithm (str): Algorithm to update demixing filters. Choose ``IP``, ``IP1``, or ``IP2``. Default: ``IP``. flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to receive (n_channels, n_bins, n_frames) and return (n_channels, n_bins, n_frames). If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``partial(max_flooring, eps=1e-10)``. pair_selector (callable, optional): Selector to choose updaing pair in ``IP2`` and ``ISS2``. If ``None`` is given, ``partial(sequential_pair_selector, sort=True)`` is used. Default: ``None``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. permutation_alignment (bool): If ``permutation_alignment=True``, a permutation solver is used to align estimated spectrograms. Default: ``True``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` or ``minimal_distortion_principle``. Default: ``True``. record_loss (bool): Record the loss at each iteration of the demixing filter update if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back and minimal distortion principle. Default: ``0``. Examples: Update demixing filters by IP: .. code-block:: python >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = \ ... np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> fdica = AuxLaplaceFDICA(spatial_algorithm="IP") >>> spectrogram_est = fdica(spectrogram_mix, n_iter=1000) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters by IP2: .. code-block:: python >>> from ssspy.utils.select_pair import sequential_pair_selector >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = \ ... np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> fdica = AuxLaplaceFDICA( ... spatial_algorithm="IP2", ... pair_selector=sequential_pair_selector, ... ) >>> spectrogram_est = fdica(spectrogram_mix, n_iter=1000) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) """ def __init__( self, spatial_algorithm: str = "IP", flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), pair_selector: Optional[Callable[[int], Iterable[Tuple[int, int]]]] = None, callbacks: Optional[ Union[Callable[["AuxLaplaceFDICA"], None], List[Callable[["AuxLaplaceFDICA"], None]]] ] = None, permutation_alignment: bool = True, scale_restoration: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, ) -> None: def contrast_fn(y: np.ndarray): r"""Contrast function. Args: y (numpy.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: The shape is (n_sources, n_bins, n_frames). """ return 2 * np.abs(y) def d_contrast_fn(y: np.ndarray): r"""Partial derivative of score function. Args: y (numpy.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: The shape is (n_sources, n_bins, n_frames). """ return 2 * np.ones_like(y) super().__init__( spatial_algorithm=spatial_algorithm, contrast_fn=contrast_fn, d_contrast_fn=d_contrast_fn, flooring_fn=flooring_fn, pair_selector=pair_selector, callbacks=callbacks, permutation_alignment=permutation_alignment, scale_restoration=scale_restoration, record_loss=record_loss, reference_id=reference_id, ) def __repr__(self) -> str: s = "AuxLaplaceFDICA(" s += "spatial_algorithm={spatial_algorithm}" s += ", permutation_alignment={permutation_alignment}" s += ", scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) ================================================ FILE: ssspy/bss/hva.py ================================================ import functools import math from typing import Callable, List, Optional, Union import numpy as np from ..special.flooring import identity, max_flooring from .admmbss import MaskingADMMBSS from .pdsbss import MaskingPDSBSS __all__ = [ "MaskingPDSHVA", "MaskingADMMHVA", "HVA", ] EPS = 1e-10 class MaskingPDSHVA(MaskingPDSBSS): r"""Harmonic vector analysis proposed in [#yatabe2021determined]_. Args: mu1 (float): Step size. Default: ``1``. mu2 (float): Step size. Default: ``1``. alpha (float): Relaxation parameter (deprecated). Set ``relaxation`` instead. relaxation (float): Relaxation parameter. Default: ``1``. attenuation (float, optional): Attenuation parameter in masking. Default: ``1 / n_sources``. mask_iter (int): Number of iterations in application of cosine shrinkage operator. flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` explicitly. Default: ``True``. record_loss (bool): Record the loss at each iteration of the update algorithm if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back. Default: ``0``. .. [#yatabe2021determined] K. Yatabe and D. Kitamura, "Determined BSS based on time-frequency masking and its application to \ harmonic vector analysis," *IEEE/ACM Trans. ASLP*, vol. 29, pp. 1609-1625, 2021. """ def __init__( self, mu1: float = 1, mu2: float = 1, alpha: float = None, relaxation: float = 1, attenuation: Optional[float] = None, mask_iter: int = 1, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), callbacks: Optional[ Union[Callable[["MaskingPDSHVA"], None], List[Callable[["MaskingPDSHVA"], None]]] ] = None, scale_restoration: bool = True, record_loss: Optional[bool] = None, reference_id: int = 0, ) -> None: def mask_fn(y: np.ndarray) -> np.ndarray: """Masking function to emphasize harmonic components. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: np.ndarray of mask. The shape is (n_sources, n_bins, n_frames). """ n_sources, n_bins, _ = y.shape if self.attenuation is None: self.attenuation = 1 / n_sources gamma = self.attenuation y = self.flooring_fn(np.abs(y)) zeta = np.log(y) zeta_mean = zeta.mean(axis=1, keepdims=True) rho = zeta - zeta_mean nu = np.fft.irfft(rho, axis=1, norm="backward") nu = nu[:, :n_bins] varsigma = np.minimum(1, nu) for _ in range(mask_iter): varsigma = (1 - np.cos(math.pi * varsigma)) / 2 xi = np.fft.irfft(varsigma * nu, axis=1, norm="forward") xi = xi[:, :n_bins] varrho = xi + zeta_mean v = np.exp(2 * varrho) mask = (v / v.sum(axis=0)) ** gamma return mask super().__init__( mu1=mu1, mu2=mu2, alpha=alpha, relaxation=relaxation, penalty_fn=None, mask_fn=mask_fn, callbacks=callbacks, scale_restoration=scale_restoration, record_loss=record_loss, reference_id=reference_id, ) self.attenuation = attenuation self.mask_iter = mask_iter if flooring_fn is None: self.flooring_fn = identity else: self.flooring_fn = flooring_fn def __repr__(self) -> str: s = "MaskingPDSHVA(" s += "mu1={mu1}, mu2={mu2}" s += ", relaxation={relaxation}" if self.attenuation is not None: s += ", attenuation={attenuation}" s += ", mask_iter={mask_iter}" s += ", scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) class MaskingADMMHVA(MaskingADMMBSS): """Harmonic vector analysis using ADMM with masking. Args: rho (float): Penalty parameter. Default: ``1``. alpha (float): Relaxation parameter (deprecated). Set ``relaxation`` instead. relaxation (float): Relaxation parameter. Default: ``1``. attenuation (float, optional): Attenuation parameter. mask_iter (int): Number of iterations in application of cosine shrinkage operator. flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to receive (n_channels, n_bins, n_frames) and return (n_channels, n_bins, n_frames). If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``partial(max_flooring, eps=1e-10)``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` explicitly. Default: ``True``. record_loss (bool, optional): Record the loss at each iteration of the update algorithm if ``record_loss=True``. Default: ``None``. reference_id (int): Reference channel for projection back. Default: ``0``. """ def __init__( self, rho: float = 1, alpha: float = None, relaxation: float = 1, attenuation: Optional[float] = None, mask_iter: int = 1, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), callbacks: Optional[ Union[Callable[["MaskingADMMHVA"], None], List[Callable[["MaskingADMMHVA"], None]]] ] = None, scale_restoration: bool = True, record_loss: Optional[bool] = None, reference_id: int = 0, ) -> None: def mask_fn(y: np.ndarray) -> np.ndarray: """Masking function to emphasize harmonic components. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: np.ndarray of mask. The shape is (n_sources, n_bins, n_frames). """ n_sources, n_bins, _ = y.shape if self.attenuation is None: self.attenuation = 1 / n_sources gamma = self.attenuation y = self.flooring_fn(np.abs(y)) zeta = np.log(y) zeta_mean = zeta.mean(axis=1, keepdims=True) rho = zeta - zeta_mean nu = np.fft.irfft(rho, axis=1, norm="backward") nu = nu[:, :n_bins] varsigma = np.minimum(1, nu) for _ in range(mask_iter): varsigma = (1 - np.cos(math.pi * varsigma)) / 2 xi = np.fft.irfft(varsigma * nu, axis=1, norm="forward") xi = xi[:, :n_bins] varrho = xi + zeta_mean v = np.exp(2 * varrho) mask = (v / v.sum(axis=0)) ** gamma return mask super().__init__( rho=rho, alpha=alpha, relaxation=relaxation, penalty_fn=None, mask_fn=mask_fn, callbacks=callbacks, scale_restoration=scale_restoration, record_loss=record_loss, reference_id=reference_id, ) self.attenuation = attenuation self.mask_iter = mask_iter if flooring_fn is None: self.flooring_fn = identity else: self.flooring_fn = flooring_fn def __repr__(self) -> str: s = "MaskingADMMHVA(" s += "rho={rho}" s += ", relaxation={relaxation}" if self.attenuation is not None: s += ", attenuation={attenuation}" s += ", mask_iter={mask_iter}" s += ", scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) class HVA(MaskingPDSHVA): """Alias of MaskingPDSHVA.""" def __repr__(self) -> str: s = "HVA(" s += "mu1={mu1}, mu2={mu2}" s += ", relaxation={relaxation}" if self.attenuation is not None: s += ", attenuation={attenuation}" s += ", mask_iter={mask_iter}" s += ", scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) ================================================ FILE: ssspy/bss/ica.py ================================================ from typing import Callable, List, Optional, Union import numpy as np from ..transform import whiten from .base import IterativeMethodBase __all__ = ["GradICA", "NaturalGradICA", "FastICA", "GradLaplaceICA", "NaturalGradLaplaceICA"] class GradICABase(IterativeMethodBase): r"""Base class of independent component analysis (ICA) using the gradient descent. Args: step_size (float): A step size of the gradient descent. Default: ``1e-1``. contrast_fn (callable): A contrast function which corresponds to :math:`-\log p(y_{tn})`. This function is expected to receive (n_channels, n_samples) and return (n_channels, n_samples). score_fn (callable): A score function which corresponds to the partial derivative of the contrast function. This function is expected to receive (n_channels, n_samples) and return (n_channels, n_samples). callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. record_loss (bool): Record the loss at each iteration of the gradient descent if ``record_loss=True``. Default: ``True``. """ def __init__( self, step_size: float = 1e-1, contrast_fn: Callable[[np.ndarray], np.ndarray] = None, score_fn: Callable[[np.ndarray], np.ndarray] = None, callbacks: Optional[ Union[Callable[["GradICABase"], None], List[Callable[["GradICABase"], None]]] ] = None, record_loss: bool = True, ) -> None: super().__init__(callbacks=callbacks, record_loss=record_loss) self.step_size = step_size if contrast_fn is None: raise ValueError("Specify contrast function.") else: self.contrast_fn = contrast_fn if score_fn is None: raise ValueError("Specify score function.") else: self.score_fn = score_fn self.input = None def __call__( self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs ) -> np.ndarray: r"""Separate a time-domain multichannel signal. Args: input (numpy.ndarray): Mixture signal in time-domain. The shape is (n_channels, n_samples). n_iter (int): Number of iterations of demixing filter updates. Default: ``100``. initial_call (bool): If ``True``, perform callbacks (and computation of loss if necessary) before iterations. Returns: numpy.ndarray of separated signal in time-domain. The shape is (n_sources, n_samples). """ self.input = input.copy() self._reset(**kwargs) super().__call__(n_iter=n_iter, initial_call=initial_call) self.output = self.separate(self.input, demix_filter=self.demix_filter) return self.output def __repr__(self) -> str: s = "GradICA(" s += "step_size={step_size}" s += ", record_loss={record_loss}" s += ")" return s.format(**self.__dict__) def _reset(self, **kwargs) -> None: r"""Reset attributes by given keyword arguments. Args: kwargs: Keyword arguments to set as attributes of ICA. """ assert self.input is not None, "Specify data!" for key in kwargs.keys(): setattr(self, key, kwargs[key]) X = self.input n_channels, n_samples = X.shape n_sources = n_channels # n_channels == n_sources self.n_sources, self.n_channels = n_sources, n_channels self.n_samples = n_samples if not hasattr(self, "demix_filter"): W = np.eye(n_sources, n_channels, dtype=np.float64) else: if self.demix_filter is None: W = None else: # To avoid overwriting ``demix_filter`` given by keyword arguments. W = self.demix_filter.copy() self.demix_filter = W self.output = self.separate(X, demix_filter=W) def update_once(self) -> None: r"""Update demixing filters once.""" raise NotImplementedError("Implement 'update_once' method.") def separate(self, input: np.ndarray, demix_filter: np.ndarray) -> np.ndarray: r"""Separate ``input`` using ``demixing_filter``. .. math:: \boldsymbol{y}_{t} = \boldsymbol{W}\boldsymbol{x}_{t} Args: input (numpy.ndarray): The mixture signal in time-domain. The shape is (n_channels, n_samples). demix_filter (numpy.ndarray): The demixing filters to separate ``input``. The shape is (n_sources, n_channels). Returns: numpy.ndarray of the separated signal in time-domain. The shape is (n_sources, n_samples). """ output = demix_filter @ input return output def compute_loss(self) -> float: r"""Compute loss :math:`\mathcal{L}`. :math:`\mathcal{L}` is given as follows: .. math:: \mathcal{L} \ &= \frac{1}{T}\sum_{t,n}G(y_{tn}) \ - \log|\det\boldsymbol{W}| \\ G(y_{tn}) \ &= - \log p(y_{tn}) Returns: Computed loss. """ X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) # (n_channels, n_samples) logdet = self.compute_logdet(W) G = self.contrast_fn(Y) loss = np.sum(np.mean(G, axis=1)) - logdet loss = loss.item() return loss def compute_logdet(self, demix_filter: np.ndarray) -> np.ndarray: r"""Compute log-determinant of demixing filter Args: demix_filter (numpy.ndarray): Demixing filter with shape of (n_sources, n_channels). Returns: numpy.ndarray of computed log-determinant value. The shape is (n_bins,). """ _, logdet = np.linalg.slogdet(demix_filter) # (n_bins,) return logdet class FastICABase(IterativeMethodBase): r"""Base class of fast independent component analysis (FastICA). Args: contrast_fn (callable): A contrast function which corresponds to :math:`-\log p(y_{tn})`. This function is expected to receive (n_channels, n_samples) and return (n_channels, n_samples). score_fn (callable): A score function which corresponds to the partial derivative of the contrast function. This function is expected to receive (n_channels, n_samples) and return (n_channels, n_samples). d_score_fn (callable): A partial derivative of the score function. This function is expected to return the same shape tensor as the input. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. record_loss (bool): Record the loss at each of the fixed-point iteration if ``record_loss=True``. Default: ``True``. """ def __init__( self, contrast_fn: Callable[[np.ndarray], np.ndarray] = None, score_fn: Callable[[np.ndarray], np.ndarray] = None, d_score_fn: Callable[[np.ndarray], np.ndarray] = None, callbacks: Optional[ Union[Callable[["FastICABase"], None], List[Callable[["FastICABase"], None]]] ] = None, record_loss: bool = True, ) -> None: super().__init__(callbacks=callbacks, record_loss=record_loss) if contrast_fn is None: raise ValueError("Specify contrast function.") else: self.contrast_fn = contrast_fn if score_fn is None: raise ValueError("Specify score function.") else: self.score_fn = score_fn if d_score_fn is None: raise ValueError("Specify derivative of score function.") else: self.d_score_fn = d_score_fn self.input = None def __call__( self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs ) -> np.ndarray: r"""Separate a time-domain multichannel signal. Args: input (numpy.ndarray): Mixture signal in time-domain. The shape is (n_channels, n_samples). n_iter (int): Number of iterations of demixing filter updates. Default: ``100``. initial_call (bool): If ``True``, perform callbacks (and computation of loss if necessary) before iterations. Returns: numpy.ndarray of the separated signal in time-domain. The shape is (n_sources, n_samples). """ self.input = input.copy() self._reset(**kwargs) super().__call__(n_iter=n_iter, initial_call=initial_call) self.output = self.separate( self.whitened_input, demix_filter=self.demix_filter, use_whitening=False ) return self.output def __repr__(self) -> str: s = "FastICA(" s += "record_loss={record_loss}" s += ")" return s.format(**self.__dict__) def _reset(self, **kwargs) -> None: r"""Reset attributes by given keyword arguments. Args: kwargs: Keyword arguments to set as attributes of ICA. """ assert self.input is not None, "Specify data!" for key in kwargs.keys(): setattr(self, key, kwargs[key]) X = self.input n_channels, n_samples = X.shape n_sources = n_channels # n_channels == n_sources self.n_sources, self.n_channels = n_sources, n_channels self.n_samples = n_samples if not hasattr(self, "demix_filter"): W = np.eye(n_sources, n_channels, dtype=np.float64) else: if self.demix_filter is None: W = None else: # To avoid overwriting ``demix_filter`` given by keyword arguments. W = self.demix_filter.copy() Z = whiten(X) self.whitened_input = Z self.demix_filter = W self.output = self.separate(Z, demix_filter=W, use_whitening=False) def update_once(self) -> None: r"""Update demixing filters once.""" raise NotImplementedError("Implement 'update_once' method.") def separate( self, input: np.ndarray, demix_filter: np.ndarray, use_whitening: bool = True ) -> np.ndarray: r"""Separate ``input`` using ``demixing_filter``. If ``use_whitening=True``, we apply whitening to input mixture :math:`\boldsymbol{x}_{t}`. .. math:: \boldsymbol{y}_{t} &= \boldsymbol{W}\boldsymbol{z}_{t}, \\ \boldsymbol{z}_{t} &= \boldsymbol{\Lambda}^{-\frac{1}{2}} \ \boldsymbol{\Gamma}^{\mathsf{T}}\boldsymbol{x}_{t}, \\ \boldsymbol{\Lambda} &:= \mathrm{diag}(\lambda_{1},\ldots,\lambda_{m},\ldots,\lambda_{M}) \ \in\mathbb{R}^{M\times M}, \\ \boldsymbol{\Gamma} &:= (\boldsymbol{\gamma}_{1}, \ldots, \boldsymbol{\gamma}_{m}, \ldots, \boldsymbol{\gamma}_{M}) \ \in\mathbb{R}^{M\times M}, where :math:`\lambda_{m}` and :math:`\boldsymbol{\gamma}_{m}` are an eigenvalue and eigenvector of :math:`\sum_{t}\boldsymbol{x}_{t}\boldsymbol{x}_{t}^{\mathsf{T}}`, respectively. Otherwise (``use_whitening=False``), we do not apply whitening. .. math:: \boldsymbol{y}_{t} = \boldsymbol{W}\boldsymbol{x}_{t}. Args: input (numpy.ndarray): The mixture signal in time-domain. The shape is (n_channels, n_samples). demix_filter (numpy.ndarray): The demixing filters to separate ``input``. The shape is (n_sources, n_channels). use_whitening (bool): If ``use_whitening=True``, use_whitening (sphering) is applied to ``input``. Default: ``True``. Returns: numpy.ndarray of the separated signal in time-domain. The shape is (n_sources, n_samples). """ if use_whitening: whitened_input = whiten(input) else: whitened_input = input output = demix_filter @ whitened_input return output def compute_loss(self) -> float: r"""Compute loss :math:`\mathcal{L}`. :math:`\mathcal{L}` is given as follows: .. math:: \mathcal{L} \ &= \frac{1}{T}\sum_{t,n}G(y_{tn}) \\ G(y_{tn}) \ &= - \log p(y_{tn}) Returns: Computed loss. """ Z, W = self.whitened_input, self.demix_filter Y = self.separate(Z, demix_filter=W, use_whitening=False) loss = np.mean(self.contrast_fn(Y), axis=-1) loss = loss.sum().item() return loss class GradICA(GradICABase): r"""Independent component analysis (ICA) using the gradient descent. Args: step_size (float): A step size of the gradient descent. Default: ``1e-1``. contrast_fn (callable): A contrast function which corresponds to :math:`-\log p(y_{tn})`. This function is expected to receive (n_channels, n_samples) and return (n_channels, n_samples). score_fn (callable): A score function which corresponds to the partial derivative of the contrast function. This function is expected to receive (n_channels, n_samples) and return (n_channels, n_samples). callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. is_holonomic (bool): If ``is_holonomic=True``, Holonomic-type update is used. Otherwise, Nonholonomic-type update is used. Default: ``False``. record_loss (bool): Record the loss at each iteration of the gradient descent if ``record_loss=True``. Default: ``True``. Examples: Update demixing filters using Holonomic-type update: .. code-block:: python >>> def contrast_fn(y): ... return np.abs(y) >>> def score_fn(y): ... return np.sign(y) >>> n_channels, n_samples = 2, 160000 >>> waveform_mix = np.random.randn(n_channels, n_samples) >>> ica = GradICA( ... contrast_fn=contrast_fn, ... score_fn=score_fn, ... is_holonomic=True, ... ) >>> waveform_est = ica(waveform_mix, n_iter=1000) >>> print(waveform_mix.shape, waveform_est.shape) (2, 160000), (2, 160000) Update demixing filters using Nonholonomic-type update: .. code-block:: python >>> def contrast_fn(y): ... return np.abs(y) >>> def score_fn(y): ... return np.sign(y) >>> n_channels, n_samples = 2, 160000 >>> waveform_mix = np.random.randn(n_channels, n_samples) >>> ica = GradICA( ... contrast_fn=contrast_fn, ... score_fn=score_fn, ... is_holonomic=False, ... ) >>> waveform_est = ica(waveform_mix, n_iter=1000) >>> print(waveform_mix.shape, waveform_est.shape) (2, 160000), (2, 160000) """ def __init__( self, step_size: float = 1e-1, contrast_fn: Callable[[np.ndarray], np.ndarray] = None, score_fn: Callable[[np.ndarray], np.ndarray] = None, callbacks: Optional[ Union[Callable[["GradICA"], None], List[Callable[["GradICA"], None]]] ] = None, is_holonomic: bool = False, record_loss: bool = True, ) -> None: super().__init__( step_size=step_size, contrast_fn=contrast_fn, score_fn=score_fn, callbacks=callbacks, record_loss=record_loss, ) self.is_holonomic = is_holonomic def __repr__(self) -> str: s = "GradICA(" s += "step_size={step_size}" s += ", is_holonomic={is_holonomic}" s += ", record_loss={record_loss}" s += ")" return s.format(**self.__dict__) def update_once(self) -> None: r"""Update demixing filters once using the gradient descent. If ``is_holonomic=True``, demixing filters are updated as follows: .. math:: \boldsymbol{W} \leftarrow\boldsymbol{W} - \eta\left(\frac{1}{T}\sum_{t} \ \boldsymbol{\phi}(\boldsymbol{y}_{t})\boldsymbol{y}_{t}^{\mathsf{T}} \ -\boldsymbol{I}\right)\boldsymbol{W}^{-\mathsf{T}}, where .. math:: \boldsymbol{\phi}(\boldsymbol{y}_{t}) &= \left(\phi(y_{t1}),\ldots,\phi(y_{tN})\right)^{\mathsf{T}}\in\mathbb{R}^{N}, \\ \phi(y_{tn}) &= \frac{\partial G(y_{tn})}{\partial y_{tn}}, \\ G(y_{tn}) &= -\log p(y_{tn}). Otherwise (``is_holonomic=False``), .. math:: \boldsymbol{W} \leftarrow\boldsymbol{W} - \eta\cdot\mathrm{offdiag}\left(\frac{1}{T}\sum_{t} \ \boldsymbol{\phi}(\boldsymbol{y}_{t})\boldsymbol{y}_{t}^{\mathsf{T}}\right) \ \boldsymbol{W}^{-\mathsf{T}}. """ X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Phi = self.score_fn(Y) PhiY = np.mean(Phi[:, np.newaxis, :] * Y[np.newaxis, :, :], axis=-1) W_inv = np.linalg.inv(W) W_inv_trans = W_inv.transpose(1, 0) eye = np.eye(self.n_sources) if self.is_holonomic: delta = (PhiY - eye) @ W_inv_trans else: delta = ((1 - eye) * PhiY) @ W_inv_trans W = W - self.step_size * delta Y = self.separate(X, demix_filter=W) self.demix_filter = W self.output = Y class NaturalGradICA(GradICABase): r"""Independent component analysis (ICA) using the natural gradient descent [#amari1995new]_. Args: step_size (float): A step size of the gradient descent. Default: ``1e-1``. contrast_fn (callable): A contrast function which corresponds to :math:`-\log p(y_{tn})`. This function is expected to receive (n_channels, n_samples) and return (n_channels, n_samples). score_fn (callable): A score function which corresponds to the partial derivative of the contrast function. This function is expected to receive (n_channels, n_samples) and return (n_channels, n_samples). callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. is_holonomic (bool): If ``is_holonomic=True``, Holonomic-type update is used. Otherwise, Nonholonomic-type update is used. Default: ``False``. record_loss (bool): Record the loss at each iteration of the gradient descent if ``record_loss=True``. Default: ``True``. Examples: Update demixing filters using Holonomic-type update: .. code-block:: python >>> def contrast_fn(y): ... return np.abs(y) >>> def score_fn(y): ... return np.sign(y) >>> n_channels, n_samples = 2, 160000 >>> waveform_mix = np.random.randn(n_channels, n_samples) >>> ica = NaturalGradICA( ... contrast_fn=contrast_fn, ... score_fn=score_fn, ... is_holonomic=True, ... ) >>> waveform_est = ica(waveform_mix, n_iter=100) >>> print(waveform_mix.shape, waveform_est.shape) (2, 160000), (2, 160000) Update demixing filters using Nonholonomic-type update: .. code-block:: python >>> def contrast_fn(y): ... return np.abs(y) >>> def score_fn(y): ... return np.sign(y) >>> n_channels, n_samples = 2, 160000 >>> waveform_mix = np.random.randn(n_channels, n_samples) >>> ica = NaturalGradICA( ... contrast_fn=contrast_fn, ... score_fn=score_fn, ... is_holonomic=False, ... ) >>> waveform_est = ica(waveform_mix, n_iter=100) >>> print(waveform_mix.shape, waveform_est.shape) (2, 160000), (2, 160000) .. [#amari1995new] S. Amari, A. Cichocki, and H. H. Yang, "A new learning algorithm for blind signal separation," in *Proc. NIPS.*, pp. 757-763, 1996. """ def __init__( self, step_size: float = 1e-1, contrast_fn: Callable[[np.ndarray], np.ndarray] = None, score_fn: Callable[[np.ndarray], np.ndarray] = None, callbacks: Optional[ Union[Callable[["GradICA"], None], List[Callable[["GradICA"], None]]] ] = None, is_holonomic: bool = False, record_loss: bool = True, ) -> None: super().__init__( step_size=step_size, contrast_fn=contrast_fn, score_fn=score_fn, callbacks=callbacks, record_loss=record_loss, ) self.is_holonomic = is_holonomic def __repr__(self) -> str: s = "NaturalGradICA(" s += "step_size={step_size}" s += ", is_holonomic={is_holonomic}" s += ", record_loss={record_loss}" s += ")" return s.format(**self.__dict__) def update_once(self) -> None: r"""Update demixing filters once using the natural gradient descent. If ``is_holonomic=True``, demixing filters are updated as follows: .. math:: \boldsymbol{W} \leftarrow\boldsymbol{W} - \eta\left(\frac{1}{T}\sum_{t} \ \boldsymbol{\phi}(\boldsymbol{y}_{t})\boldsymbol{y}_{t}^{\mathsf{T}} \ -\boldsymbol{I}\right)\boldsymbol{W}, where .. math:: \boldsymbol{\phi}(\boldsymbol{y}_{t}) &= \left(\phi(y_{t1}),\ldots,\phi(y_{tN})\right)^{\mathsf{T}}\in\mathbb{R}^{N}, \\ \phi(y_{tn}) &= \frac{\partial G(y_{tn})}{\partial y_{tn}}, \\ G(y_{tn}) &= -\log p(y_{tn}). Otherwise (``is_holonomic=False``), .. math:: \boldsymbol{W} \leftarrow\boldsymbol{W} - \eta\cdot\mathrm{offdiag}\left(\frac{1}{T}\sum_{t} \ \boldsymbol{\phi}(\boldsymbol{y}_{t})\boldsymbol{y}_{t}^{\mathsf{T}}\right) \ \boldsymbol{W}. """ X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Phi = self.score_fn(Y) PhiY = np.mean(Phi[:, np.newaxis, :] * Y[np.newaxis, :, :], axis=-1) eye = np.eye(self.n_sources) if self.is_holonomic: delta = (PhiY - eye) @ W else: delta = ((1 - eye) * PhiY) @ W W = W - self.step_size * delta Y = self.separate(X, demix_filter=W) self.demix_filter = W self.output = Y class FastICA(FastICABase): r"""Fast independent component analysis (FastICA) [#hyvarinen1999fast]_. In FastICA, a whitening (sphering) is applied to input signal. .. math:: \boldsymbol{z}_{t} &= \boldsymbol{\Lambda}^{-\frac{1}{2}} \ \boldsymbol{\Gamma}^{\mathsf{T}}\boldsymbol{x}_{t}, \\ \boldsymbol{\Lambda} &:= \mathrm{diag}(\lambda_{1},\ldots,\lambda_{m},\ldots,\lambda_{M}) \ \in\mathbb{R}^{M\times M}, \\ \boldsymbol{\Gamma} &:= (\boldsymbol{\gamma}_{1}, \ldots, \boldsymbol{\gamma}_{m}, \ldots, \boldsymbol{\gamma}_{M}) \ \in\mathbb{R}^{M\times M}, where :math:`\lambda_{m}` and :math:`\boldsymbol{\gamma}_{m}` are an eigenvalue and eigenvector of :math:`\sum_{t}\boldsymbol{x}_{t}\boldsymbol{x}_{t}^{\mathsf{T}}`, respectively. Furthermore, :math:`\boldsymbol{W}` is constrained to be orthogonal. .. math:: \boldsymbol{W}\boldsymbol{W}^{\mathsf{T}} = \boldsymbol{I} Args: contrast_fn (callable): A contrast function which corresponds to :math:`-\log p(y_{tn})`. This function is expected to receive (n_channels, n_samples) and return (n_channels, n_samples). score_fn (callable): A score function which corresponds to the partial derivative of the contrast function. This function is expected to receive (n_channels, n_samples) and return (n_channels, n_samples). d_score_fn (callable): A partial derivative of the score function. This function is expected to return the same shape tensor as the input. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. record_loss (bool): Record the loss at each of the fixed-point iteration if ``record_loss=True``. Default: ``True``. Examples: .. code-block:: python >>> def contrast_fn(y): ... return np.log(1 + np.exp(y)) >>> def score_fn(y): ... return 1 / (1 + np.exp(-y)) >>> def d_score_fn(y): ... sigmoid_y = 1 / (1 + np.exp(-y)) ... return sigmoid_y * (1 - sigmoid_y) >>> n_channels, n_samples = 2, 160000 >>> waveform_mix = np.random.randn(n_channels, n_samples) >>> ica = FastICA(contrast_fn=contrast_fn, score_fn=score_fn, d_score_fn=d_score_fn) >>> waveform_est = ica(waveform_mix, n_iter=10) >>> print(waveform_mix.shape, waveform_est.shape) (2, 160000), (2, 160000) .. [#hyvarinen1999fast] A. Hyvärinen, "Fast and robust fixed-point algorithms for independent component analysis," *IEEE Trans. on Neural Netw.*, vol. 10, no. 3, pp. 626-634, 1999. """ def __init__( self, contrast_fn: Callable[[np.ndarray], np.ndarray] = None, score_fn: Callable[[np.ndarray], np.ndarray] = None, d_score_fn: Callable[[np.ndarray], np.ndarray] = None, callbacks: Optional[ Union[Callable[["FastICA"], None], List[Callable[["FastICA"], None]]] ] = None, record_loss: bool = True, ) -> None: super().__init__( contrast_fn=contrast_fn, score_fn=score_fn, d_score_fn=d_score_fn, callbacks=callbacks, record_loss=record_loss, ) def update_once(self) -> None: r"""Update demixing filters once using the fixed-point iteration algorithm. For :math:`n=1,\dots,N`, the demixing flter :math:`\boldsymbol{w}_{n}` is updated sequentially, .. math:: y_{tn} &=\boldsymbol{w}_{n}^{\mathsf{T}}\boldsymbol{z}_{t}, \\ \boldsymbol{w}_{n}^{+} &\leftarrow \frac{1}{T}\sum_{t}\phi(y_{tn})\boldsymbol{z}_{tn} \ - \frac{1}{T}\sum_{t}\frac{\partial\phi(y_{tn})}{\partial y_{tn}} \ \boldsymbol{w}_{n}, \\ \boldsymbol{w}_{n}^{+} &\leftarrow\boldsymbol{w}_{n}^{+} \ - \sum_{n'=1}^{n-1}\boldsymbol{w}_{n'}^{\mathsf{T}}\boldsymbol{w}_{n}^{+} \ \boldsymbol{w}_{n}^{+}, \\ \boldsymbol{w}_{n} &\leftarrow \frac{\boldsymbol{w}_{n}^{+}}{\|\boldsymbol{w}_{n}^{+}\|}. """ Z, W = self.whitened_input, self.demix_filter for src_idx in range(self.n_sources): w_n = W[src_idx] # (n_channels,) y_n = w_n @ Z # (n_samples,) Gw_n = np.mean(self.d_score_fn(y_n), axis=-1) * w_n Gz = np.mean(self.score_fn(y_n) * Z, axis=-1) w_n = Gw_n - Gz if src_idx > 0: W_n = W[:src_idx] # (src_idx - 1, n_channels) scale = np.sum(W_n * w_n, axis=-1, keepdims=True) w_n = w_n - np.sum(scale * W_n, axis=0) norm = np.linalg.norm(w_n) W[src_idx] = w_n / norm Y = self.separate(Z, demix_filter=W, use_whitening=False) self.demix_filter = W self.output = Y class GradLaplaceICA(GradICA): r"""Independent component analysis (ICA) using the gradient descent on a Laplace distribution. We assume :math:`y_{ijn}` follows a Laplace distribution. .. math:: p(y_{ijn})\propto\exp(|y_{ijn}|) Args: step_size (float): A step size of the gradient descent. Default: ``1e-1``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. is_holonomic (bool): If ``is_holonomic=True``, Holonomic-type update is used. Otherwise, Nonholonomic-type update is used. Default: ``False``. record_loss (bool): Record the loss at each iteration of the gradient descent \ if ``record_loss=True``. Default: ``True``. Examples: Update demixing filters using Holonomic-type update: .. code-block:: python >>> n_channels, n_samples = 2, 160000 >>> waveform_mix = np.random.randn(n_channels, n_samples) >>> ica = GradLaplaceICA(is_holonomic=True) >>> waveform_est = ica(waveform_mix, n_iter=1000) >>> print(waveform_mix.shape, waveform_est.shape) (2, 160000), (2, 160000) Update demixing filters using Nonholonomic-type update: .. code-block:: python >>> n_channels, n_samples = 2, 160000 >>> waveform_mix = np.random.randn(n_channels, n_samples) >>> ica = GradLaplaceICA(is_holonomic=False) >>> waveform_est = ica(waveform_mix, n_iter=1000) >>> print(waveform_mix.shape, waveform_est.shape) (2, 160000), (2, 160000) """ def __init__( self, step_size: float = 1e-1, callbacks: Optional[ Union[Callable[["GradLaplaceICA"], None], List[Callable[["GradLaplaceICA"], None]]] ] = None, is_holonomic: bool = False, record_loss: bool = True, ) -> None: def contrast_fn(input): return np.abs(input) def score_fn(input): return np.sign(input) super().__init__( step_size=step_size, contrast_fn=contrast_fn, score_fn=score_fn, callbacks=callbacks, is_holonomic=is_holonomic, record_loss=record_loss, ) def __repr__(self) -> str: s = "GradLaplaceICA(" s += "step_size={step_size}" s += ", is_holonomic={is_holonomic}" s += ", record_loss={record_loss}" s += ")" return s.format(**self.__dict__) def update_once(self) -> None: r"""Update demixing filters once using the gradient descent. If ``is_holonomic=True``, demixing filters are updated as follows: .. math:: \boldsymbol{W} \leftarrow\boldsymbol{W} - \eta\left(\frac{1}{T}\sum_{t} \ \boldsymbol{\phi}(\boldsymbol{y}_{t})\boldsymbol{y}_{t}^{\mathsf{T}} \ -\boldsymbol{I}\right)\boldsymbol{W}^{-\mathsf{T}}, where .. math:: \boldsymbol{\phi}(\boldsymbol{y}_{t}) = \left(\mathrm{sign}(y_{t1}),\ldots,\mathrm{sign}(y_{tN})\right)^{\mathsf{T}} \ \in\mathbb{R}^{N}. Otherwise (``is_holonomic=False``), .. math:: \boldsymbol{W} \leftarrow\boldsymbol{W} - \eta\cdot\mathrm{offdiag}\left(\frac{1}{T}\sum_{t} \ \boldsymbol{\phi}(\boldsymbol{y}_{t})\boldsymbol{y}_{t}^{\mathsf{T}}\right) \ \boldsymbol{W}^{-\mathsf{T}}. """ super().update_once() def compute_loss(self) -> float: r"""Compute loss :math:`\mathcal{L}`. :math:`\mathcal{L}` is given as follows: .. math:: \mathcal{L} \ &= \frac{1}{T}\sum_{t,n}|y_{tn}| \ - \log|\det\boldsymbol{W}| \\ Returns: Computed loss. """ return super().compute_loss() class NaturalGradLaplaceICA(NaturalGradICA): r"""Independent component analysis (ICA) using the natural gradient descent \ on a Laplace distribution. We assume :math:`y_{ijn}` follows a Laplace distribution. .. math:: p(y_{ijn})\propto\exp(|y_{ijn}|) Args: step_size (float): A step size of the gradient descent. Default: ``1e-1``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. is_holonomic (bool): If ``is_holonomic=True``, Holonomic-type update is used. Otherwise, Nonholonomic-type update is used. Default: ``False``. record_loss (bool): Record the loss at each iteration of the gradient descent \ if ``record_loss=True``. Default: ``True``. Examples: Update demixing filters using Holonomic-type update: .. code-block:: python >>> n_channels, n_samples = 2, 160000 >>> waveform_mix = np.random.randn(n_channels, n_samples) >>> ica = NaturalGradLaplaceICA(is_holonomic=True) >>> waveform_est = ica(waveform_mix, n_iter=100) >>> print(waveform_mix.shape, waveform_est.shape) (2, 160000), (2, 160000) Update demixing filters using Nonholonomic-type update: .. code-block:: python >>> n_channels, n_samples = 2, 160000 >>> waveform_mix = np.random.randn(n_channels, n_samples) >>> ica = NaturalGradLaplaceICA(is_holonomic=False) >>> waveform_est = ica(waveform_mix, n_iter=100) >>> print(waveform_mix.shape, waveform_est.shape) (2, 160000), (2, 160000) """ def __init__( self, step_size: float = 1e-1, callbacks: Optional[ Union[ Callable[["NaturalGradLaplaceICA"], None], List[Callable[["NaturalGradLaplaceICA"], None]], ] ] = None, is_holonomic: bool = False, record_loss: bool = True, ) -> None: def contrast_fn(input): return np.abs(input) def score_fn(input): return np.sign(input) super().__init__( step_size=step_size, contrast_fn=contrast_fn, score_fn=score_fn, callbacks=callbacks, is_holonomic=is_holonomic, record_loss=record_loss, ) def __repr__(self) -> str: s = "NaturalGradLaplaceICA(" s += "step_size={step_size}" s += ", is_holonomic={is_holonomic}" s += ", record_loss={record_loss}" s += ")" return s.format(**self.__dict__) def update_once(self) -> None: r"""Update demixing filters once using the natural gradient descent. If ``is_holonomic=True``, demixing filters are updated as follows: .. math:: \boldsymbol{W} \leftarrow\boldsymbol{W} - \eta\left(\frac{1}{T}\sum_{t} \ \boldsymbol{\phi}(\boldsymbol{y}_{t})\boldsymbol{y}_{t}^{\mathsf{T}} \ -\boldsymbol{I}\right)\boldsymbol{W}, where .. math:: \boldsymbol{\phi}(\boldsymbol{y}_{t}) = \left(\mathrm{sign}(y_{t1}),\ldots,\mathrm{sign}(y_{tN})\right)^{\mathsf{T}} \ \in\mathbb{R}^{N}. Otherwise (``is_holonomic=False``), .. math:: \boldsymbol{W} \leftarrow\boldsymbol{W} - \eta\cdot\mathrm{offdiag}\left(\frac{1}{T}\sum_{t} \ \boldsymbol{\phi}(\boldsymbol{y}_{t})\boldsymbol{y}_{t}^{\mathsf{T}}\right) \ \boldsymbol{W}. """ super().update_once() def compute_loss(self) -> float: r"""Compute loss :math:`\mathcal{L}`. :math:`\mathcal{L}` is given as follows: .. math:: \mathcal{L} \ &= \frac{1}{T}\sum_{t,n}|y_{tn}| \ - \log|\det\boldsymbol{W}| \\ Returns: Computed loss. """ return super().compute_loss() ================================================ FILE: ssspy/bss/ilrma.py ================================================ import functools import warnings from typing import Callable, Iterable, List, Optional, Tuple, Union import numpy as np from ..algorithm import ( MINIMAL_DISTORTION_PRINCIPLE_KEYWORDS, PROJECTION_BACK_KEYWORDS, minimal_distortion_principle, projection_back, ) from ..special.flooring import identity, max_flooring from ..utils.flooring import choose_flooring_fn from ..utils.select_pair import sequential_pair_selector from ._update_spatial_model import ( update_by_ip1, update_by_ip2, update_by_ipa, update_by_iss1, update_by_iss2, ) from .base import IterativeMethodBase __all__ = ["GaussILRMA", "TILRMA", "GGDILRMA"] spatial_algorithms = ["IP", "IP1", "IP2", "ISS", "ISS1", "ISS2", "IPA"] source_algorithms = ["MM", "ME"] EPS = 1e-10 class ILRMABase(IterativeMethodBase): r"""Base class of independent low-rank matrix analysis (ILRMA). Args: n_basis (int): Number of NMF bases. partitioning (bool): Whether to use partioning function. Default: ``False``. flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` explicitly. Default: ``True``. record_loss (bool): Record the loss at each iteration of the update algorithm if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back. Default: ``0``. rng (numpy.random.Generator, optioinal): Random number generator. This is mainly used to randomly initialize NMF. If ``None`` is given, ``np.random.default_rng()`` is used. Default: ``None``. """ def __init__( self, n_basis: int, partitioning: bool = False, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), callbacks: Optional[ Union[Callable[["ILRMABase"], None], List[Callable[["ILRMABase"], None]]] ] = None, scale_restoration: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, rng: Optional[np.random.Generator] = None, ) -> None: super().__init__(callbacks=callbacks, record_loss=record_loss) self.n_basis = n_basis self.partitioning = partitioning if flooring_fn is None: self.flooring_fn = identity else: self.flooring_fn = flooring_fn self.input = None self.scale_restoration = scale_restoration if reference_id is None and scale_restoration: raise ValueError("Specify 'reference_id' if scale_restoration=True.") else: self.reference_id = reference_id if rng is None: rng = np.random.default_rng() self.rng = rng def __call__( self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs ) -> np.ndarray: r"""Separate a frequency-domain multichannel signal. Args: input (numpy.ndarray): The mixture signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). n_iter (int): The number of iterations of demixing filter updates. Default: ``100``. initial_call (bool): If ``True``, perform callbacks (and computation of loss if necessary) before iterations. Returns: numpy.ndarray of the separated signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). """ self.input = input.copy() self._reset(flooring_fn=self.flooring_fn, **kwargs) super().__call__(n_iter=n_iter, initial_call=initial_call) if self.scale_restoration: self.restore_scale() self.output = self.separate(self.input, demix_filter=self.demix_filter) return self.output def __repr__(self) -> str: s = "ILRMA(" s += "n_basis={n_basis}" s += ", partitioning={partitioning}" s += ", scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) def _reset( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", **kwargs, ) -> None: r"""Reset attributes by given keyword arguments. We also set variance of Gaussian distribution. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. kwargs: Keyword arguments to set as attributes of ILRMA. """ assert self.input is not None, "Specify data!" flooring_fn = choose_flooring_fn(flooring_fn, method=self) for key in kwargs.keys(): setattr(self, key, kwargs[key]) X = self.input n_channels, n_bins, n_frames = X.shape n_sources = n_channels # n_channels == n_sources self.n_sources, self.n_channels = n_sources, n_channels self.n_bins, self.n_frames = n_bins, n_frames if not hasattr(self, "demix_filter"): W = np.eye(n_sources, n_channels, dtype=np.complex128) W = np.tile(W, reps=(n_bins, 1, 1)) else: if self.demix_filter is None: W = None else: # To avoid overwriting ``demix_filter`` given by keyword arguments. W = self.demix_filter.copy() self.demix_filter = W self.output = self.separate(X, demix_filter=W) self._init_nmf(flooring_fn=flooring_fn, rng=self.rng) def _init_nmf( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", rng: Optional[np.random.Generator] = None, ) -> None: r"""Initialize NMF. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. rng (numpy.random.Generator, optional): Random number generator. If ``None`` is given, ``np.random.default_rng()`` is used. Default: ``None``. """ n_basis = self.n_basis n_sources = self.n_sources n_bins, n_frames = self.n_bins, self.n_frames flooring_fn = choose_flooring_fn(flooring_fn, method=self) if rng is None: rng = np.random.default_rng() if self.partitioning: if not hasattr(self, "latent"): Z = rng.random((n_sources, n_basis)) Z = Z / Z.sum(axis=0) Z = flooring_fn(Z) else: # To avoid overwriting. Z = self.latent.copy() if not hasattr(self, "basis"): T = rng.random((n_bins, n_basis)) T = flooring_fn(T) else: # To avoid overwriting. T = self.basis.copy() if not hasattr(self, "activation"): V = rng.random((n_basis, n_frames)) V = flooring_fn(V) else: # To avoid overwriting. V = self.activation.copy() self.latent = Z self.basis, self.activation = T, V else: if not hasattr(self, "basis"): T = rng.random((n_sources, n_bins, n_basis)) T = flooring_fn(T) else: # To avoid overwriting. T = self.basis.copy() if not hasattr(self, "activation"): V = rng.random((n_sources, n_basis, n_frames)) V = flooring_fn(V) else: # To avoid overwriting. V = self.activation.copy() self.basis, self.activation = T, V def separate(self, input: np.ndarray, demix_filter: np.ndarray) -> np.ndarray: r"""Separate ``input`` using ``demixing_filter``. .. math:: \boldsymbol{y}_{ij} = \boldsymbol{W}_{i}\boldsymbol{x}_{ij} Args: input (numpy.ndarray): The mixture signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). demix_filter (numpy.ndarray): The demixing filters to separate ``input``. The shape is (n_bins, n_sources, n_channels). Returns: numpy.ndarray of the separated signal in frequency-domain. The shape is (n_sources, n_bins, n_frames). """ X, W = input, demix_filter Y = W @ X.transpose(1, 0, 2) output = Y.transpose(1, 0, 2) return output def reconstruct_nmf( self, basis: np.ndarray, activation: np.ndarray, latent: Optional[np.ndarray] = None ) -> np.ndarray: r"""Reconstruct NMF. Args: basis (numpy.ndarray): Basis matrix. The shape is (n_sources, n_basis, n_frames) if latent is given. Otherwise, (n_basis, n_frames). activation (numpy.ndarray): Activation matrix. The shape is (n_sources, n_bins, n_basis) if latent is given. Otherwise, (n_bins, n_basis). latent (numpy.ndarray, optional): Latent variable that determines number of bases per source. Returns: numpy.ndarray of theconstructed NMF. The shape is (n_sources, n_bins, n_frames). """ if latent is None: T, V = basis, activation R = T @ V else: Z = latent T, V = basis, activation TV = T[:, :, np.newaxis] * V[np.newaxis, :, :] R = np.sum(Z[:, np.newaxis, :, np.newaxis] * TV[np.newaxis, :, :, :], axis=2) return R def update_once(self) -> None: r"""Update demixing filters once.""" raise NotImplementedError("Implement 'update_once' method.") def normalize( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Normalize demixing filters and NMF parameters. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ normalization = self.normalization flooring_fn = choose_flooring_fn(flooring_fn, method=self) assert normalization, "Set normalization." if type(normalization) is bool: # when normalization is True normalization = "power" if normalization == "power": self.normalize_by_power(flooring_fn=flooring_fn) elif normalization == "projection_back": self.normalize_by_projection_back() else: raise NotImplementedError("Normalization {} is not implemented.".format(normalization)) def normalize_by_power( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Normalize demixing filters and NMF parameters by power. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. Demixing filters are normalized by .. math:: \boldsymbol{w}_{in} \leftarrow\frac{\boldsymbol{w}_{in}}{\psi_{in}}, where .. math:: \psi_{in} = \sqrt{\frac{1}{IJ}|\boldsymbol{w}_{in}^{\mathsf{H}} \boldsymbol{x}_{ij}|^{2}}. For NMF parameters, .. math:: t_{ik} &\leftarrow t_{ik}\sum_{n}\frac{z_{nk}}{\psi_{in}^{p}}, \\ z_{nk} &\leftarrow \frac{\frac{z_{nk}}{\psi_{in}^{p}}} {\sum_{n'}\frac{z_{n'k}}{\psi_{in'}^{p}}}, if ``self.partitioning=True``. Otherwise, .. math:: t_{ikn} \leftarrow\frac{t_{ikn}}{\psi_{in}^{p}}. """ p = self.domain flooring_fn = choose_flooring_fn(flooring_fn, method=self) if self.demix_filter is None: Y = self.output else: X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Y2 = np.mean(np.abs(Y) ** 2, axis=(-2, -1)) psi = np.sqrt(Y2) psi = flooring_fn(psi) if self.partitioning: Z, T = self.latent, self.basis Z_psi = Z / (psi[:, np.newaxis] ** p) scale = np.sum(Z_psi, axis=0) T = T * scale[np.newaxis, :] Z = Z_psi / scale self.latent, self.basis = Z, T else: T = self.basis T = T / (psi[:, np.newaxis, np.newaxis] ** p) self.basis = T if self.demix_filter is None: Y = Y / psi[:, np.newaxis, np.newaxis] self.output = Y else: W = self.demix_filter W = W / psi[np.newaxis, :, np.newaxis] self.demix_filter = W def normalize_by_projection_back(self) -> None: r"""Normalize demixing filters and NMF parameters by projection back. Demixing filters are normalized by .. math:: \boldsymbol{w}_{in} \leftarrow\frac{\boldsymbol{w}_{in}}{\psi_{in}}, where .. math:: \boldsymbol{\psi}_{i} = \boldsymbol{W}_{i}^{-1}\boldsymbol{e}_{m_{\mathrm{ref}}}. For NMF parameters, .. math:: t_{ikn} \leftarrow\frac{t_{ikn}}{\psi_{in}^{p}}. """ p = self.domain reference_id = self.reference_id X = self.input if reference_id is None: warnings.warn( "channel 0 is used for reference_id \ of projection-back-based normalization.", UserWarning, ) reference_id = 0 if self.partitioning: raise NotImplementedError( "Projection-back-based normalization is not applicable with partitioning function." ) else: T = self.basis if self.demix_filter is None: Y = self.output Y = Y.transpose(1, 0, 2) # (n_bins, n_sources, n_frames) X = X.transpose(1, 0, 2) # (n_bins, n_channels, n_frames) Y_Hermite = Y.transpose(0, 2, 1).conj() # (n_bins, n_frames, n_sources) XY_Hermite = X @ Y_Hermite # (n_bins, n_channels, n_sources) YY_Hermite = Y @ Y_Hermite # (n_bins, n_sources, n_sources) scale = XY_Hermite @ np.linalg.inv(YY_Hermite) # (n_bins, n_channels, n_sources) scale = scale[..., reference_id, :] # (n_bins, n_sources) Y_scaled = Y * scale[..., np.newaxis] # (n_bins, n_sources, n_frames) Y = Y_scaled.swapaxes(-3, -2) # (n_sources, n_bins, n_frames) self.output = Y else: W = self.demix_filter scale = np.linalg.inv(W) scale = scale[:, reference_id, :] W = W * scale[:, :, np.newaxis] self.demix_filter = W scale = scale.transpose(1, 0) scale = np.abs(scale) ** p T = T * scale[:, :, np.newaxis] self.basis = T def compute_loss(self) -> float: r"""Compute loss :math:`\mathcal{L}`. Returns: Computed loss. """ raise NotImplementedError("Implement 'compute_loss' method.") def compute_logdet(self, demix_filter: np.ndarray) -> np.ndarray: r"""Compute log-determinant of demixing filter Args: demix_filter (numpy.ndarray): Demixing filters with shape of (n_bins, n_sources, n_channels). Returns: numpy.ndarray of computed log-determinant values. """ _, logdet = np.linalg.slogdet(demix_filter) # (n_bins,) return logdet def restore_scale(self) -> None: r"""Restore scale ambiguity. If ``self.scale_restoration="projection_back``, we use projection back technique. """ scale_restoration = self.scale_restoration assert scale_restoration, "Set self.scale_restoration=True." if type(scale_restoration) is bool: scale_restoration = PROJECTION_BACK_KEYWORDS[0] if scale_restoration in PROJECTION_BACK_KEYWORDS: self.apply_projection_back() elif scale_restoration in MINIMAL_DISTORTION_PRINCIPLE_KEYWORDS: self.apply_minimal_distortion_principle() else: raise ValueError("{} is not supported for scale restoration.".format(scale_restoration)) def apply_projection_back(self) -> None: r"""Apply projection back technique to estimated spectrograms.""" assert self.scale_restoration, "Set self.scale_restoration=True." X, W = self.input, self.demix_filter W_scaled = projection_back(W, reference_id=self.reference_id) Y_scaled = self.separate(X, demix_filter=W_scaled) self.output, self.demix_filter = Y_scaled, W_scaled def apply_minimal_distortion_principle(self) -> None: r"""Apply minimal distortion principle to estimated spectrograms.""" assert self.scale_restoration, "Set self.scale_restoration=True." X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Y_scaled = minimal_distortion_principle(Y, reference=X, reference_id=self.reference_id) X = X.transpose(1, 0, 2) Y = Y_scaled.transpose(1, 0, 2) X_Hermite = X.transpose(0, 2, 1).conj() W_scaled = Y @ X_Hermite @ np.linalg.inv(X @ X_Hermite) self.output, self.demix_filter = Y_scaled, W_scaled class GaussILRMA(ILRMABase): r"""Independent low-rank matrix analysis (ILRMA) [#kitamura2016determined]_ \ on Gaussian distribution. We assume :math:`y_{ijn}` follows a Gaussian distribution. .. math:: p(y_{ijn}) = \frac{1}{\pi r_{ijn}}\exp\left(-\frac{|y_{ijn}|^{2}}{r_{ijn}}\right), where .. math:: r_{ijn} = \left(\sum_{k}z_{nk}t_{ik}v_{kj}\right)^{\frac{2}{p}}, if ``partitioning=True``. Otherwise, .. math:: r_{ijn} = \left(\sum_{k}t_{ikn}v_{kjn}\right)^{\frac{2}{p}}. Args: n_basis (int): Number of NMF bases. spatial_algorithm (str): Algorithm for demixing filter updates. Choose ``IP``, ``IP1``, ``IP2``, ``ISS``, ``ISS1``, or ``ISS2``. Default: ``IP``. source_algorithm (str): Algorithm for source model updates. Choose ``MM`` or ``ME``. Default: ``MM``. domain (float): Domain parameter. Default: ``2``. partitioning (bool): Whether to use partioning function. Default: ``False``. flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. pair_selector (callable, optional): Selector to choose updaing pair in ``IP2`` and ``ISS2``. If ``None`` is given, ``sequential_pair_selector`` is used. Default: ``None``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. normalization (bool or str, optional): Normalization of demixing filters and NMF parameters. Choose ``power`` or ``projection_back``. Default: ``power``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` or ``minimal_distortion_principle``. Default: ``True``. record_loss (bool): Record the loss at each iteration of the update algorithm if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back and minimal distortion principle. Default: ``0``. rng (numpy.random.Generator, optioinal): Random number generator. This is mainly used to randomly initialize NMF. If ``None`` is given, ``np.random.default_rng()`` is used. Default: ``None``. Examples: Update demixing filters by IP: .. code-block:: python >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> ilrma = GaussILRMA( ... n_basis=2, ... spatial_algorithm="IP", ... rng=np.random.default_rng(42), ... ) >>> spectrogram_est = ilrma(spectrogram_mix, n_iter=100) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters by IP2: .. code-block:: python >>> from ssspy.utils.select_pair import sequential_pair_selector >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> ilrma = GaussILRMA( ... n_basis=2, ... spatial_algorithm="IP2", ... pair_selector=sequential_pair_selector, ... rng=np.random.default_rng(42), ... ) >>> spectrogram_est = ilrma(spectrogram_mix, n_iter=100) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters by ISS: .. code-block:: python >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> ilrma = GaussILRMA( ... n_basis=2, ... spatial_algorithm="ISS", ... rng=np.random.default_rng(42), ... ) >>> spectrogram_est = ilrma(spectrogram_mix, n_iter=100) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters by ISS2: .. code-block:: python >>> import functools >>> from ssspy.utils.select_pair import sequential_pair_selector >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> ilrma = GaussILRMA( ... n_basis=2, ... spatial_algorithm="ISS2", ... pair_selector=functools.partial(sequential_pair_selector, step=2), ... rng=np.random.default_rng(42), ... ) >>> spectrogram_est = ilrma(spectrogram_mix, n_iter=100) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters by IPA: .. code-block:: python >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> ilrma = GaussILRMA( ... n_basis=2, ... spatial_algorithm="IPA", ... rng=np.random.default_rng(42), ... ) >>> spectrogram_est = ilrma(spectrogram_mix, n_iter=100) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) .. [#kitamura2016determined] D. Kitamura, N. Ono, H. Sawada, H. Kameoka, and H. Saruwatari, \ "Determined blind source separation unifying independent vector analysis and \ nonnegative matrix factorization," \ *IEEE/ACM Trans. ASLP*, vol. 24, no. 9, pp. 1626-1641, 2016. """ _ipa_default_kwargs = {"lqpqm_normalization": True, "newton_iter": 1} _default_kwargs = _ipa_default_kwargs def __init__( self, n_basis: int, spatial_algorithm: str = "IP", source_algorithm: str = "MM", domain: float = 2, partitioning: bool = False, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), pair_selector: Optional[Callable[[int], Iterable[Tuple[int, int]]]] = None, callbacks: Optional[ Union[Callable[["GaussILRMA"], None], List[Callable[["GaussILRMA"], None]]] ] = None, normalization: Optional[Union[bool, str]] = True, scale_restoration: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, rng: Optional[np.random.Generator] = None, **kwargs, ) -> None: super().__init__( n_basis=n_basis, partitioning=partitioning, flooring_fn=flooring_fn, callbacks=callbacks, scale_restoration=scale_restoration, record_loss=record_loss, reference_id=reference_id, rng=rng, ) assert spatial_algorithm in spatial_algorithms, "Not support {}.".format(spatial_algorithm) assert source_algorithm in source_algorithms, "Not support {}.".format(source_algorithm) assert 0 < domain <= 2, "domain parameter should be chosen from [0, 2]." if source_algorithm == "ME": assert domain == 2, "domain parameter should be 2 when you specify ME algorithm." self.spatial_algorithm = spatial_algorithm self.source_algorithm = source_algorithm self.domain = domain self.normalization = normalization if pair_selector is None: if spatial_algorithm in ["IP2", "ISS2"]: self.pair_selector = sequential_pair_selector else: self.pair_selector = pair_selector if spatial_algorithm == "IPA": valid_keys = set(self.__class__._ipa_default_kwargs.keys()) else: valid_keys = set() invalid_keys = set(kwargs) - valid_keys assert invalid_keys == set(), "Invalid keywords {} are given.".format(invalid_keys) for key, value in kwargs.items(): setattr(self, key, value) # set default values if necessary for key in valid_keys: if not hasattr(self, key): value = self.__class__._default_kwargs[key] setattr(self, key, value) def __call__( self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs ) -> np.ndarray: r"""Separate a frequency-domain multichannel signal. Args: input (numpy.ndarray): The mixture signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). n_iter (int): The number of iterations of demixing filter updates. Default: ``100``. initial_call (bool): If ``True``, perform callbacks (and computation of loss if necessary) before iterations. Returns: numpy.ndarray of the separated signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). """ self.input = input.copy() self._reset(flooring_fn=self.flooring_fn, **kwargs) # Call __call__ of ILRMABase's parent, i.e. __call__ of IterativeMethodBase super(ILRMABase, self).__call__(n_iter=n_iter, initial_call=initial_call) if self.scale_restoration: self.restore_scale() if self.demix_filter is None: pass else: self.output = self.separate(self.input, demix_filter=self.demix_filter) return self.output def __repr__(self) -> str: s = "GaussILRMA(" s += "n_basis={n_basis}" s += ", spatial_algorithm={spatial_algorithm}" s += ", source_algorithm={source_algorithm}" s += ", domain={domain}" s += ", partitioning={partitioning}" s += ", normalization={normalization}" s += ", scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) def _reset( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", **kwargs, ) -> None: r"""Reset attributes by given keyword arguments. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. kwargs: Keyword arguments to set as attributes of ILRMA. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) super()._reset(flooring_fn=flooring_fn, **kwargs) if self.spatial_algorithm in ["ISS", "ISS1", "ISS2", "IPA"]: self.demix_filter = None def update_once( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update NMF parameters and demixing filters once. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) self.update_source_model(flooring_fn=flooring_fn) self.update_spatial_model(flooring_fn=flooring_fn) if self.normalization: self.normalize(flooring_fn=flooring_fn) def update_source_model( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update NMF bases, activations, and latent variables. - If ``source_algorithm`` is ``MM``, ``update_source_model_mm`` is called. - If ``source_algorithm`` is ``ME``, ``update_source_model_me`` is called. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) if self.source_algorithm == "MM": self.update_source_model_mm(flooring_fn=flooring_fn) elif self.source_algorithm == "ME": self.update_source_model_me(flooring_fn=flooring_fn) else: raise ValueError( "{}-algorithm-based source model updates are not supported.".format( self.source_algorithm ) ) def update_source_model_mm( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update NMF bases, activations, and latent variables by MM algorithm. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) if self.partitioning: self.update_latent_mm() self.update_basis_mm(flooring_fn=flooring_fn) self.update_activation_mm(flooring_fn=flooring_fn) def update_source_model_me( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update NMF bases, activations, and latent variables by ME algorithm. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) if self.domain != 2: raise ValueError("Domain parameter is expected 2, but given {}.".format(self.domain)) if self.partitioning: self.update_latent_me() self.update_basis_me(flooring_fn=flooring_fn) self.update_activation_me(flooring_fn=flooring_fn) def update_latent_mm(self) -> None: r"""Update latent variables in NMF by MM algorithm. Update :math:`z_{nk}` as follows: .. math:: z_{nk} &\leftarrow\left[\frac{\displaystyle\sum_{i,j}\frac{t_{ik}v_{kj}} {(\sum_{k'}z_{nk'}t_{ik'}v_{k'j})^{\frac{p+2}{p}}} |y_{ijn}|^{2}}{\displaystyle\sum_{i,j}\dfrac{t_{ik}v_{kj}}{\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}} \right]^{\frac{p}{p+2}}z_{nk} \\ z_{nk} &\leftarrow\frac{z_{nk}}{\sum_{n'}z_{n'k}}. """ p = self.domain if self.demix_filter is None: Y = self.output else: X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Y2 = np.abs(Y) ** 2 p2_p = (p + 2) / p p_p2 = p / (p + 2) Z = self.latent T, V = self.basis, self.activation TV = T[:, :, np.newaxis] * V[np.newaxis, :, :] ZTV = self.reconstruct_nmf(T, V, latent=Z) ZTVp2p = ZTV**p2_p TV_ZTVp2p = TV[np.newaxis, :, :, :] / ZTVp2p[:, :, np.newaxis, :] num = np.sum(TV_ZTVp2p * Y2[:, :, np.newaxis, :], axis=(1, 3)) TV_ZTV = TV[np.newaxis, :, :, :] / ZTV[:, :, np.newaxis, :] denom = np.sum(TV_ZTV, axis=(1, 3)) Z = ((num / denom) ** p_p2) * Z Z = Z / Z.sum(axis=0) self.latent = Z def update_basis_mm( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update NMF bases. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. Update :math:`t_{ikn}` as follows: .. math:: t_{ik} \leftarrow\left[ \frac{\displaystyle\sum_{j,n}\frac{z_{nk}v_{kj}} {(\sum_{k'}z_{nk'}t_{ik'}v_{k'j})^{\frac{p+2}{p}}} |y_{ijn}|^{2}}{\displaystyle\sum_{j,n} \dfrac{z_{nk}v_{kj}}{\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}} \right]^{\frac{p}{p+2}}t_{ik}, if ``partitioning=True``. Otherwise .. math:: t_{ikn} \leftarrow \left[\frac{\displaystyle\sum_{j} \dfrac{v_{kjn}}{(\sum_{k'}t_{ik'n}v_{k'jn})^{\frac{p+2}{p}}}|y_{ijn}|^{2}} {\displaystyle\sum_{j}\frac{v_{kjn}}{\sum_{k'}t_{ik'n}v_{k'jn}}}\right] ^{\frac{p}{p+2}}t_{ikn}. """ p = self.domain flooring_fn = choose_flooring_fn(flooring_fn, method=self) if self.demix_filter is None: Y = self.output else: X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Y2 = np.abs(Y) ** 2 p2_p = (p + 2) / p p_p2 = p / (p + 2) if self.partitioning: Z = self.latent T, V = self.basis, self.activation ZV = Z[:, :, np.newaxis] * V[np.newaxis, :, :] ZTV = self.reconstruct_nmf(T, V, latent=Z) ZTVp2p = ZTV**p2_p ZV_ZTVp2p = ZV[:, np.newaxis, :, :] / ZTVp2p[:, :, np.newaxis, :] num = np.sum(ZV_ZTVp2p * Y2[:, :, np.newaxis, :], axis=(0, 3)) ZV_ZTV = ZV[:, np.newaxis, :, :] / ZTV[:, :, np.newaxis, :] denom = np.sum(ZV_ZTV, axis=(0, 3)) else: T, V = self.basis, self.activation TV = self.reconstruct_nmf(T, V) TVp2p = TV**p2_p V_TVp2p = V[:, np.newaxis, :, :] / TVp2p[:, :, np.newaxis, :] num = np.sum(V_TVp2p * Y2[:, :, np.newaxis, :], axis=3) V_TV = V[:, np.newaxis, :, :] / TV[:, :, np.newaxis, :] denom = np.sum(V_TV, axis=3) T = ((num / denom) ** p_p2) * T T = flooring_fn(T) self.basis = T def update_activation_mm( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update NMF activations by MM algorithm. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. Update :math:`v_{kjn}` as follows: .. math:: v_{kj} \leftarrow\left[\frac{\displaystyle\sum_{i,n}\frac{z_{nk}t_{ik}} {(\sum_{k'}z_{nk'}t_{ik'}v_{k'j})^{\frac{p+2}{p}}} |y_{ijn}|^{2}}{\displaystyle\sum_{i,n}\dfrac{z_{nk}t_{ik}}{\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}} \right]^{\frac{p}{p+2}}v_{kj}, if ``partitioning=True``. Otherwise .. math:: v_{kjn} \leftarrow \left[\frac{\displaystyle\sum_{i} \dfrac{t_{ikn}}{(\sum_{k'}t_{ik'n}v_{k'jn})^{\frac{p+2}{p}}}|y_{ijn}|^{2}} {\displaystyle\sum_{i}\frac{t_{ikn}}{\sum_{k'}t_{ik'n}v_{k'jn}}} \right]^{\frac{p}{p+2}}v_{kjn}. """ p = self.domain flooring_fn = choose_flooring_fn(flooring_fn, method=self) if self.demix_filter is None: Y = self.output else: X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Y2 = np.abs(Y) ** 2 p2_p = (p + 2) / p p_p2 = p / (p + 2) if self.partitioning: Z = self.latent T, V = self.basis, self.activation ZT = Z[:, np.newaxis, :] * T[np.newaxis, :, :] ZTV = self.reconstruct_nmf(T, V, latent=Z) ZTVp2p = ZTV**p2_p ZT_ZTVp2p = ZT[:, :, :, np.newaxis] / ZTVp2p[:, :, np.newaxis, :] num = np.sum(ZT_ZTVp2p * Y2[:, :, np.newaxis, :], axis=(0, 1)) ZT_ZTV = ZT[:, :, :, np.newaxis] / ZTV[:, :, np.newaxis, :] denom = np.sum(ZT_ZTV, axis=(0, 1)) else: T, V = self.basis, self.activation TV = self.reconstruct_nmf(T, V) TVp2p = TV**p2_p T_TVp2p = T[:, :, :, np.newaxis] / TVp2p[:, :, np.newaxis, :] num = np.sum(T_TVp2p * Y2[:, :, np.newaxis, :], axis=1) T_TV = T[:, :, :, np.newaxis] / TV[:, :, np.newaxis, :] denom = np.sum(T_TV, axis=1) V = ((num / denom) ** p_p2) * V V = flooring_fn(V) self.activation = V def update_latent_me(self) -> None: r"""Update latent variables in NMF by ME algorithm. Update :math:`z_{nk}` as follows: .. math:: z_{nk} &\leftarrow\left[\frac{\displaystyle\sum_{i,j}\frac{t_{ik}v_{kj}} {(\sum_{k'}z_{nk'}t_{ik'}v_{k'j})^{2}} |y_{ijn}|^{2}}{\displaystyle\sum_{i,j}\dfrac{t_{ik}v_{kj}}{\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}} \right]z_{nk} \\ z_{nk} &\leftarrow\frac{z_{nk}}{\sum_{n'}z_{n'k}}. """ if self.domain != 2: raise ValueError("Domain parameter is expected 2, but given {}.".format(self.domain)) if self.demix_filter is None: Y = self.output else: X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Y2 = np.abs(Y) ** 2 Z = self.latent T, V = self.basis, self.activation TV = T[:, :, np.newaxis] * V[np.newaxis, :, :] ZTV = self.reconstruct_nmf(T, V, latent=Z) ZTV2 = ZTV**2 TV_ZTV2 = TV[np.newaxis, :, :, :] / ZTV2[:, :, np.newaxis, :] num = np.sum(TV_ZTV2 * Y2[:, :, np.newaxis, :], axis=(1, 3)) TV_ZTV = TV[np.newaxis, :, :, :] / ZTV[:, :, np.newaxis, :] denom = np.sum(TV_ZTV, axis=(1, 3)) Z = (num / denom) * Z Z = Z / Z.sum(axis=0) self.latent = Z def update_basis_me( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update NMF bases by ME algorithm. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. Update :math:`t_{ikn}` as follows: .. math:: t_{ik} \leftarrow\left[ \frac{\displaystyle\sum_{j,n}\frac{z_{nk}v_{kj}} {(\sum_{k'}z_{nk'}t_{ik'}v_{k'j})^{2}} |y_{ijn}|^{2}}{\displaystyle\sum_{j,n} \dfrac{z_{nk}v_{kj}}{\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}} \right]t_{ik}, if ``partitioning=True``. Otherwise .. math:: t_{ikn} \leftarrow\left[\frac{\displaystyle\sum_{j} \dfrac{v_{kjn}}{(\sum_{k'}t_{ik'n}v_{k'jn})^{2}}|y_{ijn}|^{2}} {\displaystyle\sum_{j}\frac{v_{kjn}}{\sum_{k'}t_{ik'n}v_{k'jn}}}\right] t_{ikn}. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) if self.domain != 2: raise ValueError("Domain parameter is expected 2, but given {}.".format(self.domain)) if self.demix_filter is None: Y = self.output else: X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Y2 = np.abs(Y) ** 2 if self.partitioning: Z = self.latent T, V = self.basis, self.activation ZV = Z[:, :, np.newaxis] * V[np.newaxis, :, :] ZTV = self.reconstruct_nmf(T, V, latent=Z) ZTV2 = ZTV**2 ZV_ZTV2 = ZV[:, np.newaxis, :, :] / ZTV2[:, :, np.newaxis, :] num = np.sum(ZV_ZTV2 * Y2[:, :, np.newaxis, :], axis=(0, 3)) ZV_ZTV = ZV[:, np.newaxis, :, :] / ZTV[:, :, np.newaxis, :] denom = np.sum(ZV_ZTV, axis=(0, 3)) else: T, V = self.basis, self.activation TV = self.reconstruct_nmf(T, V) TV2 = TV**2 V_TV2 = V[:, np.newaxis, :, :] / TV2[:, :, np.newaxis, :] num = np.sum(V_TV2 * Y2[:, :, np.newaxis, :], axis=3) V_TV = V[:, np.newaxis, :, :] / TV[:, :, np.newaxis, :] denom = np.sum(V_TV, axis=3) T = (num / denom) * T T = flooring_fn(T) self.basis = T def update_activation_me( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update NMF activations by ME algorithm. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. Update :math:`t_{ikn}` as follows: .. math:: v_{kj} \leftarrow\left[\frac{\displaystyle\sum_{i,n}\frac{z_{nk}t_{ik}} {(\sum_{k'}z_{nk'}t_{ik'}v_{k'j})^{2}} |y_{ijn}|^{2}}{\displaystyle\sum_{i,n}\dfrac{z_{nk}t_{ik}}{\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}} \right]v_{kj}, if ``partitioning=True``. Otherwise .. math:: v_{kjn} \leftarrow \left[\frac{\displaystyle\sum_{i} \dfrac{t_{ikn}}{(\sum_{k'}t_{ik'n}v_{k'jn})^{2}}|y_{ijn}|^{2}} {\displaystyle\sum_{i}\frac{t_{ikn}}{\sum_{k'}t_{ik'n}v_{k'jn}}} \right]v_{kjn}. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) if self.domain != 2: raise ValueError("Domain parameter is expected 2, but given {}.".format(self.domain)) if self.demix_filter is None: Y = self.output else: X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Y2 = np.abs(Y) ** 2 if self.partitioning: Z = self.latent T, V = self.basis, self.activation ZT = Z[:, np.newaxis, :] * T[np.newaxis, :, :] ZTV = self.reconstruct_nmf(T, V, latent=Z) ZTV2 = ZTV**2 ZT_ZTV2 = ZT[:, :, :, np.newaxis] / ZTV2[:, :, np.newaxis, :] num = np.sum(ZT_ZTV2 * Y2[:, :, np.newaxis, :], axis=(0, 1)) ZT_ZTV = ZT[:, :, :, np.newaxis] / ZTV[:, :, np.newaxis, :] denom = np.sum(ZT_ZTV, axis=(0, 1)) else: T, V = self.basis, self.activation TV = self.reconstruct_nmf(T, V) TV2 = TV**2 T_TV2 = T[:, :, :, np.newaxis] / TV2[:, :, np.newaxis, :] num = np.sum(T_TV2 * Y2[:, :, np.newaxis, :], axis=1) T_TV = T[:, :, :, np.newaxis] / TV[:, :, np.newaxis, :] denom = np.sum(T_TV, axis=1) V = (num / denom) * V V = flooring_fn(V) self.activation = V def update_spatial_model( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update demixing filters once. - If ``spatial_algorithm`` is ``IP`` or ``IP1``, ``update_spatial_model_ip1`` is called. - If ``spatial_algorithm`` is ``ISS`` or ``ISS1``, ``update_spatial_model_iss1`` is called. - If ``spatial_algorithm`` is ``IP2``, ``update_spatial_model_ip2`` is called. - If ``spatial_algorithm`` is ``ISS2``, ``update_spatial_model_iss2`` is called. - If ``spatial_algorithm`` is ``IPA``, ``update_spatial_model_ipa`` is called. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) if self.spatial_algorithm in ["IP", "IP1"]: self.update_spatial_model_ip1(flooring_fn=flooring_fn) elif self.spatial_algorithm in ["IP2"]: self.update_spatial_model_ip2(flooring_fn=flooring_fn) elif self.spatial_algorithm in ["ISS", "ISS1"]: self.update_spatial_model_iss1(flooring_fn=flooring_fn) elif self.spatial_algorithm in ["ISS2"]: self.update_spatial_model_iss2(flooring_fn=flooring_fn) elif self.spatial_algorithm in ["IPA"]: self.update_spatial_model_ipa(flooring_fn=flooring_fn) else: raise NotImplementedError("Not support {}.".format(self.spatial_algorithm)) def update_spatial_model_ip1( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update demixing filters once using iterative projection. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. Demixing filters are updated sequentially for :math:`n=1,\ldots,N` as follows: .. math:: \boldsymbol{w}_{in} &\leftarrow\left(\boldsymbol{W}_{in}^{\mathsf{H}}\boldsymbol{U}_{in}\right)^{-1} \ \boldsymbol{e}_{n}, \\ \boldsymbol{w}_{in} &\leftarrow\frac{\boldsymbol{w}_{in}} {\sqrt{\boldsymbol{w}_{in}^{\mathsf{H}}\boldsymbol{U}_{in}\boldsymbol{w}_{in}}}, where .. math:: \boldsymbol{U}_{in} = \frac{1}{J}\sum_{j} \frac{1}{\left(\sum_{k}z_{nk}t_{ik}v_{kj}\right)^{\frac{2}{p}}} \boldsymbol{x}_{ij}\boldsymbol{x}_{ij}^{\mathsf{H}} if ``partitioning=True``, otherwise .. math:: \boldsymbol{U}_{in} = \frac{1}{J}\sum_{j} \frac{1}{\left(\sum_{k}t_{ikn}v_{kjn}\right)^{\frac{2}{p}}} \boldsymbol{x}_{ij}\boldsymbol{x}_{ij}^{\mathsf{H}}. """ p = self.domain flooring_fn = choose_flooring_fn(flooring_fn, method=self) X, W = self.input, self.demix_filter if self.partitioning: Z = self.latent T, V = self.basis, self.activation ZTV = self.reconstruct_nmf(T, V, latent=Z) ZTV2p = ZTV ** (2 / p) varphi = 1 / ZTV2p else: T, V = self.basis, self.activation TV = self.reconstruct_nmf(T, V) TV2p = TV ** (2 / p) varphi = 1 / TV2p XX_Hermite = X[:, np.newaxis, :, :] * X[np.newaxis, :, :, :].conj() XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3) varphi = varphi.transpose(1, 0, 2) varphi_XX = varphi[:, :, np.newaxis, np.newaxis, :] * XX_Hermite[:, np.newaxis, :, :, :] U = np.mean(varphi_XX, axis=-1) self.demix_filter = update_by_ip1(W, U, flooring_fn=flooring_fn) def update_spatial_model_ip2( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update demixing filters once using pairwise iterative projection \ following [#nakashima2021faster]_. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. For :math:`n_{1}` and :math:`n_{2}` (:math:`n_{1}\neq n_{2}`), compute weighted covariance matrix as follows: .. math:: \boldsymbol{U}_{in} = \frac{1}{J}\sum_{j} \frac{1}{r_{ijn}}\boldsymbol{x}_{ij}\boldsymbol{x}_{ij}^{\mathsf{H}}, :math:`r_{ijn}` is computed by .. math:: r_{ijn} = \left(\sum_{k}z_{nk}t_{ik}v_{kj}\right)^{\frac{2}{p}} if ``partitioning=True``. Otherwise, .. math:: r_{ijn} = \left(\sum_{k}t_{ikn}v_{kjn}\right)^{\frac{2}{p}}. Using :math:`\boldsymbol{U}_{in_{1}}` and :math:`\boldsymbol{U}_{in_{2}}`, we compute generalized eigenvectors. .. math:: \left({\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}}^{\mathsf{H}}\boldsymbol{U}_{in_{1}} \boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\right)\boldsymbol{h}_{i} = \lambda_{i} \left({\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}}^{\mathsf{H}}\boldsymbol{U}_{in_{2}} \boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\right)\boldsymbol{h}_{i}, where .. math:: \boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})} &= (\boldsymbol{W}_{i}\boldsymbol{U}_{in_{1}})^{-1} ( \begin{array}{cc} \boldsymbol{e}_{n_{1}} & \boldsymbol{e}_{n_{2}} \end{array} ), \\ \boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})} &= (\boldsymbol{W}_{i}\boldsymbol{U}_{in_{2}})^{-1} ( \begin{array}{cc} \boldsymbol{e}_{n_{1}} & \boldsymbol{e}_{n_{2}} \end{array} ). After that, we standardize two eigenvectors :math:`\boldsymbol{h}_{in_{1}}` and :math:`\boldsymbol{h}_{in_{2}}`. .. math:: \boldsymbol{h}_{in_{1}} &\leftarrow\frac{\boldsymbol{h}_{in_{1}}} {\sqrt{\boldsymbol{h}_{in_{1}}^{\mathsf{H}} \left({\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}}^{\mathsf{H}}\boldsymbol{U}_{in_{1}} \boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\right) \boldsymbol{h}_{in_{1}}}}, \\ \boldsymbol{h}_{in_{2}} &\leftarrow\frac{\boldsymbol{h}_{in_{2}}} {\sqrt{\boldsymbol{h}_{in_{2}}^{\mathsf{H}} \left({\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}}^{\mathsf{H}}\boldsymbol{U}_{in_{2}} \boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\right) \boldsymbol{h}_{in_{2}}}}. Then, update :math:`\boldsymbol{w}_{in_{1}}` and :math:`\boldsymbol{w}_{in_{2}}` simultaneously. .. math:: \boldsymbol{w}_{in_{1}} &\leftarrow \boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\boldsymbol{h}_{in_{1}} \\ \boldsymbol{w}_{in_{2}} &\leftarrow \boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\boldsymbol{h}_{in_{2}} At each iteration, we update pairs of :math:`n_{1}` and :math:`n_{1}` for :math:`n_{1}\neq n_{2}`. .. [#nakashima2021faster] T. Nakashima, R. Scheibler, Y. Wakabayashi, and N. Ono, \ "Faster independent low-rank matrix analysis with pairwise updates of demixing vectors," in *Proc. EUSIPCO*, 2021, pp. 301-305. """ p = self.domain flooring_fn = choose_flooring_fn(flooring_fn, method=self) X, W = self.input, self.demix_filter if self.partitioning: Z = self.latent T, V = self.basis, self.activation ZTV = self.reconstruct_nmf(T, V, latent=Z) R = ZTV ** (2 / p) else: T, V = self.basis, self.activation TV = self.reconstruct_nmf(T, V) R = TV ** (2 / p) varphi = 1 / R XX_Hermite = X[:, np.newaxis, :, :] * X[np.newaxis, :, :, :].conj() XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3) varphi = varphi.transpose(1, 0, 2) varphi_XX = varphi[:, :, np.newaxis, np.newaxis, :] * XX_Hermite[:, np.newaxis, :, :, :] U = np.mean(varphi_XX, axis=-1) self.demix_filter = update_by_ip2( W, U, flooring_fn=flooring_fn, pair_selector=self.pair_selector ) def update_spatial_model_iss1( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update estimated spectrograms once using iterative source steering. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. Update :math:`y_{ijn}` as follows: .. math:: \boldsymbol{y}_{ij} & \leftarrow\boldsymbol{y}_{ij} - \boldsymbol{d}_{in}y_{ijn} \\ d_{inn'} &= \begin{cases} \dfrac{\displaystyle\sum_{j}\dfrac{1}{r_{ijn}} y_{ijn'}y_{ijn}^{*}}{\displaystyle\sum_{j}\dfrac{1} {r_{ijn}}|y_{ijn}|^{2}} & (n'\neq n) \\ 1 - \dfrac{1}{\sqrt{\displaystyle\dfrac{1}{J}\sum_{j}\dfrac{1} {r_{ijn}} |y_{ijn}|^{2}}} & (n'=n) \end{cases}, where .. math:: r_{ijn} = \left(\sum_{k}z_{nk}t_{ik}v_{kj}\right)^{\frac{2}{p}}, if ``partitioning=True``. Otherwise .. math:: r_{ijn} = \left(\sum_{k}t_{ikn}v_{kjn}\right)^{\frac{2}{p}}. """ p = self.domain flooring_fn = choose_flooring_fn(flooring_fn, method=self) Y = self.output if self.partitioning: Z = self.latent T, V = self.basis, self.activation ZTV = self.reconstruct_nmf(T, V, latent=Z) R = ZTV ** (2 / p) else: T, V = self.basis, self.activation TV = self.reconstruct_nmf(T, V) R = TV ** (2 / p) varphi = 1 / R self.output = update_by_iss1(Y, varphi, flooring_fn=flooring_fn) def update_spatial_model_iss2( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update estimated spectrograms once using pairwise iterative source steering. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. Compute :math:`\boldsymbol{G}_{in}^{(n_{1},n_{2})}` and :math:`\boldsymbol{f}_{in}^{(n_{1},n_{2})}` for :math:`n_{1}\neq n_{2}`: .. math:: \begin{array}{rclc} \boldsymbol{G}_{in}^{(n_{1},n_{2})} &=& {\displaystyle\frac{1}{J}\sum_{j}}\dfrac{1}{r_{ijn}} \boldsymbol{y}_{ij}^{(n_{1},n_{2})}{\boldsymbol{y}_{ij}^{(n_{1},n_{2})}}^{\mathsf{H}} &(n=1,\ldots,N), \\ \boldsymbol{f}_{in}^{(n_{1},n_{2})} &=& {\displaystyle\frac{1}{J}\sum_{j}} \dfrac{1}{r_{ijn}}y_{ijn}^{*}\boldsymbol{y}_{ij}^{(n_{1},n_{2})} &(n\neq n_{1},n_{2}), \end{array} where .. math:: r_{ijn} = \left(\sum_{k}z_{nk}t_{ik}v_{kj}\right)^{\frac{2}{p}} if ``partitioning=True``. Otherwise, .. math:: r_{ijn} = \left(\sum_{k}t_{ikn}v_{kjn}\right)^{\frac{2}{p}}. Using :math:`\boldsymbol{G}_{in}^{(n_{1},n_{2})}` and :math:`\boldsymbol{f}_{in}^{(n_{1},n_{2})}`, we compute .. math:: \begin{array}{rclc} \boldsymbol{p}_{in} &=& \dfrac{\boldsymbol{h}_{in}} {\sqrt{\boldsymbol{h}_{in}^{\mathsf{H}}\boldsymbol{G}_{in}^{(n_{1},n_{2})} \boldsymbol{h}_{in}}} & (n=n_{1},n_{2}), \\ \boldsymbol{q}_{in} &=& -{\boldsymbol{G}_{in}^{(n_{1},n_{2})}}^{-1}\boldsymbol{f}_{in}^{(n_{1},n_{2})} & (n\neq n_{1},n_{2}), \end{array} where :math:`\boldsymbol{h}_{in}` (:math:`n=n_{1},n_{2}`) is a generalized eigenvector obtained from .. math:: \boldsymbol{G}_{in_{1}}^{(n_{1},n_{2})}\boldsymbol{h}_{i} = \lambda_{i}\boldsymbol{G}_{in_{2}}^{(n_{1},n_{2})}\boldsymbol{h}_{i}. Separated signal :math:`y_{ijn}` is updated as follows: .. math:: y_{ijn} &\leftarrow\begin{cases} &\boldsymbol{p}_{in}^{\mathsf{H}}\boldsymbol{y}_{ij}^{(n_{1},n_{2})} & (n=n_{1},n_{2}) \\ &\boldsymbol{q}_{in}^{\mathsf{H}}\boldsymbol{y}_{ij}^{(n_{1},n_{2})} + y_{ijn} & (n\neq n_{1},n_{2}) \end{cases}. """ p = self.domain flooring_fn = choose_flooring_fn(flooring_fn, method=self) Y = self.output if self.partitioning: Z = self.latent T, V = self.basis, self.activation ZTV = self.reconstruct_nmf(T, V, latent=Z) R = ZTV ** (2 / p) else: T, V = self.basis, self.activation TV = self.reconstruct_nmf(T, V) R = TV ** (2 / p) varphi = 1 / R self.output = update_by_iss2( Y, varphi, flooring_fn=flooring_fn, pair_selector=self.pair_selector ) def update_spatial_model_ipa( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update estimated spectrograms once using \ iterative projection with adjustment [#scheibler2021independent]_. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. Compute :math:`r_{ijn}` as follows: .. math:: r_{ijn} = \left(\sum_{k}z_{nk}t_{ik}v_{kj}\right)^{\frac{2}{p}}, if ``partitioning=True``. Otherwise .. math:: r_{ijn} = \left(\sum_{k}t_{ikn}v_{kjn}\right)^{\frac{2}{p}}. Then, by defining, :math:`\tilde{\boldsymbol{U}}_{in'}`, :math:`\boldsymbol{A}_{in}\in\mathbb{R}^{(N-1)\times(N-1)}`, :math:`\boldsymbol{b}_{in}\in\mathbb{C}^{N-1}`, :math:`\boldsymbol{C}_{in}\in\mathbb{C}^{(N-1)\times(N-1)}`, :math:`\boldsymbol{d}_{in}\in\mathbb{C}^{N-1}`, and :math:`z_{in}\in\mathbb{R}_{\geq 0}` as follows: .. math:: \tilde{\boldsymbol{U}}_{in'} &= \frac{1}{J}\sum_{j}\frac{1}{r_{ijn'}} \boldsymbol{y}_{ij}\boldsymbol{y}_{ij}^{\mathsf{H}}, \\ \boldsymbol{A}_{in} &= \mathrm{diag}(\ldots, \boldsymbol{e}_{n}^{\mathsf{T}}\tilde{\boldsymbol{U}}_{in'}\boldsymbol{e}_{n} ,\ldots)~~(n'\neq n), \\ \boldsymbol{b}_{in} &= (\ldots, \boldsymbol{e}_{n}^{\mathsf{T}}\tilde{\boldsymbol{U}}_{in'}\boldsymbol{e}_{n'} ,\ldots)^{\mathsf{T}}~~(n'\neq n), \\ \boldsymbol{C}_{in} &= \bar{\boldsymbol{E}}_{n}^{\mathsf{T}}(\tilde{\boldsymbol{U}}_{in}^{-1})^{*} \bar{\boldsymbol{E}}_{n}, \\ \boldsymbol{d}_{in} &= \bar{\boldsymbol{E}}_{n}^{\mathsf{T}}(\tilde{\boldsymbol{U}}_{in}^{-1})^{*} \boldsymbol{e}_{n}, \\ z_{in} &= \boldsymbol{e}_{n}^{\mathsf{T}}\tilde{\boldsymbol{U}}_{in}^{-1}\boldsymbol{e}_{n} - \boldsymbol{d}_{in}^{\mathsf{H}}\boldsymbol{C}_{in}^{-1}\boldsymbol{d}_{in}. :math:`\boldsymbol{y}_{ij}` is updated via log-quadratically penelized quadratic minimization (LQPQM). .. math:: \check{\boldsymbol{q}}_{in} &\leftarrow \mathrm{LQPQM2}(\boldsymbol{H}_{in},\boldsymbol{v}_{in},z_{in}), \\ \boldsymbol{q}_{in} &\leftarrow \boldsymbol{G}_{in}^{-1}\check{\boldsymbol{q}}_{in} - \boldsymbol{A}_{in}^{-1}\boldsymbol{b}_{in}, \\ \tilde{\boldsymbol{q}}_{in} &\leftarrow \boldsymbol{e}_{n} - \bar{\boldsymbol{E}}_{n}\boldsymbol{q}_{in}, \\ \boldsymbol{p}_{in} &\leftarrow \frac{\tilde{\boldsymbol{U}}_{in}^{-1}\tilde{\boldsymbol{q}}_{in}^{*}} {\sqrt{(\tilde{\boldsymbol{q}}_{in}^{*})^{\mathsf{H}}\tilde{\boldsymbol{U}}_{in}^{-1} \tilde{\boldsymbol{q}}_{in}^{*}}}, \\ \boldsymbol{\Upsilon}_{i}^{(n)} &\leftarrow \boldsymbol{I} + \boldsymbol{e}_{n}(\boldsymbol{p}_{in} - \boldsymbol{e}_{n})^{\mathsf{H}} + \bar{\boldsymbol{E}}_{n}\boldsymbol{q}_{in}^{*}\boldsymbol{e}_{n}^{\mathsf{T}}, \\ \boldsymbol{y}_{ij} &\leftarrow \boldsymbol{\Upsilon}_{i}^{(n)}\boldsymbol{y}_{ij}, .. [#scheibler2021independent] R. Scheibler, "Independent vector analysis via log-quadratically penalized quadratic minimization," *IEEE Trans. Signal Processing*, vol. 69, pp. 2509-2524, 2021. """ self.lqpqm_normalization: bool self.newton_iter: int p = self.domain normalization = self.lqpqm_normalization max_iter = self.newton_iter flooring_fn = choose_flooring_fn(flooring_fn, method=self) Y = self.output if self.partitioning: Z = self.latent T, V = self.basis, self.activation ZTV = self.reconstruct_nmf(T, V, latent=Z) R = ZTV ** (2 / p) else: T, V = self.basis, self.activation TV = self.reconstruct_nmf(T, V) R = TV ** (2 / p) varphi = 1 / R self.output = update_by_ipa( Y, varphi, normalization=normalization, flooring_fn=flooring_fn, max_iter=max_iter, ) def compute_loss(self) -> float: r"""Compute loss :math:`\mathcal{L}`. :math:`\mathcal{L}` is given as follows: .. math:: \mathcal{L} = \frac{1}{J}\sum_{i,j,n}\left(\frac{|y_{ijn}|^{2}}{r_{ijn}} + \log r_{ijn}\right) - 2\sum_{i}\log|\det\boldsymbol{W}_{i}|, where .. math:: r_{ijn} = \left(\sum_{k}z_{nk}t_{ik}v_{kj}\right)^{\frac{2}{p}}, if ``partitioning=True``. Otherwise .. math:: r_{ijn} = \left(\sum_{k}t_{ikn}v_{kjn}\right)^{\frac{2}{p}}. Returns: Computed loss. """ p = self.domain if self.demix_filter is None: X, Y = self.input, self.output Y2 = np.abs(Y) ** 2 X, Y = X.transpose(1, 0, 2), Y.transpose(1, 0, 2) X_Hermite = X.transpose(0, 2, 1).conj() XX_Hermite = X @ X_Hermite W = Y @ X_Hermite @ np.linalg.inv(XX_Hermite) else: X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Y2 = np.abs(Y) ** 2 if self.partitioning: Z = self.latent T, V = self.basis, self.activation ZTV = self.reconstruct_nmf(T, V, latent=Z) R = ZTV ** (2 / p) loss = Y2 / R + (2 / p) * np.log(ZTV) else: T, V = self.basis, self.activation TV = self.reconstruct_nmf(T, V) R = TV ** (2 / p) loss = Y2 / R + (2 / p) * np.log(TV) logdet = self.compute_logdet(W) # (n_bins,) loss = np.sum(loss.mean(axis=-1), axis=0) - 2 * logdet loss = loss.sum(axis=0).item() return loss def apply_projection_back(self) -> None: r"""Apply projection back technique to estimated spectrograms.""" if self.demix_filter is None: assert self.scale_restoration, "Set self.scale_restoration=True." X, Y = self.input, self.output Y_scaled = projection_back(Y, reference=X, reference_id=self.reference_id) self.output = Y_scaled else: super().apply_projection_back() def apply_minimal_distortion_principle(self) -> None: r"""Apply minimal distortion principle to estimated spectrograms.""" if self.demix_filter is None: X, Y = self.input, self.output Y_scaled = minimal_distortion_principle(Y, reference=X, reference_id=self.reference_id) self.output = Y_scaled else: super().apply_minimal_distortion_principle() class TILRMA(ILRMABase): r"""Independent low-rank matrix analysis (ILRMA) on Student's *t* distribution. We assume :math:`y_{ijn}` follows a Student's *t* distribution. .. math:: p(y_{ijn}) = \frac{1}{\pi r_{ijn}} \left(1+\frac{2}{\nu}\frac{|y_{ijn}|^{2}}{r_{ijn}}\right)^{-\frac{2+\nu}{2}}, where .. math:: r_{ijn} = \left(\sum_{k}z_{nk}t_{ik}v_{kj}\right)^{\frac{2}{p}}, if ``partitioning=True``. Otherwise, .. math:: r_{ijn} = \left(\sum_{k}t_{ikn}v_{kjn}\right)^{\frac{2}{p}}. :math:`\nu` is a degree of freedom parameter. Args: n_basis (int): Number of NMF bases. dof (float): Degree of freedom parameter in student's-t distribution. spatial_algorithm (str): Algorithm for demixing filter updates. Choose ``IP``, ``IP1``, ``IP2``, ``ISS``, ``ISS1``, or ``ISS2``. Default: ``IP``. source_algorithm (str): Algorithm for source model updates. Choose ``MM`` or ``ME``. Default: ``MM``. domain (float): Domain parameter. Default: ``2``. partitioning (bool): Whether to use partioning function. Default: ``False``. flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. pair_selector (callable, optional): Selector to choose updaing pair in ``IP2`` and ``ISS2``. If ``None`` is given, ``sequential_pair_selector`` is used. Default: ``None``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. normalization (bool or str, optional): Normalization of demixing filters and NMF parameters. Choose ``power`` or ``projection_back``. Default: ``power``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` or ``minimal_distortion_principle``. Default: ``True``. record_loss (bool): Record the loss at each iteration of the update algorithm if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back and minimal distortion principle. Default: ``0``. rng (numpy.random.Generator, optioinal): Random number generator. This is mainly used to randomly initialize NMF. If ``None`` is given, ``np.random.default_rng()`` is used. Default: ``None``. Examples: Update demixing filters by IP: .. code-block:: python >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> ilrma = TILRMA( ... n_basis=2, ... dof=1000, ... spatial_algorithm="IP", ... rng=np.random.default_rng(42), ... ) >>> spectrogram_est = ilrma(spectrogram_mix, n_iter=100) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters by IP2: .. code-block:: python >>> from ssspy.utils.select_pair import sequential_pair_selector >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> ilrma = TILRMA( ... n_basis=2, ... dof=1000, ... spatial_algorithm="IP2", ... pair_selector=sequential_pair_selector, ... rng=np.random.default_rng(42), ... ) >>> spectrogram_est = ilrma(spectrogram_mix, n_iter=100) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters by ISS: .. code-block:: python >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> ilrma = TILRMA( ... n_basis=2, ... dof=1000, ... spatial_algorithm="ISS", ... rng=np.random.default_rng(42), ... ) >>> spectrogram_est = ilrma(spectrogram_mix, n_iter=100) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters by ISS2: .. code-block:: python >>> import functools >>> from ssspy.utils.select_pair import sequential_pair_selector >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> ilrma = TILRMA( ... n_basis=2, ... dof=1000, ... spatial_algorithm="ISS2", ... pair_selector=functools.partial(sequential_pair_selector, step=2), ... rng=np.random.default_rng(42), ... ) >>> spectrogram_est = ilrma(spectrogram_mix, n_iter=100) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) """ def __init__( self, n_basis: int, dof: float, spatial_algorithm: str = "IP", source_algorithm: str = "MM", domain: float = 2, partitioning: bool = False, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), pair_selector: Optional[Callable[[int], Iterable[Tuple[int, int]]]] = None, callbacks: Optional[ Union[Callable[["TILRMA"], None], List[Callable[["TILRMA"], None]]] ] = None, normalization: Optional[Union[bool, str]] = True, scale_restoration: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, rng: Optional[np.random.Generator] = None, ) -> None: super().__init__( n_basis=n_basis, partitioning=partitioning, flooring_fn=flooring_fn, callbacks=callbacks, scale_restoration=scale_restoration, record_loss=record_loss, reference_id=reference_id, rng=rng, ) assert spatial_algorithm in spatial_algorithms, "Not support {}.".format(spatial_algorithms) assert source_algorithm in source_algorithms, "Not support {}.".format(source_algorithm) assert 0 < domain <= 2, "domain parameter should be chosen from [0, 2]." if spatial_algorithm == "IPA": raise ValueError("IPA is not supported for t-ILRMA.") if source_algorithm == "ME": assert domain == 2, "domain parameter should be 2 when you specify ME algorithm." self.dof = dof self.spatial_algorithm = spatial_algorithm self.source_algorithm = source_algorithm self.domain = domain self.normalization = normalization if pair_selector is None: if spatial_algorithm in ["IP2", "ISS2"]: self.pair_selector = sequential_pair_selector else: self.pair_selector = pair_selector def __call__( self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs ) -> np.ndarray: r"""Separate a frequency-domain multichannel signal. Args: input (numpy.ndarray): The mixture signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). n_iter (int): The number of iterations of demixing filter updates. Default: ``100``. initial_call (bool): If ``True``, perform callbacks (and computation of loss if necessary) before iterations. Returns: numpy.ndarray of the separated signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). """ self.input = input.copy() self._reset(flooring_fn=self.flooring_fn, **kwargs) # Call __call__ of ILRMABase's parent, i.e. __call__ of IterativeMethodBase super(ILRMABase, self).__call__(n_iter=n_iter, initial_call=initial_call) if self.scale_restoration: self.restore_scale() if self.demix_filter is None: pass else: self.output = self.separate(self.input, demix_filter=self.demix_filter) return self.output def __repr__(self) -> str: s = "TILRMA(" s += "n_basis={n_basis}" s += ", dof={dof}" s += ", spatial_algorithm={spatial_algorithm}" s += ", source_algorithm={source_algorithm}" s += ", domain={domain}" s += ", partitioning={partitioning}" s += ", normalization={normalization}" s += ", scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) def _reset( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", **kwargs, ) -> None: r"""Reset attributes by given keyword arguments. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. kwargs: Keyword arguments to set as attributes of ILRMA. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) super()._reset(flooring_fn=flooring_fn, **kwargs) if self.spatial_algorithm in ["ISS", "ISS1", "ISS2"]: self.demix_filter = None def update_once( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update NMF parameters and demixing filters once. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) self.update_source_model(flooring_fn=flooring_fn) self.update_spatial_model(flooring_fn=flooring_fn) if self.normalization: self.normalize(flooring_fn=flooring_fn) def update_source_model( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update NMF bases, activations, and latent variables. - If ``source_algorithm`` is ``MM``, ``update_source_model_mm`` is called. - If ``source_algorithm`` is ``ME``, ``update_source_model_me`` is called. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) if self.source_algorithm == "MM": self.update_source_model_mm(flooring_fn=flooring_fn) elif self.source_algorithm == "ME": self.update_source_model_me(flooring_fn=flooring_fn) else: raise ValueError( "{}-algorithm-based source model updates are not supported.".format( self.source_algorithm ) ) def update_source_model_mm( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update NMF bases, activations, and latent variables by MM algorithm. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) if self.partitioning: self.update_latent_mm() self.update_basis_mm(flooring_fn=flooring_fn) self.update_activation_mm(flooring_fn=flooring_fn) def update_source_model_me( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update NMF bases, activations, and latent variables by ME algorithm. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) if self.partitioning: self.update_latent_me() self.update_basis_me(flooring_fn=flooring_fn) self.update_activation_me(flooring_fn=flooring_fn) def update_latent_mm(self) -> None: r"""Update latent variables in NMF by MM algorithm. Update :math:`z_{nk}` as follows: .. math:: z_{nk} &\leftarrow\left[\frac{\displaystyle\sum_{i,j}\frac{t_{ik}v_{kj}} {\tilde{r}_{ijn}\sum_{k'}z_{nk'}t_{ik'}v_{k'j}} |y_{ijn}|^{2}}{\displaystyle\sum_{i,j}\dfrac{t_{ik}v_{kj}}{\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}} \right]^{\frac{p}{p+2}}z_{nk} \\ z_{nk} &\leftarrow\frac{z_{nk}}{\sum_{n'}z_{n'k}}, \\ \tilde{r}_{ijn} &= \frac{\nu}{\nu+2}\left(\sum_{k}z_{nk}t_{ik}v_{kj}\right)^{\frac{2}{p}} + \frac{2}{\nu+2}|y_{ijn}|^{2}. """ p = self.domain nu = self.dof if self.demix_filter is None: Y = self.output else: X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Y2 = np.abs(Y) ** 2 p_p2 = p / (p + 2) nu_nu2 = nu / (nu + 2) Z = self.latent T, V = self.basis, self.activation TV = T[:, :, np.newaxis] * V[np.newaxis, :, :] ZTV = self.reconstruct_nmf(T, V, latent=Z) ZTV2p = ZTV ** (2 / p) R_tilde = nu_nu2 * ZTV2p + (1 - nu_nu2) * Y2 RZTV = R_tilde * ZTV TV_RZTV = TV[np.newaxis, :, :, :] / RZTV[:, :, np.newaxis, :] num = np.sum(TV_RZTV * Y2[:, :, np.newaxis, :], axis=(1, 3)) TV_ZTV = TV[np.newaxis, :, :, :] / ZTV[:, :, np.newaxis, :] denom = np.sum(TV_ZTV, axis=(1, 3)) Z = ((num / denom) ** p_p2) * Z Z = Z / Z.sum(axis=0) self.latent = Z def update_basis_mm( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update NMF bases by MM algorithm. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. Update :math:`t_{ikn}` as follows: .. math:: t_{ik} &\leftarrow\left[ \frac{\displaystyle\sum_{j,n}\frac{z_{nk}v_{kj}} {\tilde{r}_{ijn}\sum_{k'}z_{nk'}t_{ik'}v_{k'j}} |y_{ijn}|^{2}}{\displaystyle\sum_{j,n} \dfrac{z_{nk}v_{kj}}{\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}} \right]^{\frac{p}{p+2}}t_{ik}, \\ \tilde{r}_{ijn} &= \frac{\nu}{\nu+2}\left(\sum_{k}z_{nk}t_{ik}v_{kj}\right)^{\frac{2}{p}} + \frac{2}{\nu+2}|y_{ijn}|^{2}, if ``partitioning=True``. Otherwise .. math:: t_{ikn} &\leftarrow \left[\frac{\displaystyle\sum_{j} \dfrac{v_{kjn}}{\tilde{r}_{ijn}\sum_{k'}t_{ik'n}v_{k'jn}}|y_{ijn}|^{2}} {\displaystyle\sum_{j}\frac{v_{kjn}}{\sum_{k'}t_{ik'n}v_{k'jn}}}\right] ^{\frac{p}{p+2}}t_{ikn}, \\ \tilde{r}_{ijn} &= \frac{\nu}{\nu+2}\left(\sum_{k}t_{ikn}v_{kjn}\right)^{\frac{2}{p}} + \frac{2}{\nu+2}|y_{ijn}|^{2}. """ p = self.domain nu = self.dof flooring_fn = choose_flooring_fn(flooring_fn, method=self) if self.demix_filter is None: Y = self.output else: X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Y2 = np.abs(Y) ** 2 p_p2 = p / (p + 2) nu_nu2 = nu / (nu + 2) if self.partitioning: Z = self.latent T, V = self.basis, self.activation ZV = Z[:, :, np.newaxis] * V[np.newaxis, :, :] ZTV = self.reconstruct_nmf(T, V, latent=Z) ZTV2p = ZTV ** (2 / p) R_tilde = nu_nu2 * ZTV2p + (1 - nu_nu2) * Y2 RZTV = R_tilde * ZTV ZV_RZTV = ZV[:, np.newaxis, :, :] / RZTV[:, :, np.newaxis, :] num = np.sum(ZV_RZTV * Y2[:, :, np.newaxis, :], axis=(0, 3)) ZV_ZTV = ZV[:, np.newaxis, :, :] / ZTV[:, :, np.newaxis, :] denom = np.sum(ZV_ZTV, axis=(0, 3)) else: T, V = self.basis, self.activation TV = self.reconstruct_nmf(T, V) TV2p = TV ** (2 / p) R_tilde = nu_nu2 * TV2p + (1 - nu_nu2) * Y2 RTV = R_tilde * TV V_RTV = V[:, np.newaxis, :, :] / RTV[:, :, np.newaxis, :] num = np.sum(V_RTV * Y2[:, :, np.newaxis, :], axis=3) V_TV = V[:, np.newaxis, :, :] / TV[:, :, np.newaxis, :] denom = np.sum(V_TV, axis=3) T = ((num / denom) ** p_p2) * T T = flooring_fn(T) self.basis = T def update_activation_mm( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update NMF activations by MM algorithm. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. Update :math:`v_{kjn}` as follows: .. math:: v_{kj} &\leftarrow\left[\frac{\displaystyle\sum_{i,n}\frac{z_{nk}t_{ik}} {\tilde{r}_{ijn}\sum_{k'}z_{nk'}t_{ik'}v_{k'j}} |y_{ijn}|^{2}}{\displaystyle\sum_{i,n}\dfrac{z_{nk}t_{ik}}{\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}} \right]^{\frac{p}{p+2}}v_{kj}, \\ \tilde{r}_{ijn} &= \frac{\nu}{\nu+2}\left(\sum_{k}z_{nk}t_{ik}v_{kj}\right)^{\frac{2}{p}} + \frac{2}{\nu+2}|y_{ijn}|^{2}, if ``partitioning=True``. Otherwise .. math:: v_{kjn} &\leftarrow \left[\frac{\displaystyle\sum_{i} \dfrac{t_{ikn}}{\tilde{r}_{ijn}\sum_{k'}t_{ik'n}v_{k'jn}}|y_{ijn}|^{2}} {\displaystyle\sum_{i}\frac{t_{ikn}}{\sum_{k'}t_{ik'n}v_{k'jn}}} \right]^{\frac{p}{p+2}}v_{kjn}, \\ \tilde{r}_{ijn} &= \frac{\nu}{\nu+2}\left(\sum_{k}t_{ikn}v_{kjn}\right)^{\frac{2}{p}} + \frac{2}{\nu+2}|y_{ijn}|^{2}. """ p = self.domain nu = self.dof flooring_fn = choose_flooring_fn(flooring_fn, method=self) if self.demix_filter is None: Y = self.output else: X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Y2 = np.abs(Y) ** 2 p_p2 = p / (p + 2) nu_nu2 = nu / (nu + 2) if self.partitioning: Z = self.latent T, V = self.basis, self.activation ZT = Z[:, np.newaxis, :] * T[np.newaxis, :, :] ZTV = self.reconstruct_nmf(T, V, latent=Z) ZTV2p = ZTV ** (2 / p) R_tilde = nu_nu2 * ZTV2p + (1 - nu_nu2) * Y2 RZTV = R_tilde * ZTV ZT_RZTV = ZT[:, :, :, np.newaxis] / RZTV[:, :, np.newaxis, :] num = np.sum(ZT_RZTV * Y2[:, :, np.newaxis, :], axis=(0, 1)) ZT_ZTV = ZT[:, :, :, np.newaxis] / ZTV[:, :, np.newaxis, :] denom = np.sum(ZT_ZTV, axis=(0, 1)) else: T, V = self.basis, self.activation TV = self.reconstruct_nmf(T, V) TV2p = TV ** (2 / p) R_tilde = nu_nu2 * TV2p + (1 - nu_nu2) * Y2 RTV = R_tilde * TV T_RTV = T[:, :, :, np.newaxis] / RTV[:, :, np.newaxis, :] num = np.sum(T_RTV * Y2[:, :, np.newaxis, :], axis=1) T_TV = T[:, :, :, np.newaxis] / TV[:, :, np.newaxis, :] denom = np.sum(T_TV, axis=1) V = ((num / denom) ** p_p2) * V V = flooring_fn(V) self.activation = V def update_latent_me(self) -> None: r"""Update latent variables in NMF by ME algorithm. Update :math:`z_{nk}` as follows: .. math:: z_{nk} &\leftarrow\frac{\displaystyle\sum_{i,j}\frac{t_{ik}v_{kj}} {\tilde{r}_{ijn}\sum_{k'}z_{nk'}t_{ik'}v_{k'j}} |y_{ijn}|^{2}}{\displaystyle\sum_{i,j}\dfrac{t_{ik}v_{kj}}{\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}} z_{nk} \\ z_{nk} &\leftarrow\frac{z_{nk}}{\sum_{n'}z_{n'k}}, \\ \tilde{r}_{ijn} &= \frac{\nu}{\nu+2}\sum_{k}z_{nk}t_{ik}v_{kj}+\frac{2}{\nu+2}|y_{ijn}|^{2}. """ nu = self.dof if self.domain != 2: raise ValueError("Domain parameter is expected 2, but given {}.".format(self.domain)) if self.demix_filter is None: Y = self.output else: X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Y2 = np.abs(Y) ** 2 nu_nu2 = nu / (nu + 2) Z = self.latent T, V = self.basis, self.activation TV = T[:, :, np.newaxis] * V[np.newaxis, :, :] ZTV = self.reconstruct_nmf(T, V, latent=Z) R_tilde = nu_nu2 * ZTV + (1 - nu_nu2) * Y2 RZTV = R_tilde * ZTV TV_RZTV = TV[np.newaxis, :, :, :] / RZTV[:, :, np.newaxis, :] num = np.sum(TV_RZTV * Y2[:, :, np.newaxis, :], axis=(1, 3)) TV_ZTV = TV[np.newaxis, :, :, :] / ZTV[:, :, np.newaxis, :] denom = np.sum(TV_ZTV, axis=(1, 3)) Z = (num / denom) * Z Z = Z / Z.sum(axis=0) self.latent = Z def update_basis_me( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update NMF bases by ME algorithm. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. Update :math:`t_{ikn}` as follows: .. math:: t_{ik} &\leftarrow \frac{\displaystyle\sum_{j,n}\frac{z_{nk}v_{kj}} {\tilde{r}_{ijn}\sum_{k'}z_{nk'}t_{ik'}v_{k'j}} |y_{ijn}|^{2}}{\displaystyle\sum_{j,n} \dfrac{z_{nk}v_{kj}}{\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}} t_{ik}, \\ \tilde{r}_{ijn} &= \frac{\nu}{\nu+2}\sum_{k}z_{nk}t_{ik}v_{kj}+\frac{2}{\nu+2}|y_{ijn}|^{2}, if ``partitioning=True``. Otherwise .. math:: t_{ikn} &\leftarrow\frac{\displaystyle\sum_{j} \dfrac{v_{kjn}}{\tilde{r}_{ijn}\sum_{k'}t_{ik'n}v_{k'jn}}|y_{ijn}|^{2}} {\displaystyle\sum_{j}\frac{v_{kjn}}{\sum_{k'}t_{ik'n}v_{k'jn}}} t_{ikn}, \\ \tilde{r}_{ijn} &= \frac{\nu}{\nu+2}\sum_{k}t_{ikn}v_{kjn}+\frac{2}{\nu+2}|y_{ijn}|^{2}. """ nu = self.dof flooring_fn = choose_flooring_fn(flooring_fn, method=self) if self.domain != 2: raise ValueError("Domain parameter is expected 2, but given {}.".format(self.domain)) if self.demix_filter is None: Y = self.output else: X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Y2 = np.abs(Y) ** 2 nu_nu2 = nu / (nu + 2) if self.partitioning: Z = self.latent T, V = self.basis, self.activation ZV = Z[:, :, np.newaxis] * V[np.newaxis, :, :] ZTV = self.reconstruct_nmf(T, V, latent=Z) R_tilde = nu_nu2 * ZTV + (1 - nu_nu2) * Y2 RZTV = R_tilde * ZTV ZV_RZTV = ZV[:, np.newaxis, :, :] / RZTV[:, :, np.newaxis, :] num = np.sum(ZV_RZTV * Y2[:, :, np.newaxis, :], axis=(0, 3)) ZV_ZTV = ZV[:, np.newaxis, :, :] / ZTV[:, :, np.newaxis, :] denom = np.sum(ZV_ZTV, axis=(0, 3)) else: T, V = self.basis, self.activation TV = self.reconstruct_nmf(T, V) R_tilde = nu_nu2 * TV + (1 - nu_nu2) * Y2 RTV = R_tilde * TV V_RTV = V[:, np.newaxis, :, :] / RTV[:, :, np.newaxis, :] num = np.sum(V_RTV * Y2[:, :, np.newaxis, :], axis=3) V_TV = V[:, np.newaxis, :, :] / TV[:, :, np.newaxis, :] denom = np.sum(V_TV, axis=3) T = (num / denom) * T T = flooring_fn(T) self.basis = T def update_activation_me( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update NMF activations by ME algorithm. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. Update :math:`v_{kjn}` as follows: .. math:: v_{kj} &\leftarrow\frac{\displaystyle\sum_{i,n}\frac{z_{nk}t_{ik}} {\tilde{r}_{ijn}\sum_{k'}z_{nk'}t_{ik'}v_{k'j}} |y_{ijn}|^{2}}{\displaystyle\sum_{i,n}\dfrac{z_{nk}t_{ik}}{\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}} v_{kj}, \\ \tilde{r}_{ijn} &= \frac{\nu}{\nu+2}\sum_{k}z_{nk}t_{ik}v_{kj}+\frac{2}{\nu+2}|y_{ijn}|^{2}, if ``partitioning=True``. Otherwise .. math:: v_{kjn} &\leftarrow\frac{\displaystyle\sum_{i} \dfrac{t_{ikn}}{\tilde{r}_{ijn}\sum_{k'}t_{ik'n}v_{k'jn}}|y_{ijn}|^{2}} {\displaystyle\sum_{i}\frac{t_{ikn}}{\sum_{k'}t_{ik'n}v_{k'jn}}} v_{kjn}, \\ \tilde{r}_{ijn} &= \frac{\nu}{\nu+2}\sum_{k}t_{ikn}v_{kjn}+\frac{2}{\nu+2}|y_{ijn}|^{2}. """ nu = self.dof flooring_fn = choose_flooring_fn(flooring_fn, method=self) if self.domain != 2: raise ValueError("Domain parameter is expected 2, but given {}.".format(self.domain)) if self.demix_filter is None: Y = self.output else: X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Y2 = np.abs(Y) ** 2 nu_nu2 = nu / (nu + 2) if self.partitioning: Z = self.latent T, V = self.basis, self.activation ZT = Z[:, np.newaxis, :] * T[np.newaxis, :, :] ZTV = self.reconstruct_nmf(T, V, latent=Z) R_tilde = nu_nu2 * ZTV + (1 - nu_nu2) * Y2 RZTV = R_tilde * ZTV ZT_RZTV = ZT[:, :, :, np.newaxis] / RZTV[:, :, np.newaxis, :] num = np.sum(ZT_RZTV * Y2[:, :, np.newaxis, :], axis=(0, 1)) ZT_ZTV = ZT[:, :, :, np.newaxis] / ZTV[:, :, np.newaxis, :] denom = np.sum(ZT_ZTV, axis=(0, 1)) else: T, V = self.basis, self.activation TV = self.reconstruct_nmf(T, V) R_tilde = nu_nu2 * TV + (1 - nu_nu2) * Y2 RTV = R_tilde * TV T_RTV = T[:, :, :, np.newaxis] / RTV[:, :, np.newaxis, :] num = np.sum(T_RTV * Y2[:, :, np.newaxis, :], axis=1) T_TV = T[:, :, :, np.newaxis] / TV[:, :, np.newaxis, :] denom = np.sum(T_TV, axis=1) V = (num / denom) * V V = flooring_fn(V) self.activation = V def update_spatial_model( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update demixing filters once. - If ``spatial_algorithm`` is ``IP`` or ``IP1``, ``update_spatial_model_ip1`` is called. - If ``spatial_algorithm`` is ``ISS`` or ``ISS1``, ``update_spatial_model_iss1`` is called. - If ``spatial_algorithm`` is ``IP2``, ``update_spatial_model_ip2`` is called. - If ``spatial_algorithm`` is ``ISS2``, ``update_spatial_model_iss2`` is called. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) if self.spatial_algorithm in ["IP", "IP1"]: self.update_spatial_model_ip1(flooring_fn=flooring_fn) elif self.spatial_algorithm in ["IP2"]: self.update_spatial_model_ip2(flooring_fn=flooring_fn) elif self.spatial_algorithm in ["ISS", "ISS1"]: self.update_spatial_model_iss1(flooring_fn=flooring_fn) elif self.spatial_algorithm in ["ISS2"]: self.update_spatial_model_iss2(flooring_fn=flooring_fn) else: raise NotImplementedError("Not support {}.".format(self.spatial_algorithm)) def update_spatial_model_ip1( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update demixing filters once using iterative projection. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. Demixing filters are updated sequentially for :math:`n=1,\ldots,N` as follows: .. math:: \boldsymbol{w}_{in} &\leftarrow\left(\boldsymbol{W}_{in}^{\mathsf{H}}\boldsymbol{U}_{in}\right)^{-1} \ \boldsymbol{e}_{n}, \\ \boldsymbol{w}_{in} &\leftarrow\frac{\boldsymbol{w}_{in}} {\sqrt{\boldsymbol{w}_{in}^{\mathsf{H}}\boldsymbol{U}_{in}\boldsymbol{w}_{in}}}, where .. math:: \boldsymbol{U}_{in} = \frac{1}{J}\sum_{j} \frac{1}{\tilde{r}_{ijn}}\boldsymbol{x}_{ij}\boldsymbol{x}_{ij}^{\mathsf{H}}. :math:`\tilde{r}_{ijn}` is defined as .. math:: \tilde{r}_{ijn} = \frac{\nu}{\nu+2}\left(\sum_{k}z_{nk}t_{ik}v_{kj}\right)^{\frac{2}{p}} + \frac{2}{\nu+2}|y_{ijn}|^{2}, if ``partitioning=True``. Otherwise .. math:: \tilde{r}_{ijn} = \frac{\nu}{\nu+2}\left(\sum_{k}t_{ikn}v_{kjn}\right)^{\frac{2}{p}} + \frac{2}{\nu+2}|y_{ijn}|^{2}. """ p = self.domain nu = self.dof X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) flooring_fn = choose_flooring_fn(flooring_fn, method=self) Y2 = np.abs(Y) ** 2 nu_nu2 = nu / (nu + 2) if self.partitioning: Z = self.latent T, V = self.basis, self.activation ZTV = self.reconstruct_nmf(T, V, latent=Z) ZTV2p = ZTV ** (2 / p) R_tilde = nu_nu2 * ZTV2p + (1 - nu_nu2) * Y2 else: T, V = self.basis, self.activation TV = self.reconstruct_nmf(T, V) TV2p = TV ** (2 / p) R_tilde = nu_nu2 * TV2p + (1 - nu_nu2) * Y2 varphi = 1 / R_tilde XX_Hermite = X[:, np.newaxis, :, :] * X[np.newaxis, :, :, :].conj() XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3) varphi = varphi.transpose(1, 0, 2) varphi_XX = varphi[:, :, np.newaxis, np.newaxis, :] * XX_Hermite[:, np.newaxis, :, :, :] U = np.mean(varphi_XX, axis=-1) self.demix_filter = update_by_ip1(W, U, flooring_fn=flooring_fn) def update_spatial_model_ip2( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update demixing filters once using pairwise iterative projection. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. For :math:`n_{1}` and :math:`n_{2}` (:math:`n_{1}\neq n_{2}`), compute weighted covariance matrix as follows: .. math:: \boldsymbol{U}_{in} = \frac{1}{J}\sum_{j} \frac{1}{\tilde{r}_{ijn}}\boldsymbol{x}_{ij}\boldsymbol{x}_{ij}^{\mathsf{H}}, :math:`\tilde{r}_{ijn}` is computed by .. math:: \tilde{r}_{ijn} = \frac{\nu}{\nu+2}\left(\sum_{k}z_{nk}t_{ik}v_{kj}\right)^{\frac{2}{p}} + \frac{2}{\nu+2}|y_{ijn}|^{2}, if ``partitioning=True``. \ Otherwise, .. math:: \tilde{r}_{ijn} = \frac{\nu}{\nu+2}\left(\sum_{k}t_{ikn}v_{kjn}\right)^{\frac{2}{p}} + \frac{2}{\nu+2}|y_{ijn}|^{2}. Using :math:`\boldsymbol{U}_{in_{1}}` and :math:`\boldsymbol{U}_{in_{2}}`, we compute generalized eigenvectors. .. math:: \left({\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}}^{\mathsf{H}}\boldsymbol{U}_{in_{1}} \boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\right)\boldsymbol{h}_{i} = \lambda_{i} \left({\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}}^{\mathsf{H}}\boldsymbol{U}_{in_{2}} \boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\right)\boldsymbol{h}_{i}, where .. math:: \boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})} &= (\boldsymbol{W}_{i}\boldsymbol{U}_{in_{1}})^{-1} ( \begin{array}{cc} \boldsymbol{e}_{n_{1}} & \boldsymbol{e}_{n_{2}} \end{array} ), \\ \boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})} &= (\boldsymbol{W}_{i}\boldsymbol{U}_{in_{2}})^{-1} ( \begin{array}{cc} \boldsymbol{e}_{n_{1}} & \boldsymbol{e}_{n_{2}} \end{array} ). After that, we standardize two eigenvectors :math:`\boldsymbol{h}_{in_{1}}` and :math:`\boldsymbol{h}_{in_{2}}`. .. math:: \boldsymbol{h}_{in_{1}} &\leftarrow\frac{\boldsymbol{h}_{in_{1}}} {\sqrt{\boldsymbol{h}_{in_{1}}^{\mathsf{H}} \left({\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}}^{\mathsf{H}}\boldsymbol{U}_{in_{1}} \boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\right) \boldsymbol{h}_{in_{1}}}}, \\ \boldsymbol{h}_{in_{2}} &\leftarrow\frac{\boldsymbol{h}_{in_{2}}} {\sqrt{\boldsymbol{h}_{in_{2}}^{\mathsf{H}} \left({\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}}^{\mathsf{H}}\boldsymbol{U}_{in_{2}} \boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\right) \boldsymbol{h}_{in_{2}}}}. Then, update :math:`\boldsymbol{w}_{in_{1}}` and :math:`\boldsymbol{w}_{in_{2}}` simultaneously. .. math:: \boldsymbol{w}_{in_{1}} &\leftarrow \boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\boldsymbol{h}_{in_{1}} \\ \boldsymbol{w}_{in_{2}} &\leftarrow \boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\boldsymbol{h}_{in_{2}} At each iteration, we update pairs of :math:`n_{1}` and :math:`n_{1}` for :math:`n_{1}\neq n_{2}`. """ nu = self.dof p = self.domain flooring_fn = choose_flooring_fn(flooring_fn, method=self) X, W = self.input, self.demix_filter nu_nu2 = nu / (nu + 2) Y = self.separate(X, demix_filter=W) Y2 = np.abs(Y) ** 2 if self.partitioning: Z = self.latent T, V = self.basis, self.activation ZTV = self.reconstruct_nmf(T, V, latent=Z) ZTV2p = ZTV ** (2 / p) R_tilde = nu_nu2 * ZTV2p + (1 - nu_nu2) * Y2 else: T, V = self.basis, self.activation TV = self.reconstruct_nmf(T, V) TV2p = TV ** (2 / p) R_tilde = nu_nu2 * TV2p + (1 - nu_nu2) * Y2 varphi = 1 / R_tilde XX_Hermite = X[:, np.newaxis, :, :] * X[np.newaxis, :, :, :].conj() XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3) varphi = varphi.transpose(1, 0, 2) varphi_XX = varphi[:, :, np.newaxis, np.newaxis, :] * XX_Hermite[:, np.newaxis, :, :, :] U = np.mean(varphi_XX, axis=-1) self.demix_filter = update_by_ip2( W, U, flooring_fn=flooring_fn, pair_selector=self.pair_selector ) def update_spatial_model_iss1( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update estimated spectrograms once using iterative source steering. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. Update :math:`y_{ijn}` as follows: .. math:: \boldsymbol{y}_{ij} & \leftarrow\boldsymbol{y}_{ij} - \boldsymbol{d}_{in}y_{ijn} \\ d_{inn'} &= \begin{cases} \dfrac{\displaystyle\sum_{j}\dfrac{1}{\tilde{r}_{ijn}} y_{ijn'}y_{ijn}^{*}}{\displaystyle\sum_{j}\dfrac{1} {\tilde{r}_{ijn}}|y_{ijn}|^{2}} & (n'\neq n) \\ 1 - \dfrac{1}{\sqrt{\displaystyle\dfrac{1}{J}\sum_{j}\dfrac{1} {\tilde{r}_{ijn}}|y_{ijn}|^{2}}} & (n'=n) \end{cases}. :math:`\tilde{r}_{ijn}` is defined as .. math:: \tilde{r}_{ijn} = \frac{\nu}{\nu+2}\left(\sum_{k}z_{nk}t_{ik}v_{kj}\right)^{\frac{2}{p}} + \frac{2}{\nu+2}|y_{ijn}|^{2}, if ``partitioning=True``. Otherwise .. math:: \tilde{r}_{ijn} = \frac{\nu}{\nu+2}\left(\sum_{k}t_{ikn}v_{kjn}\right)^{\frac{2}{p}} + \frac{2}{\nu+2}|y_{ijn}|^{2}. """ p = self.domain nu = self.dof flooring_fn = choose_flooring_fn(flooring_fn, method=self) Y = self.output Y2 = np.abs(Y) ** 2 nu_nu2 = nu / (nu + 2) if self.partitioning: Z = self.latent T, V = self.basis, self.activation ZTV = self.reconstruct_nmf(T, V, latent=Z) ZTV2p = ZTV ** (2 / p) R_tilde = nu_nu2 * ZTV2p + (1 - nu_nu2) * Y2 else: T, V = self.basis, self.activation TV = self.reconstruct_nmf(T, V) TV2p = TV ** (2 / p) R_tilde = nu_nu2 * TV2p + (1 - nu_nu2) * Y2 varphi = 1 / R_tilde self.output = update_by_iss1(Y, varphi, flooring_fn=flooring_fn) def update_spatial_model_iss2( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update estimated spectrograms once using pairwise iterative source steering. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. Compute :math:`\boldsymbol{G}_{in}^{(n_{1},n_{2})}` and :math:`\boldsymbol{f}_{in}^{(n_{1},n_{2})}` for :math:`n_{1}\neq n_{2}`: .. math:: \begin{array}{rclc} \boldsymbol{G}_{in}^{(n_{1},n_{2})} &=& {\displaystyle\frac{1}{J}\sum_{j}}\dfrac{1}{\tilde{r}_{ijn}} \boldsymbol{y}_{ij}^{(n_{1},n_{2})}{\boldsymbol{y}_{ij}^{(n_{1},n_{2})}}^{\mathsf{H}} &(n=1,\ldots,N), \\ \boldsymbol{f}_{in}^{(n_{1},n_{2})} &=& {\displaystyle\frac{1}{J}\sum_{j}} \dfrac{1}{\tilde{r}_{ijn}}y_{ijn}^{*}\boldsymbol{y}_{ij}^{(n_{1},n_{2})} &(n\neq n_{1},n_{2}), \end{array} where .. math:: \tilde{r}_{ijn} = \frac{\nu}{\nu+2}\left(\sum_{k}z_{nk}t_{ik}v_{kj}\right)^{\frac{2}{p}} + \frac{2}{\nu+2}|y_{ijn}|^{2} if ``partitioning=True``. Otherwise, .. math:: \tilde{r}_{ijn} = \frac{\nu}{\nu+2}\left(\sum_{k}t_{ikn}v_{kjn}\right)^{\frac{2}{p}} + \frac{2}{\nu+2}|y_{ijn}|^{2}. Using :math:`\boldsymbol{G}_{in}^{(n_{1},n_{2})}` and :math:`\boldsymbol{f}_{in}^{(n_{1},n_{2})}`, we compute .. math:: \begin{array}{rclc} \boldsymbol{p}_{in} &=& \dfrac{\boldsymbol{h}_{in}} {\sqrt{\boldsymbol{h}_{in}^{\mathsf{H}}\boldsymbol{G}_{in}^{(n_{1},n_{2})} \boldsymbol{h}_{in}}} & (n=n_{1},n_{2}), \\ \boldsymbol{q}_{in} &=& -{\boldsymbol{G}_{in}^{(n_{1},n_{2})}}^{-1}\boldsymbol{f}_{in}^{(n_{1},n_{2})} & (n\neq n_{1},n_{2}), \end{array} where :math:`\boldsymbol{h}_{in}` (:math:`n=n_{1},n_{2}`) is a generalized eigenvector obtained from .. math:: \boldsymbol{G}_{in_{1}}^{(n_{1},n_{2})}\boldsymbol{h}_{i} = \lambda_{i}\boldsymbol{G}_{in_{2}}^{(n_{1},n_{2})}\boldsymbol{h}_{i}. Separated signal :math:`y_{ijn}` is updated as follows: .. math:: y_{ijn} &\leftarrow\begin{cases} &\boldsymbol{p}_{in}^{\mathsf{H}}\boldsymbol{y}_{ij}^{(n_{1},n_{2})} & (n=n_{1},n_{2}) \\ &\boldsymbol{q}_{in}^{\mathsf{H}}\boldsymbol{y}_{ij}^{(n_{1},n_{2})} + y_{ijn} & (n\neq n_{1},n_{2}) \end{cases}. """ p = self.domain nu = self.dof flooring_fn = choose_flooring_fn(flooring_fn, method=self) Y = self.output Y2 = np.abs(Y) ** 2 nu_nu2 = nu / (nu + 2) if self.partitioning: Z = self.latent T, V = self.basis, self.activation ZTV = self.reconstruct_nmf(T, V, latent=Z) ZTV2p = ZTV ** (2 / p) R_tilde = nu_nu2 * ZTV2p + (1 - nu_nu2) * Y2 else: T, V = self.basis, self.activation TV = self.reconstruct_nmf(T, V) TV2p = TV ** (2 / p) R_tilde = nu_nu2 * TV2p + (1 - nu_nu2) * Y2 varphi = 1 / R_tilde self.output = update_by_iss2( Y, varphi, flooring_fn=flooring_fn, pair_selector=self.pair_selector ) def compute_loss(self) -> float: r"""Compute loss :math:`\mathcal{L}`. :math:`\mathcal{L}` is given as follows: .. math:: \mathcal{L} = \frac{1}{J}\sum_{i,j} \left\{1+\frac{\nu}{2}\log\left(1+\frac{2}{\nu} \frac{|y_{ijn}|^{2}}{r_{ijn}}\right) + \log r_{ijn}\right\} -2\sum_{i}\log\left|\det\boldsymbol{W}_{i}\right|, where .. math:: r_{ijn} = \left(\sum_{k}z_{nk}t_{ik}v_{kj}\right)^{\frac{2}{p}}, if ``partitioning=True``, otherwise .. math:: r_{ijn} = \left(\sum_{k}t_{ikn}v_{kjn}\right)^{\frac{2}{p}}. Returns: Computed loss. """ nu = self.dof p = self.domain if self.demix_filter is None: X, Y = self.input, self.output Y2 = np.abs(Y) ** 2 X, Y = X.transpose(1, 0, 2), Y.transpose(1, 0, 2) X_Hermite = X.transpose(0, 2, 1).conj() XX_Hermite = X @ X_Hermite W = Y @ X_Hermite @ np.linalg.inv(XX_Hermite) else: X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Y2 = np.abs(Y) ** 2 if self.partitioning: Z = self.latent T, V = self.basis, self.activation ZTV = self.reconstruct_nmf(T, V, latent=Z) Y2ZTV2p = Y2 / (ZTV ** (2 / p)) loss = (1 + nu / 2) * np.log(1 + (2 / nu) * Y2ZTV2p) + (2 / p) * np.log(ZTV) else: T, V = self.basis, self.activation TV = self.reconstruct_nmf(T, V) Y2TV2p = Y2 / (TV ** (2 / p)) loss = (1 + nu / 2) * np.log(1 + (2 / nu) * Y2TV2p) + (2 / p) * np.log(TV) logdet = self.compute_logdet(W) # (n_bins,) loss = np.sum(loss.mean(axis=-1), axis=0) - 2 * logdet loss = loss.sum(axis=0).item() return loss def apply_projection_back(self) -> None: r"""Apply projection back technique to estimated spectrograms.""" if self.demix_filter is None: assert self.scale_restoration, "Set self.scale_restoration=True." X, Y = self.input, self.output Y_scaled = projection_back(Y, reference=X, reference_id=self.reference_id) self.output = Y_scaled else: super().apply_projection_back() def apply_minimal_distortion_principle(self) -> None: r"""Apply minimal distortion principle to estimated spectrograms.""" if self.demix_filter is None: X, Y = self.input, self.output Y_scaled = minimal_distortion_principle(Y, reference=X, reference_id=self.reference_id) self.output = Y_scaled else: super().apply_minimal_distortion_principle() class GGDILRMA(ILRMABase): r"""Independent low-rank matrix analysis (ILRMA) on a generalized Gaussian distribution. We assume :math:`y_{ijn}` follows a generalized Gaussian distribution. .. math:: p(y_{ijn}) = \frac{\beta}{2\pi r_{ijn}\Gamma\left(\frac{2}{\beta}\right)} \exp\left\{-\left(\frac{|y_{ijn}|^{2}}{r_{ijn}}\right)^{\frac{\beta}{2}}\right\}, where .. math:: r_{ijn} = \left(\sum_{k}z_{nk}t_{ik}v_{kj}\right)^{\frac{2}{p}}, if ``partitioning=True``. Otherwise, .. math:: r_{ijn} = \left(\sum_{k}t_{ikn}v_{kjn}\right)^{\frac{2}{p}}. :math:`\beta` is a shape parameter of a generalized Gaussian distribution. Args: n_basis (int): Number of NMF bases. beta (float): Shape parameter in generalized Gaussian distribution. spatial_algorithm (str): Algorithm for demixing filter updates. Choose ``IP``, ``IP1``, ``IP2``, ``ISS``, ``ISS1``, or ``ISS2``. Default: ``IP``. source_algorithm (str): Algorithm for source model updates. Only ``MM`` is supported: Default: ``MM``. domain (float): Domain parameter. Default: ``2``. partitioning (bool): Whether to use partioning function. Default: ``False``. flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. pair_selector (callable, optional): Selector to choose updaing pair in ``IP2`` and ``ISS2``. If ``None`` is given, ``sequential_pair_selector`` is used. Default: ``None``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. normalization (bool or str, optional): Normalization of demixing filters and NMF parameters. Choose ``power`` or ``projection_back``. Default: ``power``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` or ``minimal_distortion_principle``. Default: ``True``. record_loss (bool): Record the loss at each iteration of the update algorithm if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back and minimal distortion principle. Default: ``0``. rng (numpy.random.Generator, optioinal): Random number generator. This is mainly used to randomly initialize NMF. If ``None`` is given, ``np.random.default_rng()`` is used. Default: ``None``. Examples: Update demixing filters by IP: .. code-block:: python >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> ilrma = GGDILRMA( ... n_basis=2, ... beta=1.99, ... spatial_algorithm="IP", ... rng=np.random.default_rng(42), ... ) >>> spectrogram_est = ilrma(spectrogram_mix, n_iter=100) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters by IP2: .. code-block:: python >>> from ssspy.utils.select_pair import sequential_pair_selector >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> ilrma = GGDILRMA( ... n_basis=2, ... beta=1.99, ... spatial_algorithm="IP2", ... pair_selector=sequential_pair_selector, ... rng=np.random.default_rng(42), ... ) >>> spectrogram_est = ilrma(spectrogram_mix, n_iter=100) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters by ISS: .. code-block:: python >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> ilrma = GGDILRMA( ... n_basis=2, ... beta=1.99, ... spatial_algorithm="ISS", ... rng=np.random.default_rng(42), ... ) >>> spectrogram_est = ilrma(spectrogram_mix, n_iter=100) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters by ISS2: .. code-block:: python >>> import functools >>> from ssspy.utils.select_pair import sequential_pair_selector >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> ilrma = GGDILRMA( ... n_basis=2, ... beta=1.99, ... spatial_algorithm="ISS2", ... pair_selector=functools.partial(sequential_pair_selector, step=2), ... rng=np.random.default_rng(42), ... ) >>> spectrogram_est = ilrma(spectrogram_mix, n_iter=100) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) """ def __init__( self, n_basis: int, beta: float, spatial_algorithm: str = "IP", source_algorithm: str = "MM", domain: float = 2, partitioning: bool = False, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), pair_selector: Optional[Callable[[int], Iterable[Tuple[int, int]]]] = None, callbacks: Optional[ Union[Callable[["GGDILRMA"], None], List[Callable[["GGDILRMA"], None]]] ] = None, normalization: Optional[Union[bool, str]] = True, scale_restoration: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, rng: Optional[np.random.Generator] = None, ) -> None: super().__init__( n_basis=n_basis, partitioning=partitioning, flooring_fn=flooring_fn, callbacks=callbacks, scale_restoration=scale_restoration, record_loss=record_loss, reference_id=reference_id, rng=rng, ) assert 0 < beta < 2, "Shape parameter {} shoule be chosen from (0, 2).".format(beta) assert spatial_algorithm in spatial_algorithms, "Not support {}.".format(spatial_algorithms) assert source_algorithm == "MM", "Not support {}.".format(source_algorithm) assert 0 < domain <= 2, "domain parameter should be chosen from [0, 2]." if spatial_algorithm == "IPA": raise ValueError("IPA is not supported for GGD-ILRMA.") self.beta = beta self.spatial_algorithm = spatial_algorithm self.source_algorithm = source_algorithm self.domain = domain self.normalization = normalization if pair_selector is None: if spatial_algorithm in ["IP2", "ISS2"]: self.pair_selector = sequential_pair_selector else: self.pair_selector = pair_selector def __call__( self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs ) -> np.ndarray: r"""Separate a frequency-domain multichannel signal. Args: input (numpy.ndarray): The mixture signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). n_iter (int): The number of iterations of demixing filter updates. Default: ``100``. initial_call (bool): If ``True``, perform callbacks (and computation of loss if necessary) before iterations. Returns: numpy.ndarray of the separated signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). """ self.input = input.copy() self._reset(flooring_fn=self.flooring_fn, **kwargs) # Call __call__ of ILRMABase's parent, i.e. __call__ of IterativeMethodBase super(ILRMABase, self).__call__(n_iter=n_iter, initial_call=initial_call) if self.scale_restoration: self.restore_scale() if self.demix_filter is None: pass else: self.output = self.separate(self.input, demix_filter=self.demix_filter) return self.output def __repr__(self) -> str: s = "GGDILRMA(" s += "n_basis={n_basis}" s += ", beta={beta}" s += ", spatial_algorithm={spatial_algorithm}" s += ", source_algorithm={source_algorithm}" s += ", domain={domain}" s += ", partitioning={partitioning}" s += ", normalization={normalization}" s += ", scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) def _reset( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", **kwargs, ) -> None: r"""Reset attributes by given keyword arguments. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. kwargs: Keyword arguments to set as attributes of ILRMA. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) super()._reset(flooring_fn=flooring_fn, **kwargs) if self.spatial_algorithm in ["ISS", "ISS1", "ISS2"]: self.demix_filter = None def update_once( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update NMF parameters and demixing filters once. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) self.update_source_model(flooring_fn=flooring_fn) self.update_spatial_model(flooring_fn=flooring_fn) if self.normalization: self.normalize(flooring_fn=flooring_fn) def update_source_model( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update NMF bases, activations, and latent variables. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) if self.source_algorithm == "MM": self.update_source_model_mm(flooring_fn=flooring_fn) else: raise ValueError( "{}-algorithm-based source model updates are not supported.".format( self.source_algorithm ) ) def update_source_model_mm( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update NMF bases, activations, and latent variables by MM algorithm. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) if self.partitioning: self.update_latent_mm() self.update_basis_mm(flooring_fn=flooring_fn) self.update_activation_mm(flooring_fn=flooring_fn) def update_latent_mm(self) -> None: r"""Update latent variables in NMF by MM algorithm. Update :math:`z_{nk}` as follows: .. math:: z_{nk} &\leftarrow\left[ \frac{\beta}{2} \frac{\displaystyle\sum_{i,j}\frac{t_{ik}v_{kj}} {(\sum_{k'}z_{nk'}t_{ik'}v_{k'j})^{\frac{\beta+p}{2}}}|y_{ijn}|^{\beta}} {\displaystyle\sum_{i,j}\frac{t_{ik}v_{kj}}{\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}} \right]^{\frac{p}{\beta+p}}z_{nk}, \\ z_{nk} &\leftarrow\frac{z_{nk}}{\displaystyle\sum_{n'}z_{n'k}}. """ p = self.domain beta = self.beta if self.demix_filter is None: Y = self.output else: X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Yb = np.abs(Y) ** beta p_bp = p / (beta + p) bp_p = (beta + p) / p Z = self.latent T, V = self.basis, self.activation TV = T[:, :, np.newaxis] * V[np.newaxis, :, :] ZTV = self.reconstruct_nmf(T, V, latent=Z) ZTVbpp = ZTV**bp_p TV_RZTV = TV[np.newaxis, :, :, :] / ZTVbpp[:, :, np.newaxis, :] num = (beta / 2) * np.sum(TV_RZTV * Yb[:, :, np.newaxis, :], axis=(1, 3)) TV_ZTV = TV[np.newaxis, :, :, :] / ZTV[:, :, np.newaxis, :] denom = np.sum(TV_ZTV, axis=(1, 3)) Z = ((num / denom) ** p_bp) * Z Z = Z / Z.sum(axis=0) self.latent = Z def update_basis_mm( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update NMF bases by MM algorithm. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. Update :math:`t_{ikn}` as follows: .. math:: t_{ik} \leftarrow\left[ \frac{\beta}{2} \frac{\displaystyle\sum_{j,n}\frac{z_{nk}v_{kj}} {(\sum_{k'}z_{nk'}t_{ik'}v_{k'j})^{\frac{\beta+p}{p}}}|y_{ijn}|^{\beta}} {\displaystyle\sum_{j,n}\frac{z_{nk}v_{kj}}{\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}} \right]^{\frac{p}{\beta+p}}t_{ik}, if ``partitioning=True``. Otherwise .. math:: t_{ikn} \leftarrow\left[ \frac{\beta}{2} \frac{\displaystyle\sum_{j}\frac{v_{kjn}} {(\sum_{k'}t_{ik'n}v_{k'jn})^{\frac{\beta+p}{p}}}|y_{ijn}|^{\beta}} {\displaystyle\sum_{j}\frac{v_{kjn}}{\sum_{k'}t_{ik'n}v_{k'jn}}} \right]^{\frac{p}{\beta+p}}t_{ikn}. """ p = self.domain beta = self.beta flooring_fn = choose_flooring_fn(flooring_fn, method=self) if self.demix_filter is None: Y = self.output else: X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Yb = np.abs(Y) ** beta p_bp = p / (beta + p) bp_p = (beta + p) / p if self.partitioning: Z = self.latent T, V = self.basis, self.activation ZV = Z[:, :, np.newaxis] * V[np.newaxis, :, :] ZTV = self.reconstruct_nmf(T, V, latent=Z) ZTVbpp = ZTV**bp_p ZV_ZTVbpp = ZV[:, np.newaxis, :, :] / ZTVbpp[:, :, np.newaxis, :] num = (beta / 2) * np.sum(ZV_ZTVbpp * Yb[:, :, np.newaxis, :], axis=(0, 3)) ZV_ZTV = ZV[:, np.newaxis, :, :] / ZTV[:, :, np.newaxis, :] denom = np.sum(ZV_ZTV, axis=(0, 3)) else: T, V = self.basis, self.activation TV = self.reconstruct_nmf(T, V) TVbpp = TV**bp_p V_TVbpp = V[:, np.newaxis, :, :] / TVbpp[:, :, np.newaxis, :] num = (beta / 2) * np.sum(V_TVbpp * Yb[:, :, np.newaxis, :], axis=3) V_TV = V[:, np.newaxis, :, :] / TV[:, :, np.newaxis, :] denom = np.sum(V_TV, axis=3) T = ((num / denom) ** p_bp) * T T = flooring_fn(T) self.basis = T def update_activation_mm( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update NMF activations by MM algorithm. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. Update :math:`v_{kjn}` as follows: .. math:: v_{kj} \leftarrow\left[ \frac{\beta}{2} \frac{\displaystyle\sum_{i,n}\frac{z_{nk}t_{ik}} {(\sum_{k'}z_{nk'}t_{ik'}v_{k'j})^{\frac{\beta+p}{p}}}|y_{ijn}|^{\beta}} {\displaystyle\sum_{i,n}\frac{z_{nk}t_{ik}}{\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}} \right]^{\frac{p}{\beta+p}}v_{kj}, if ``partitioning=True``. Otherwise .. math:: v_{kj} \leftarrow\left[ \frac{\beta}{2} \frac{\displaystyle\sum_{i}\frac{t_{ikn}} {(\sum_{k'}t_{ik'n}v_{k'jn})^{\frac{\beta+p}{p}}}|y_{ijn}|^{\beta}} {\displaystyle\sum_{i}\frac{t_{ik}}{\sum_{k'}t_{ik'n}v_{k'jn}}} \right]^{\frac{p}{\beta+p}}v_{kjn}. """ p = self.domain beta = self.beta flooring_fn = choose_flooring_fn(flooring_fn, method=self) if self.demix_filter is None: Y = self.output else: X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Yb = np.abs(Y) ** beta p_bp = p / (beta + p) bp_p = (beta + p) / p if self.partitioning: Z = self.latent T, V = self.basis, self.activation ZT = Z[:, np.newaxis, :] * T[np.newaxis, :, :] ZTV = self.reconstruct_nmf(T, V, latent=Z) ZTVbpp = ZTV**bp_p ZT_ZTVbpp = ZT[:, :, :, np.newaxis] / ZTVbpp[:, :, np.newaxis, :] num = (beta / 2) * np.sum(ZT_ZTVbpp * Yb[:, :, np.newaxis, :], axis=(0, 1)) ZT_ZTV = ZT[:, :, :, np.newaxis] / ZTV[:, :, np.newaxis, :] denom = np.sum(ZT_ZTV, axis=(0, 1)) else: T, V = self.basis, self.activation TV = self.reconstruct_nmf(T, V) TVbpp = TV**bp_p T_TVbpp = T[:, :, :, np.newaxis] / TVbpp[:, :, np.newaxis, :] num = (beta / 2) * np.sum(T_TVbpp * Yb[:, :, np.newaxis, :], axis=1) T_TV = T[:, :, :, np.newaxis] / TV[:, :, np.newaxis, :] denom = np.sum(T_TV, axis=1) V = ((num / denom) ** p_bp) * V V = flooring_fn(V) self.activation = V def update_spatial_model( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update demixing filters once. - If ``spatial_algorithm`` is ``IP`` or ``IP1``, ``update_spatial_model_ip1`` is called. - If ``spatial_algorithm`` is ``ISS`` or ``ISS1``, ``update_spatial_model_iss1`` is called. - If ``spatial_algorithm`` is ``IP2``, ``update_spatial_model_ip2`` is called. - If ``spatial_algorithm`` is ``ISS2``, ``update_spatial_model_iss2`` is called. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) if self.spatial_algorithm in ["IP", "IP1"]: self.update_spatial_model_ip1(flooring_fn=flooring_fn) elif self.spatial_algorithm in ["IP2"]: self.update_spatial_model_ip2(flooring_fn=flooring_fn) elif self.spatial_algorithm in ["ISS", "ISS1"]: self.update_spatial_model_iss1(flooring_fn=flooring_fn) elif self.spatial_algorithm in ["ISS2"]: self.update_spatial_model_iss2(flooring_fn=flooring_fn) else: raise NotImplementedError("Not support {}.".format(self.spatial_algorithm)) def update_spatial_model_ip1( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update demixing filters once using iterative projection. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. Demixing filters are updated sequentially for :math:`n=1,\ldots,N` as follows: .. math:: \boldsymbol{w}_{in} &\leftarrow\left(\boldsymbol{W}_{in}^{\mathsf{H}}\boldsymbol{U}_{in}\right)^{-1} \ \boldsymbol{e}_{n}, \\ \boldsymbol{w}_{in} &\leftarrow\frac{\boldsymbol{w}_{in}} {\sqrt{\boldsymbol{w}_{in}^{\mathsf{H}}\boldsymbol{U}_{in}\boldsymbol{w}_{in}}}, where .. math:: \boldsymbol{U}_{in} \leftarrow\frac{1}{J}\sum_{i,j,n} \frac{\boldsymbol{x}_{ij}\boldsymbol{x}_{ij}^{\mathsf{H}}}{\tilde{r}_{ijn}}. :math:`\tilde{r}_{ijn}` is computed as .. math:: \tilde{r}_{ijn} = \frac{2|y_{ijn}|^{2-\beta}}{\beta} \left(\sum_{k}z_{nk}t_{ik}v_{kj}\right)^{\frac{\beta}{p}}, if ``partitioning=True``. Otherwise, .. math:: \tilde{r}_{ijn} = \frac{2|y_{ijn}|^{2-\beta}}{\beta} \left(\sum_{k}t_{ikn}v_{kjn}\right)^{\frac{\beta}{p}}. """ p = self.domain beta = self.beta flooring_fn = choose_flooring_fn(flooring_fn, method=self) X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Y2b = np.abs(Y) ** (2 - beta) Y2b = flooring_fn(Y2b) if self.partitioning: Z = self.latent T, V = self.basis, self.activation ZTV = self.reconstruct_nmf(T, V, latent=Z) ZTVbp = ZTV ** (beta / p) R_tilde = (2 / beta) * Y2b * ZTVbp else: T, V = self.basis, self.activation TV = self.reconstruct_nmf(T, V) TVbp = TV ** (beta / p) R_tilde = (2 / beta) * Y2b * TVbp varphi = 1 / R_tilde XX_Hermite = X[:, np.newaxis, :, :] * X[np.newaxis, :, :, :].conj() XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3) varphi = varphi.transpose(1, 0, 2) varphi_XX = varphi[:, :, np.newaxis, np.newaxis, :] * XX_Hermite[:, np.newaxis, :, :, :] U = np.mean(varphi_XX, axis=-1) self.demix_filter = update_by_ip1(W, U, flooring_fn=flooring_fn) def update_spatial_model_ip2( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update demixing filters once using pairwise iterative projection. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. For :math:`n_{1}` and :math:`n_{2}` (:math:`n_{1}\neq n_{2}`), compute weighted covariance matrix as follows: .. math:: \boldsymbol{U}_{in} = \frac{1}{J}\sum_{j} \frac{1}{\tilde{r}_{ijn}}\boldsymbol{x}_{ij}\boldsymbol{x}_{ij}^{\mathsf{H}}, :math:`\tilde{r}_{ijn}` is computed by .. math:: \tilde{r}_{ijn} = \frac{2|y_{ijn}|^{2-\beta}}{\beta} \left(\sum_{k}z_{nk}t_{ik}v_{kj}\right)^{\frac{\beta}{p}}, if ``partitioning=True``. \ Otherwise, .. math:: \tilde{r}_{ijn} = \frac{2|y_{ijn}|^{2-\beta}}{\beta} \left(\sum_{k}t_{ikn}v_{kjn}\right)^{\frac{\beta}{p}}. Using :math:`\boldsymbol{U}_{in_{1}}` and :math:`\boldsymbol{U}_{in_{2}}`, we compute generalized eigenvectors. .. math:: \left({\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}}^{\mathsf{H}}\boldsymbol{U}_{in_{1}} \boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\right)\boldsymbol{h}_{i} = \lambda_{i} \left({\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}}^{\mathsf{H}}\boldsymbol{U}_{in_{2}} \boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\right)\boldsymbol{h}_{i}, where .. math:: \boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})} &= (\boldsymbol{W}_{i}\boldsymbol{U}_{in_{1}})^{-1} ( \begin{array}{cc} \boldsymbol{e}_{n_{1}} & \boldsymbol{e}_{n_{2}} \end{array} ), \\ \boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})} &= (\boldsymbol{W}_{i}\boldsymbol{U}_{in_{2}})^{-1} ( \begin{array}{cc} \boldsymbol{e}_{n_{1}} & \boldsymbol{e}_{n_{2}} \end{array} ). After that, we standardize two eigenvectors :math:`\boldsymbol{h}_{in_{1}}` and :math:`\boldsymbol{h}_{in_{2}}`. .. math:: \boldsymbol{h}_{in_{1}} &\leftarrow\frac{\boldsymbol{h}_{in_{1}}} {\sqrt{\boldsymbol{h}_{in_{1}}^{\mathsf{H}} \left({\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}}^{\mathsf{H}}\boldsymbol{U}_{in_{1}} \boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\right) \boldsymbol{h}_{in_{1}}}}, \\ \boldsymbol{h}_{in_{2}} &\leftarrow\frac{\boldsymbol{h}_{in_{2}}} {\sqrt{\boldsymbol{h}_{in_{2}}^{\mathsf{H}} \left({\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}}^{\mathsf{H}}\boldsymbol{U}_{in_{2}} \boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\right) \boldsymbol{h}_{in_{2}}}}. Then, update :math:`\boldsymbol{w}_{in_{1}}` and :math:`\boldsymbol{w}_{in_{2}}` simultaneously. .. math:: \boldsymbol{w}_{in_{1}} &\leftarrow \boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\boldsymbol{h}_{in_{1}} \\ \boldsymbol{w}_{in_{2}} &\leftarrow \boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\boldsymbol{h}_{in_{2}} At each iteration, we update pairs of :math:`n_{1}` and :math:`n_{1}` for :math:`n_{1}\neq n_{2}`. """ p = self.domain beta = self.beta flooring_fn = choose_flooring_fn(flooring_fn, method=self) X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Y2b = np.abs(Y) ** (2 - beta) Y2b = flooring_fn(Y2b) if self.partitioning: Z = self.latent T, V = self.basis, self.activation ZTV = self.reconstruct_nmf(T, V, latent=Z) ZTVbp = ZTV ** (beta / p) R_tilde = (2 / beta) * Y2b * ZTVbp else: T, V = self.basis, self.activation TV = self.reconstruct_nmf(T, V) TVbp = TV ** (beta / p) R_tilde = (2 / beta) * Y2b * TVbp varphi = 1 / R_tilde XX_Hermite = X[:, np.newaxis, :, :] * X[np.newaxis, :, :, :].conj() XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3) varphi = varphi.transpose(1, 0, 2) varphi_XX = varphi[:, :, np.newaxis, np.newaxis, :] * XX_Hermite[:, np.newaxis, :, :, :] U = np.mean(varphi_XX, axis=-1) self.demix_filter = update_by_ip2( W, U, flooring_fn=flooring_fn, pair_selector=self.pair_selector ) def update_spatial_model_iss1( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update estimated spectrograms once using iterative source steering. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. Update :math:`y_{ijn}` as follows: .. math:: \boldsymbol{y}_{ij} & \leftarrow\boldsymbol{y}_{ij} - \boldsymbol{d}_{in}y_{ijn} \\ d_{inn'} &= \begin{cases} \dfrac{\displaystyle\sum_{j}\dfrac{1}{\tilde{r}_{ijn}} y_{ijn'}y_{ijn}^{*}}{\displaystyle\sum_{j}\dfrac{1} {\tilde{r}_{ijn}}|y_{ijn}|^{2}} & (n'\neq n) \\ 1 - \dfrac{1}{\sqrt{\displaystyle\dfrac{1}{J}\sum_{j}\dfrac{1} {\tilde{r}_{ijn}}|y_{ijn}|^{2}}} & (n'=n) \end{cases}, where :math:`\tilde{r}_{ijn}` is computed as .. math:: \tilde{r}_{ijn} = \frac{2|y_{ijn}|^{2-\beta}}{\beta} \left(\sum_{k}z_{nk}t_{ik}v_{kj}\right)^{\frac{\beta}{p}}, if ``partitioning=True``. Otherwise, .. math:: \tilde{r}_{ijn} = \frac{2|y_{ijn}|^{2-\beta}}{\beta} \left(\sum_{k}t_{ikn}v_{kjn}\right)^{\frac{\beta}{p}}. """ p = self.domain beta = self.beta flooring_fn = choose_flooring_fn(flooring_fn, method=self) Y = self.output Y2b = np.abs(Y) ** (2 - beta) Y2b = flooring_fn(Y2b) if self.partitioning: Z = self.latent T, V = self.basis, self.activation ZTV = self.reconstruct_nmf(T, V, latent=Z) ZTVbp = ZTV ** (beta / p) R_bar = Y2b * ZTVbp else: T, V = self.basis, self.activation TV = self.reconstruct_nmf(T, V) TVbp = TV ** (beta / p) R_bar = Y2b * TVbp varphi = beta / (2 * R_bar) self.output = update_by_iss1(Y, varphi, flooring_fn=flooring_fn) def update_spatial_model_iss2( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update estimated spectrograms once using pairwise iterative source steering. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. Compute :math:`\boldsymbol{G}_{in}^{(n_{1},n_{2})}` and :math:`\boldsymbol{f}_{in}^{(n_{1},n_{2})}` for :math:`n_{1}\neq n_{2}`: .. math:: \begin{array}{rclc} \boldsymbol{G}_{in}^{(n_{1},n_{2})} &=& {\displaystyle\frac{1}{J}\sum_{j}}\dfrac{1}{\tilde{r}_{ijn}} \boldsymbol{y}_{ij}^{(n_{1},n_{2})}{\boldsymbol{y}_{ij}^{(n_{1},n_{2})}}^{\mathsf{H}} &(n=1,\ldots,N), \\ \boldsymbol{f}_{in}^{(n_{1},n_{2})} &=& {\displaystyle\frac{1}{J}\sum_{j}} \dfrac{1}{\tilde{r}_{ijn}}y_{ijn}^{*}\boldsymbol{y}_{ij}^{(n_{1},n_{2})} &(n\neq n_{1},n_{2}), \end{array} where .. math:: \tilde{r}_{ijn} = \frac{2}{\beta}|y_{ijn}|^{2-\beta} \left(\sum_{k}z_{nk}t_{ik}v_{kj}\right)^{\frac{\beta}{p}}, if ``partitioning=True``. Otherwise, .. math:: \tilde{r}_{ijn} = \frac{2}{\beta}|y_{ijn}|^{2-\beta} \left(\sum_{k}t_{ikn}v_{kjn}\right)^{\frac{\beta}{p}}. Using :math:`\boldsymbol{G}_{in}^{(n_{1},n_{2})}` and :math:`\boldsymbol{f}_{in}^{(n_{1},n_{2})}`, we compute .. math:: \begin{array}{rclc} \boldsymbol{p}_{in} &=& \dfrac{\boldsymbol{h}_{in}} {\sqrt{\boldsymbol{h}_{in}^{\mathsf{H}}\boldsymbol{G}_{in}^{(n_{1},n_{2})} \boldsymbol{h}_{in}}} & (n=n_{1},n_{2}), \\ \boldsymbol{q}_{in} &=& -{\boldsymbol{G}_{in}^{(n_{1},n_{2})}}^{-1}\boldsymbol{f}_{in}^{(n_{1},n_{2})} & (n\neq n_{1},n_{2}), \end{array} where :math:`\boldsymbol{h}_{in}` (:math:`n=n_{1},n_{2}`) is a generalized eigenvector obtained from .. math:: \boldsymbol{G}_{in_{1}}^{(n_{1},n_{2})}\boldsymbol{h}_{i} = \lambda_{i}\boldsymbol{G}_{in_{2}}^{(n_{1},n_{2})}\boldsymbol{h}_{i}. Separated signal :math:`y_{ijn}` is updated as follows: .. math:: y_{ijn} &\leftarrow\begin{cases} &\boldsymbol{p}_{in}^{\mathsf{H}}\boldsymbol{y}_{ij}^{(n_{1},n_{2})} & (n=n_{1},n_{2}) \\ &\boldsymbol{q}_{in}^{\mathsf{H}}\boldsymbol{y}_{ij}^{(n_{1},n_{2})} + y_{ijn} & (n\neq n_{1},n_{2}) \end{cases}. """ p = self.domain beta = self.beta flooring_fn = choose_flooring_fn(flooring_fn, method=self) Y = self.output Y2b = np.abs(Y) ** (2 - beta) Y2b = flooring_fn(Y2b) if self.partitioning: Z = self.latent T, V = self.basis, self.activation ZTV = self.reconstruct_nmf(T, V, latent=Z) ZTVbp = ZTV ** (beta / p) R_tilde = (2 / beta) * Y2b * ZTVbp else: T, V = self.basis, self.activation TV = self.reconstruct_nmf(T, V) TVbp = TV ** (beta / p) R_tilde = (2 / beta) * Y2b * TVbp varphi = 1 / R_tilde self.output = update_by_iss2( Y, varphi, flooring_fn=flooring_fn, pair_selector=self.pair_selector ) def compute_loss(self) -> float: r"""Compute loss :math:`\mathcal{L}`. :math:`\mathcal{L}` is given as follows: .. math:: \mathcal{L} = \frac{1}{J}\sum_{i,j,n} \left\{\left(\frac{|y_{ijn}|^{2}}{r_{ijn}}\right)^{\frac{\beta}{2}} + \log r_{ijn}\right\} - 2\sum_{i}\log|\det\boldsymbol{W}_{i}|, where .. math:: r_{ijn} = \left(\sum_{k}z_{nk}t_{ik}v_{kj}\right)^{\frac{2}{p}}, if ``partitioning=True``. Otherwise .. math:: r_{ijn} = \left(\sum_{k}t_{ikn}v_{kjn}\right)^{\frac{2}{p}}. Returns: Computed loss. """ beta = self.beta p = self.domain if self.demix_filter is None: X, Y = self.input, self.output Yb = np.abs(Y) ** beta X, Y = X.transpose(1, 0, 2), Y.transpose(1, 0, 2) X_Hermite = X.transpose(0, 2, 1).conj() XX_Hermite = X @ X_Hermite W = Y @ X_Hermite @ np.linalg.inv(XX_Hermite) else: X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Yb = np.abs(Y) ** beta if self.partitioning: Z = self.latent T, V = self.basis, self.activation ZTV = self.reconstruct_nmf(T, V, latent=Z) R = ZTV ** (beta / p) loss = Yb / R + (2 / p) * np.log(ZTV) else: T, V = self.basis, self.activation TV = self.reconstruct_nmf(T, V) R = TV ** (beta / p) loss = Yb / R + (2 / p) * np.log(TV) logdet = self.compute_logdet(W) # (n_bins,) loss = np.sum(loss.mean(axis=-1), axis=0) - 2 * logdet loss = loss.sum(axis=0).item() return loss def apply_projection_back(self) -> None: r"""Apply projection back technique to estimated spectrograms.""" if self.demix_filter is None: assert self.scale_restoration, "Set self.scale_restoration=True." X, Y = self.input, self.output Y_scaled = projection_back(Y, reference=X, reference_id=self.reference_id) self.output = Y_scaled else: super().apply_projection_back() def apply_minimal_distortion_principle(self) -> None: r"""Apply minimal distortion principle to estimated spectrograms.""" if self.demix_filter is None: X, Y = self.input, self.output Y_scaled = minimal_distortion_principle(Y, reference=X, reference_id=self.reference_id) self.output = Y_scaled else: super().apply_minimal_distortion_principle() ================================================ FILE: ssspy/bss/ipsdta.py ================================================ import functools from typing import Callable, List, Optional, Tuple, Union import numpy as np from ..algorithm import ( MINIMAL_DISTORTION_PRINCIPLE_KEYWORDS, PROJECTION_BACK_KEYWORDS, minimal_distortion_principle, projection_back, ) from ..linalg.mean import gmeanmh from ..linalg.quadratic import quadratic from ..linalg.sqrtm import invsqrtmh, sqrtmh from ..special.flooring import identity, max_flooring from ..special.psd import to_psd from ..utils.flooring import choose_flooring_fn from ._update_spatial_model import update_by_block_decomposition_vcd from .base import IterativeMethodBase spatial_algorithms = ["FPI", "VCD"] source_algorithms = ["EM", "MM"] EPS = 1e-10 class IPSDTABase(IterativeMethodBase): r"""Base class of independent positive semidefinite tensor analysis (IPSDTA). Args: n_basis (int): Number of PSDTF bases. flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` explicitly. Default: ``True``. record_loss (bool): Record the loss at each iteration of the update algorithm if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back. Default: ``0``. rng (numpy.random.Generator, optioinal): Random number generator. This is mainly used to randomly initialize PSDTF. If ``None`` is given, ``np.random.default_rng()`` is used. Default: ``None``. """ def __init__( self, n_basis: int, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), callbacks: Optional[ Union[Callable[["IPSDTABase"], None], List[Callable[["IPSDTABase"], None]]] ] = None, scale_restoration: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, rng: Optional[np.random.Generator] = None, ) -> None: self.source_normalization: Optional[Union[bool, str]] super().__init__(callbacks=callbacks, record_loss=record_loss) self.n_basis = n_basis if flooring_fn is None: self.flooring_fn = identity else: self.flooring_fn = flooring_fn self.input = None self.scale_restoration = scale_restoration if reference_id is None and scale_restoration: raise ValueError("Specify 'reference_id' if scale_restoration=True.") else: self.reference_id = reference_id if rng is None: rng = np.random.default_rng() self.rng = rng def __call__(self, input: np.ndarray, n_iter: int = 100, **kwargs) -> np.ndarray: r"""Separate a frequency-domain multichannel signal. Args: input (numpy.ndarray): The mixture signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). n_iter (int): The number of iterations of demixing filter updates. Default: ``100``. Returns: numpy.ndarray of the separated signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). """ self.input = input.copy() self._reset(**kwargs) super().__call__(n_iter=n_iter) if self.scale_restoration: self.restore_scale() self.output = self.separate(self.input, demix_filter=self.demix_filter) return self.output def __repr__(self) -> str: s = "IPSDTA(" s += "n_basis={n_basis}" s += ", scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) def _reset( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", **kwargs, ) -> None: r"""Reset attributes by given keyword arguments. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. kwargs: Keyword arguments to set as attributes of IPSDTA. """ assert self.input is not None, "Specify data!" flooring_fn = choose_flooring_fn(flooring_fn, method=self) for key in kwargs.keys(): setattr(self, key, kwargs[key]) X = self.input n_channels, n_bins, n_frames = X.shape n_sources = n_channels # n_channels == n_sources self.n_sources, self.n_channels = n_sources, n_channels self.n_bins, self.n_frames = n_bins, n_frames if not hasattr(self, "demix_filter"): W = np.eye(n_sources, n_channels, dtype=np.complex128) W = np.tile(W, reps=(n_bins, 1, 1)) else: if self.demix_filter is None: W = None else: # To avoid overwriting ``demix_filter`` given by keyword arguments. W = self.demix_filter.copy() self.demix_filter = W self.output = self.separate(X, demix_filter=W) self._init_psdtf(flooring_fn=flooring_fn, rng=self.rng) def _init_psdtf( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", rng: Optional[np.random.Generator] = None, ) -> None: r"""Initialize PSDTF. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. rng (numpy.random.Generator, optional): Random number generator. If ``None`` is given, ``np.random.default_rng()`` is used. Default: ``None``. """ n_basis = self.n_basis n_sources = self.n_sources n_bins, n_frames = self.n_bins, self.n_frames flooring_fn = choose_flooring_fn(flooring_fn, method=self) if rng is None: rng = np.random.default_rng() if not hasattr(self, "basis"): # should be positive semi-definite eye = np.eye(n_bins, dtype=np.complex128) rand = rng.random((n_sources, n_basis, n_bins)) T = rand[..., np.newaxis] * eye else: # To avoid overwriting. T = self.basis.copy() if not hasattr(self, "activation"): V = rng.random((n_sources, n_basis, n_frames)) V = flooring_fn(V) else: # To avoid overwriting. V = self.activation.copy() self.basis, self.activation = T, V if self.source_normalization: self.normalize_psdtf() def separate(self, input: np.ndarray, demix_filter: np.ndarray) -> np.ndarray: r"""Separate ``input`` using ``demixing_filter``. .. math:: \boldsymbol{y}_{ij} = \boldsymbol{W}_{i}\boldsymbol{x}_{ij} Args: input (numpy.ndarray): The mixture signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). demix_filter (numpy.ndarray): The demixing filters to separate ``input``. The shape is (n_bins, n_sources, n_channels). Returns: numpy.ndarray of the separated signal in frequency-domain. The shape is (n_sources, n_bins, n_frames). """ X, W = input, demix_filter Y = W @ X.transpose(1, 0, 2) output = Y.transpose(1, 0, 2) return output def reconstruct_psdtf( self, basis: np.ndarray, activation: np.ndarray, axis1: int = -2, axis2: int = -1, ) -> np.ndarray: r"""Reconstruct PSDTF. Args: basis (numpy.ndarray): Basis matrix. The shape is (n_sources, n_basis, n_bins, n_bins) if ``axis1=-1`` and ``axis2=-2``. Otherwise, (n_sources, n_bins, n_bins, n_basis). activation (numpy.ndarray): Activation matrix. The shape is (n_sources, n_basis, n_frames). axis1 (int): First axis of covariance matrix. Default: ``-2``. axis2 (int): Second axis of covariance matrix. Default: ``-1``. Returns: numpy.ndarray of reconstructed PSDTF. The shape is (n_sources, n_frames, n_bins, n_bins). """ T, V = basis, activation n_dims = T.ndim axis1 = n_dims + axis1 if axis1 < 0 else axis1 axis2 = n_dims + axis2 if axis2 < 0 else axis2 assert (axis1 == 1 and axis2 == 2) or (axis1 == 2 and axis2 == 3) if axis1 == 1 and axis2 == 2: T = T.transpose(0, 3, 1, 2) R = np.sum(T[:, :, np.newaxis, :, :] * V[:, :, :, np.newaxis, np.newaxis], axis=1) R = to_psd(R, axis1=2, axis2=3) return R def update_once(self) -> None: r"""Update demixing filters once.""" raise NotImplementedError("Implement 'update_once' method.") def normalize_psdtf(self) -> None: r"""Normalize PSDTF parameters.""" source_normalization = self.source_normalization T, V = self.basis, self.activation assert source_normalization, "Set source_normalization." trace = np.trace(T, axis1=-2, axis2=-1).real T = T / trace[:, :, np.newaxis, np.newaxis] V = V * trace[:, :, np.newaxis] self.basis, self.activation = T, V def compute_loss(self) -> float: r"""Compute loss :math:`\mathcal{L}`. Returns: Computed loss. """ raise NotImplementedError("Implement 'compute_loss' method.") def compute_logdet(self, demix_filter: np.ndarray) -> np.ndarray: r"""Compute log-determinant of demixing filter Args: demix_filter (numpy.ndarray): Demixing filters with shape of (n_bins, n_sources, n_channels). Returns: numpy.ndarray of computed log-determinant values. """ _, logdet = np.linalg.slogdet(demix_filter) # (n_bins,) return logdet def restore_scale(self) -> None: r"""Restore scale ambiguity. If ``self.scale_restoration="projection_back``, we use projection back technique. """ scale_restoration = self.scale_restoration assert scale_restoration, "Set self.scale_restoration=True." if type(scale_restoration) is bool: scale_restoration = PROJECTION_BACK_KEYWORDS[0] if scale_restoration in PROJECTION_BACK_KEYWORDS: self.apply_projection_back() elif scale_restoration in MINIMAL_DISTORTION_PRINCIPLE_KEYWORDS: self.apply_minimal_distortion_principle() else: raise ValueError("{} is not supported for scale restoration.".format(scale_restoration)) def apply_projection_back(self) -> None: r"""Apply projection back technique to estimated spectrograms.""" assert self.scale_restoration, "Set self.scale_restoration=True." X, W = self.input, self.demix_filter W_scaled = projection_back(W, reference_id=self.reference_id) Y_scaled = self.separate(X, demix_filter=W_scaled) self.output, self.demix_filter = Y_scaled, W_scaled def apply_minimal_distortion_principle(self) -> None: r"""Apply minimal distortion principle to estimated spectrograms.""" assert self.scale_restoration, "Set self.scale_restoration=True." X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Y_scaled = minimal_distortion_principle(Y, reference=X, reference_id=self.reference_id) X = X.transpose(1, 0, 2) Y = Y_scaled.transpose(1, 0, 2) X_Hermite = X.transpose(0, 2, 1).conj() W_scaled = Y @ X_Hermite @ np.linalg.inv(X @ X_Hermite) self.output, self.demix_filter = Y_scaled, W_scaled class BlockDecompositionIPSDTABase(IPSDTABase): r"""Base class of independent positive semidefinite tensor analysis (IPSDTA) \ using block decomposition of bases. Args: n_basis (int): Number of PSDTF bases. n_blocks (int): Number of sub-blocks. flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` explicitly. Default: ``True``. record_loss (bool): Record the loss at each iteration of the update algorithm if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back. Default: ``0``. rng (numpy.random.Generator, optioinal): Random number generator. This is mainly used to randomly initialize PSDTF. If ``None`` is given, ``np.random.default_rng()`` is used. Default: ``None``. """ def __init__( self, n_basis: int, n_blocks: int, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), callbacks: Optional[ Union[ Callable[["BlockDecompositionIPSDTABase"], None], List[Callable[["BlockDecompositionIPSDTABase"], None]], ] ] = None, scale_restoration: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, rng: Optional[np.random.Generator] = None, ) -> None: super().__init__( n_basis=n_basis, flooring_fn=flooring_fn, callbacks=callbacks, scale_restoration=scale_restoration, record_loss=record_loss, reference_id=reference_id, rng=rng, ) self.n_blocks = n_blocks def __repr__(self) -> str: s = "IPSDTA(" s += "n_basis={n_basis}" s += ", n_blocks={n_blocks}" s += ", scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) def _reset( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", **kwargs, ) -> None: r"""Reset attributes by given keyword arguments. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. kwargs: Keyword arguments to set as attributes of IPSDTA. """ assert self.input is not None, "Specify data!" flooring_fn = choose_flooring_fn(flooring_fn, method=self) for key in kwargs.keys(): setattr(self, key, kwargs[key]) X = self.input n_channels, n_bins, n_frames = X.shape n_sources = n_channels # n_channels == n_sources self.n_sources, self.n_channels = n_sources, n_channels self.n_bins, self.n_frames = n_bins, n_frames if not hasattr(self, "demix_filter"): W = np.eye(n_sources, n_channels, dtype=np.complex128) W = np.tile(W, reps=(n_bins, 1, 1)) else: if self.demix_filter is None: W = None else: # To avoid overwriting ``demix_filter`` given by keyword arguments. W = self.demix_filter.copy() self.demix_filter = W self.output = self.separate(X, demix_filter=W) self._init_block_decomposition_psdtf(flooring_fn=flooring_fn, rng=self.rng) def _init_block_decomposition_psdtf( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", rng: Optional[np.random.Generator] = None, ) -> None: r"""Initialize PSDTF using block decomposition of bases. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. rng (numpy.random.Generator, optional): Random number generator. If ``None`` is given, ``np.random.default_rng()`` is used. Default: ``None``. """ n_basis = self.n_basis n_sources = self.n_sources n_bins, n_frames = self.n_bins, self.n_frames n_blocks = self.n_blocks n_remains = self.n_remains n_neighbors = n_bins // n_blocks flooring_fn = choose_flooring_fn(flooring_fn, method=self) if rng is None: rng = np.random.default_rng() if not hasattr(self, "basis"): # should be positive semi-definite eye = np.eye(n_neighbors, dtype=np.complex128) rand = rng.random((n_sources, n_basis, n_blocks - n_remains, n_neighbors)) T = rand[..., np.newaxis] * eye if n_remains > 0: eye = np.eye(n_neighbors + 1, dtype=np.complex128) rand = rng.random((n_sources, n_basis, n_remains, n_neighbors + 1)) T_high = rand[..., np.newaxis] * eye T = T, T_high else: # To avoid overwriting. if n_remains > 0: T_low, T_high = self.basis T = T_low.copy(), T_high.copy() else: T = self.basis.copy() if not hasattr(self, "activation"): V = rng.random((n_sources, n_basis, n_frames)) V = flooring_fn(V) else: # To avoid overwriting. V = self.activation.copy() self.basis, self.activation = T, V if self.source_normalization: self.normalize_block_decomposition_psdtf() @property def n_remains(self) -> int: if not hasattr(self, "n_bins"): raise AttributeError("Since n_bins is not defined, n_remains cannot be computed.") return self.n_bins % self.n_blocks def reconstruct_block_decomposition_psdtf( self, basis: np.ndarray, activation: np.ndarray, axis1: int = -2, axis2: int = -1 ) -> np.ndarray: r"""Reconstruct PSDTF using block decomposition of bases. Args: basis (numpy.ndarray): Basis matrix. The shape is (n_sources, n_basis, n_blocks, n_neighbors, n_neighbors) if ``axis1=-1`` and ``axis2=-2``. Otherwise, (n_sources, n_blocks, n_neighbors, n_neighbors, n_basis). activation (numpy.ndarray): Activation matrix. The shape is (n_sources, n_basis, n_frames). axis1 (int): First axis of covariance matrix. Default: ``-2``. axis2 (int): Second axis of covariance matrix. Default: ``-1``. Returns: numpy.ndarray of reconstructed PSDTF. The shape is (n_sources, n_frames, n_blocks, n_neighbors, n_neighbors). """ def _reconstruct( basis: np.ndarray, activation: np.ndarray, axis1: int = -2, axis2: int = -1 ) -> np.ndarray: r"""Reconstruct PSDTF using block decomposition of bases. Args: basis (numpy.ndarray): Basis matrix. The shape is (n_sources, n_basis, n_blocks, n_neighbors, n_neighbors) if ``axis1=-1`` and ``axis2=-2``. Otherwise, (n_sources, n_blocks, n_neighbors, n_neighbors, n_basis). activation (numpy.ndarray): Activation matrix. The shape is (n_sources, n_basis, n_frames). axis1 (int): First axis of covariance matrix. Default: ``-2``. axis2 (int): Second axis of covariance matrix. Default: ``-1``. Returns: numpy.ndarray of reconstructed PSDTF. The shape is (n_sources, n_frames, n_blocks, n_neighbors, n_neighbors). """ na = np.newaxis T, V = basis, activation n_dims = T.ndim axis1 = n_dims + axis1 if axis1 < 0 else axis1 axis2 = n_dims + axis2 if axis2 < 0 else axis2 assert (axis1 == 2 and axis2 == 3) or (axis1 == 3 and axis2 == 4) if axis1 == 2 and axis2 == 3: T = T.transpose(0, 4, 1, 2, 3) R = np.sum( T[:, :, na, :, :, :] * V[:, :, :, na, na, na], axis=1, ) R = to_psd(R, axis1=3, axis2=4) return R if type(basis) is tuple: assert self.n_remains > 0, "n_remains is expected to be positive." T_low, T_high = basis V = activation R_low = _reconstruct(T_low, V, axis1=axis1, axis2=axis2) R_high = _reconstruct(T_high, V, axis1=axis1, axis2=axis2) R = R_low, R_high else: T = basis V = activation R = _reconstruct(T, V, axis1=axis1, axis2=axis2) return R def normalize_block_decomposition_psdtf(self, axis1: int = -2, axis2: int = -1) -> None: r"""Normalize PSDTF parameters using block decomposition of bases. Args: axis1 (int): First axis of covariance matrix. Default: ``-2``. axis2 (int): Second axis of covariance matrix. Default: ``-1``. """ source_normalization = self.source_normalization n_remains = self.n_remains na = np.newaxis T, V = self.basis, self.activation assert source_normalization, "Set source_normalization." if n_remains > 0: T_low, T_high = T trace_low = np.trace(T_low, axis1=axis1, axis2=axis2).real trace_high = np.trace(T_high, axis1=axis1, axis2=axis2).real trace = np.sum(trace_low, axis=-1) + np.sum(trace_high, axis=-1) T_low = T_low / trace[:, :, na, na, na] T_high = T_high / trace[:, :, na, na, na] T = T_low, T_high else: trace = np.trace(T, axis1=axis1, axis2=axis2).real trace = np.sum(trace, axis=-1) T = T / trace[:, :, na, na, na] V = V * trace[:, :, na] self.basis, self.activation = T, V class GaussIPSDTA(BlockDecompositionIPSDTABase): r"""Independent positive semidefinite tensor analysis (IPSDTA) \ on Gaussian distribution. Args: n_basis (int): Number of PSDTF bases. n_blocks (int): Number of sub-blocks. source_algorithm (str): Algorithm for PSDTF updates. Choose ``EM``, or ``MM``. Default: ``MM``. spatial_algorithm (str): Algorithm for demixing filter updates. Choose ``FPI``, or ``VCD``. Default: ``VCD``. flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. source_normalization (bool): If ``source_normalization=True``, normalize PSDTF parameters. Default: ``True``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` explicitly. Default: ``True``. record_loss (bool): Record the loss at each iteration of the update algorithm if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back. Default: ``0``. rng (numpy.random.Generator, optioinal): Random number generator. This is mainly used to randomly initialize PSDTF. If ``None`` is given, ``np.random.default_rng()`` is used. Default: ``None``. """ def __init__( self, n_basis: int, n_blocks: int, source_algorithm: str = "MM", spatial_algorithm: str = "VCD", flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), callbacks: Optional[ Union[ Callable[["GaussIPSDTA"], None], List[Callable[["GaussIPSDTA"], None]], ] ] = None, source_normalization: Optional[Union[bool, str]] = True, scale_restoration: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, rng: Optional[np.random.Generator] = None, ) -> None: super().__init__( n_basis, n_blocks, flooring_fn, callbacks, scale_restoration, record_loss, reference_id, rng, ) assert source_algorithm in source_algorithms, "Not support {}.".format(source_algorithms) assert spatial_algorithm in spatial_algorithms, "Not support {}.".format(spatial_algorithms) self.source_algorithm = source_algorithm self.spatial_algorithm = spatial_algorithm self.source_normalization = source_normalization def __repr__(self) -> str: s = "GaussIPSDTA(" s += "n_basis={n_basis}" s += ", n_blocks={n_blocks}" s += ", source_algorithm={source_algorithm}" s += ", spatial_algorithm={spatial_algorithm}" s += ", source_normalization={source_normalization}" s += ", scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) def _reset(self, **kwargs) -> None: r"""Reset attributes by given keyword arguments. Args: kwargs: Keyword arguments to set as attributes of IPSDTA. """ super()._reset(**kwargs) if self.spatial_algorithm == "FPI": if not hasattr(self, "fixed_point"): n_sources = self.n_sources n_bins = self.n_bins self.fixed_point = np.ones((n_sources, n_bins), dtype=np.complex128) else: self.fixed_point = self.fixed_point.copy() raise NotImplementedError("IPSDTA with fixed-point iteration is not supported.") def update_once( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update PSDTF parameters and demixing filters once. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) self.update_source_model(flooring_fn=flooring_fn) self.update_spatial_model(flooring_fn=flooring_fn) def update_source_model( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update PSDTF basis matrices and activations. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) if self.source_algorithm == "MM": self.update_source_model_mm(flooring_fn=flooring_fn) else: raise NotImplementedError("Not support {}.".format(self.source_algorithm)) if self.source_normalization: self.normalize_block_decomposition_psdtf() def update_source_model_mm( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update PSDTF basis matrices and activations by MM algorithm. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) self.update_basis_mm(flooring_fn=flooring_fn) self.update_activation_mm() def update_basis_mm( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update PSDTF basis matrices by MM algorithm. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ n_sources = self.n_sources n_frames = self.n_frames flooring_fn = choose_flooring_fn(flooring_fn, method=self) def _update_basis_mm( basis: np.ndarray, activation: np.ndarray, separated: np.ndarray = None ) -> np.ndarray: r""" Args: basis: (n_sources, n_basis, n_blocks, n_neighbors, n_neighbors) activation: (n_sources, n_basis, n_frames) separated: (n_sources, n_blocks, n_neighbors, n_frames) Returns: numpy.ndarray of updated basis matrix. """ T, V = basis, activation Y = separated na = np.newaxis R = self.reconstruct_block_decomposition_psdtf(T, V) R_inverse = np.linalg.inv(R) Y = Y.transpose(0, 3, 1, 2) YY_Hermite = Y[:, :, :, :, na] @ Y[:, :, :, na, :].conj() RYYR = R_inverse @ YY_Hermite @ R_inverse P = np.mean( V[:, :, :, na, na, na] * R_inverse[:, na, :, :, :, :], axis=2, ) Q = np.mean( V[:, :, :, na, na, na] * RYYR[:, na, :, :, :, :], axis=2, ) TQT = T @ Q @ T P = to_psd(P, flooring_fn=flooring_fn) TQT = to_psd(TQT, flooring_fn=flooring_fn) # geometric mean of P^(-1) and TQT T = gmeanmh(P, TQT, type=2) T = to_psd(T, flooring_fn=flooring_fn) return T n_bins = self.n_bins n_blocks = self.n_blocks n_remains = self.n_remains n_neighbors = n_bins // n_blocks X, W = self.input, self.demix_filter T, V = self.basis, self.activation Y = self.separate(X, demix_filter=W) if n_remains > 0: T_low, T_high = T Y_low, Y_high = np.split(Y, [(n_blocks - n_remains) * n_neighbors], axis=1) Y_low = Y_low.reshape(n_sources, n_blocks - n_remains, n_neighbors, n_frames) Y_high = Y_high.reshape(n_sources, n_remains, n_neighbors + 1, n_frames) T_low = _update_basis_mm(T_low, V, separated=Y_low) T_high = _update_basis_mm(T_high, V, separated=Y_high) T = T_low, T_high else: Y = Y.reshape(n_sources, n_blocks, n_neighbors, n_frames) T = _update_basis_mm(T, V, separated=Y) self.basis = T def update_activation_mm(self) -> None: r"""Update PSDTF activations by MM algorithm.""" def _compute_traces( basis: np.ndarray, activation: np.ndarray, separated: np.ndarray = None ) -> Tuple[np.ndarray, np.ndarray]: r""" Args: basis: (n_sources, n_basis, n_blocks, n_neighbors, n_neighbors) activation: (n_sources, n_basis, n_frames) separated: (n_sources, n_blocks, n_neighbors, n_frames) Returns: Tuple of numerator and denominator. Type of each item is ``numpy.ndarray``. """ T, V = basis, activation Y = separated na = np.newaxis R = self.reconstruct_block_decomposition_psdtf(T, V) R_inverse = np.linalg.inv(R) Y = Y.transpose(0, 3, 1, 2) YY_Hermite = Y[:, :, :, :, na] @ Y[:, :, :, na, :].conj() RYYR = R_inverse @ YY_Hermite @ R_inverse num = np.trace(RYYR[:, na, :] @ T[:, :, na], axis1=-2, axis2=-1) denom = np.trace(R_inverse[:, na, :] @ T[:, :, na], axis1=-2, axis2=-1) num = np.real(num).sum(axis=-1) denom = np.real(denom).sum(axis=-1) return num, denom n_sources = self.n_sources n_bins, n_frames = self.n_bins, self.n_frames n_blocks = self.n_blocks n_remains = self.n_remains n_neighbors = n_bins // n_blocks X, W = self.input, self.demix_filter T, V = self.basis, self.activation Y = self.separate(X, demix_filter=W) if n_remains > 0: T_low, T_high = T Y_low, Y_high = np.split(Y, [(n_blocks - n_remains) * n_neighbors], axis=1) Y_low = Y_low.reshape(n_sources, n_blocks - n_remains, n_neighbors, n_frames) Y_high = Y_high.reshape(n_sources, n_remains, n_neighbors + 1, n_frames) num_low, denom_low = _compute_traces(T_low, V, separated=Y_low) num_high, denom_high = _compute_traces(T_high, V, separated=Y_high) num = num_low + num_high denom = denom_low + denom_high else: Y = Y.reshape(n_sources, n_blocks, n_neighbors, n_frames) num, denom = _compute_traces(T, V, separated=Y) self.activation = V * np.sqrt(num / denom) def update_spatial_model( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update demixing filters once. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) if self.spatial_algorithm == "VCD": self.update_spatial_model_vcd(flooring_fn=flooring_fn) else: raise NotImplementedError("Not support {}.".format(self.spatial_algorithm)) def update_spatial_model_vcd( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update demixing filters once by VCD. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) def _update(input: np.ndarray, demix_filter: np.ndarray, covariance: np.ndarray): r""" Args: input (np.ndarray): Mixture spectrogram. The shape is (n_channels, n_blocks, n_neighbors, n_frames). demix_filter (np.ndarray): Demixing filter split by frequnecy bands. The shape is (n_blocks, n_neighbors, n_sources, n_channels). covariance (np.ndarray): Rconstructed PSDTF. The shape is (n_sources, n_frames, n_blocks, n_neighbors, n_neighbors). Returns: np.ndarray of demixing filters after update. """ X, W = input, demix_filter R = covariance XX = X[:, na, :, :, na] * X[na, :, :, na, :].conj() XX = XX.transpose(2, 3, 4, 0, 1, 5) R_inverse = np.linalg.inv(R) R_inverse = R_inverse.transpose(2, 4, 3, 0, 1) RXX = np.mean(R_inverse[:, :, :, :, na, na] * XX[:, :, :, na, :, :], axis=-1) W = update_by_block_decomposition_vcd( W, weighted_covariance=RXX, singular_fn=lambda x: np.abs(x) < flooring_fn(0) ) return W n_sources, n_channels = self.n_sources, self.n_channels n_bins, n_frames = self.n_bins, self.n_frames n_blocks = self.n_blocks n_remains = self.n_remains na = np.newaxis n_neighbors = n_bins // n_blocks X, W = self.input, self.demix_filter T, V = self.basis, self.activation R = self.reconstruct_block_decomposition_psdtf(T, V) if n_remains > 0: X_low, X_high = np.split(X, [(n_blocks - n_remains) * n_neighbors], axis=1) W_low, W_high = np.split(W, [(n_blocks - n_remains) * n_neighbors], axis=0) R_low, R_high = R # Lower frequency X_low = X_low.reshape(n_channels, n_blocks - n_remains, n_neighbors, n_frames) W_low = W_low.reshape(n_blocks - n_remains, n_neighbors, n_sources, n_channels) W_low = _update(X_low, demix_filter=W_low, covariance=R_low) # Higher frequency X_high = X_high.reshape(n_channels, n_remains, n_neighbors + 1, n_frames) W_high = W_high.reshape(n_remains, n_neighbors + 1, n_sources, n_channels) W_high = _update(X_high, demix_filter=W_high, covariance=R_high) W_low = W_low.reshape((n_blocks - n_remains) * n_neighbors, n_sources, n_channels) W_high = W_high.reshape(n_remains * (n_neighbors + 1), n_sources, n_channels) W = np.concatenate([W_low, W_high], axis=0) else: X = X.reshape(n_channels, n_blocks, n_neighbors, n_frames) W = W.reshape(n_blocks, n_neighbors, n_sources, n_channels) W = _update(X, demix_filter=W, covariance=R) W = W.reshape(n_blocks * n_neighbors, n_sources, n_channels) self.demix_filter = W def compute_loss(self) -> float: r"""Compute loss :math:`\mathcal{L}`. Returns: Computed loss. """ def _compute_block_decomposition_loss( separated: np.ndarray, demix_filter: np.ndarray, covariance: np.ndarray ) -> float: r""" Args: separated (np.ndarray): Separated signal with shape of (n_sources, n_frames, n_blocks, n_neighbors). demix_filter (np.ndarray): Demixing filters with shape of (n_blocks, n_neighbors, n_sources, n_channels). covariance: Covariance matrix with shape of (n_sources, n_frames, n_blocks, n_neighbors, n_neighbors). """ Y, W = separated, demix_filter R = covariance n_sources, n_frames, n_blocks, n_neighbors = Y.shape Y = Y.reshape(n_sources, n_frames, n_blocks, n_neighbors, 1) R_inverse = np.linalg.inv(R) Y_Hermite = np.swapaxes(Y, 3, 4).conj() YRY = np.sum(Y_Hermite @ R_inverse @ Y, axis=(0, 2, 3, 4)) YRY = np.real(YRY) YRY = np.maximum(YRY, 0) _, logdetR = np.linalg.slogdet(R) logdetR = logdetR.sum(axis=(0, 2)) logdetW = self.compute_logdet(W) loss = np.mean(YRY + logdetR, axis=0) - 2 * logdetW.sum(axis=(0, 1)) loss = loss.item() return loss n_sources, n_channels = self.n_sources, self.n_channels n_bins, n_frames = self.n_bins, self.n_frames n_blocks = self.n_blocks n_remains = self.n_remains n_neighbors = n_bins // n_blocks X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Y = Y.transpose(0, 2, 1) T, V = self.basis, self.activation R = self.reconstruct_block_decomposition_psdtf(T, V) if n_remains > 0: Y_low, Y_high = np.split(Y, [(n_blocks - n_remains) * n_neighbors], axis=2) W_low, W_high = np.split(W, [(n_blocks - n_remains) * n_neighbors], axis=0) R_low, R_high = R Y_low = Y_low.reshape(n_sources, n_frames, (n_blocks - n_remains), n_neighbors) Y_high = Y_high.reshape(n_sources, n_frames, n_remains, n_neighbors + 1) W_low = W_low.reshape((n_blocks - n_remains), n_neighbors, n_sources, n_channels) W_high = W_high.reshape(n_remains, n_neighbors + 1, n_sources, n_channels) loss_low = _compute_block_decomposition_loss( Y_low, demix_filter=W_low, covariance=R_low ) loss_high = _compute_block_decomposition_loss( Y_high, demix_filter=W_high, covariance=R_high ) loss = loss_low + loss_high else: Y = Y.reshape(n_sources, n_frames, n_blocks, n_neighbors) W = W.reshape(n_blocks, n_neighbors, n_sources, n_channels) loss = _compute_block_decomposition_loss(Y, demix_filter=W, covariance=R) return loss class TIPSDTA(BlockDecompositionIPSDTABase): r"""Independent positive semidefinite tensor analysis (IPSDTA) \ on Student's t distribution. Args: n_basis (int): Number of PSDTF bases. n_blocks (int): Number of sub-blocks. dof (float): Degree of freedom parameter. source_algorithm (str): Algorithm for PSDTF updates. Only ``MM`` is supported. Default: ``MM``. spatial_algorithm (str): Algorithm for demixing filter updates. Only ``VCD`` is supported. Default: ``VCD``. flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. source_normalization (bool): If ``source_normalization=True``, normalize PSDTF parameters. Default: ``True``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` explicitly. Default: ``True``. record_loss (bool): Record the loss at each iteration of the update algorithm if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back. Default: ``0``. rng (numpy.random.Generator, optioinal): Random number generator. This is mainly used to randomly initialize PSDTF. If ``None`` is given, ``np.random.default_rng()`` is used. Default: ``None``. """ def __init__( self, n_basis: int, n_blocks: int, dof: float, source_algorithm: str = "MM", spatial_algorithm: str = "VCD", flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), callbacks: Optional[ Union[ Callable[["GaussIPSDTA"], None], List[Callable[["GaussIPSDTA"], None]], ] ] = None, source_normalization: Optional[Union[bool, str]] = True, scale_restoration: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, rng: Optional[np.random.Generator] = None, ) -> None: super().__init__( n_basis, n_blocks, flooring_fn, callbacks, scale_restoration, record_loss, reference_id, rng, ) assert source_algorithm in source_algorithms, "Not support {}.".format(source_algorithm) assert spatial_algorithm in spatial_algorithms, "Not support {}.".format(spatial_algorithm) self.dof = dof self.source_algorithm = source_algorithm self.source_normalization = source_normalization self.spatial_algorithm = spatial_algorithm def __repr__(self) -> str: s = "TIPSDTA(" s += "n_basis={n_basis}" s += ", n_blocks={n_blocks}" s += ", dof={dof}" s += ", source_algorithm={source_algorithm}" s += ", spatial_algorithm={spatial_algorithm}" s += ", source_normalization={source_normalization}" s += ", scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) def update_once( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update PSDTF parameters and demixing filters once. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) self.update_source_model(flooring_fn=flooring_fn) self.update_spatial_model(flooring_fn=flooring_fn) def update_source_model( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update PSDTF basis matrices and activations. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) if self.source_algorithm == "MM": self.update_source_model_mm(flooring_fn=flooring_fn) else: raise NotImplementedError("Not support {}.".format(self.source_algorithm)) if self.source_normalization: self.normalize_block_decomposition_psdtf() def update_source_model_mm( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update PSDTF basis matrices and activations by MM algorithm. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) self.update_basis_mm(flooring_fn=flooring_fn) self.update_activation_mm() def update_basis_mm( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update PSDTF basis matrices by MM algorithm. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ n_sources = self.n_sources n_frames = self.n_frames flooring_fn = choose_flooring_fn(flooring_fn, method=self) def _quadratic(Y: np.ndarray, R: np.ndarray) -> np.ndarray: r""" Args: Y (np.ndarray): Separated spectrams with shape of (n_sources, n_blocks, n_neighbors, n_frames). R (np.ndarray): Covariance matrix with shape of (n_sources, n_frames, n_blocks, n_neighbors, n_neighbors). Returns: Quadratic forms with shape of (n_sources, n_frames). """ Y = Y.transpose(0, 3, 1, 2) R_inverse = np.linalg.inv(R) YRY = quadratic(Y, R_inverse) YRY = np.real(YRY) YRY = np.maximum(YRY, 0) YRY = YRY.sum(axis=-1) return YRY def _update_basis_mm( basis: np.ndarray, activation: np.ndarray, separated: np.ndarray = None, weight: np.ndarray = None, ) -> np.ndarray: r""" Args: basis: (n_sources, n_basis, n_blocks, n_neighbors, n_neighbors) activation: (n_sources, n_basis, n_frames) separated: (n_sources, n_blocks, n_neighbors, n_frames) weight: (n_sources, n_frames) Returns: numpy.ndarray of updated basis matrix. """ T, V = basis, activation Y = separated pi = weight na = np.newaxis R = self.reconstruct_block_decomposition_psdtf(T, V) R_inverse = np.linalg.inv(R) Y = Y.transpose(0, 3, 1, 2) YY_Hermite = Y[:, :, :, :, na] @ Y[:, :, :, na, :].conj() RYYR = R_inverse @ YY_Hermite @ R_inverse piRYYR = pi[:, :, na, na, na] * RYYR P = np.mean( V[:, :, :, na, na, na] * R_inverse[:, na, :, :, :, :], axis=2, ) Q = np.mean( V[:, :, :, na, na, na] * piRYYR[:, na, :, :, :, :], axis=2, ) Q = to_psd(Q, flooring_fn=flooring_fn) Q_sqrt = sqrtmh(Q) QTPTQ = Q_sqrt @ T @ P @ T @ Q_sqrt QTPTQ = to_psd(QTPTQ, flooring_fn=flooring_fn) T = T @ Q_sqrt @ invsqrtmh(QTPTQ, flooring_fn=flooring_fn) @ Q_sqrt @ T T = to_psd(T, flooring_fn=flooring_fn) return T n_bins = self.n_bins n_blocks = self.n_blocks n_remains = self.n_remains n_neighbors = n_bins // n_blocks nu = self.dof X, W = self.input, self.demix_filter T, V = self.basis, self.activation Y = self.separate(X, demix_filter=W) R = self.reconstruct_block_decomposition_psdtf(T, V) if n_remains > 0: T_low, T_high = T Y_low, Y_high = np.split(Y, [(n_blocks - n_remains) * n_neighbors], axis=1) Y_low = Y_low.reshape(n_sources, n_blocks - n_remains, n_neighbors, n_frames) Y_high = Y_high.reshape(n_sources, n_remains, n_neighbors + 1, n_frames) R_low, R_high = R YRY_low = _quadratic(Y_low, R_low) YRY_high = _quadratic(Y_high, R_high) YRY = YRY_low + YRY_high pi = (nu + 2 * n_bins) / (nu + 2 * YRY) T_low = _update_basis_mm(T_low, V, separated=Y_low, weight=pi) T_high = _update_basis_mm(T_high, V, separated=Y_high, weight=pi) T = T_low, T_high else: Y = Y.reshape(n_sources, n_blocks, n_neighbors, n_frames) YRY = _quadratic(Y, R) pi = (nu + 2 * n_bins) / (nu + 2 * YRY) T = _update_basis_mm(T, V, separated=Y, weight=pi) self.basis = T def update_activation_mm(self) -> None: r"""Update PSDTF activations by MM algorithm.""" def _quadratic(Y: np.ndarray, R: np.ndarray) -> np.ndarray: r""" Args: Y (np.ndarray): Separated spectrams with shape of (n_sources, n_blocks, n_neighbors, n_frames). R (np.ndarray): Covariance matrix with shape of (n_sources, n_frames, n_blocks, n_neighbors, n_neighbors). Returns: Quadratic forms with shape of (n_sources, n_frames). """ Y = Y.transpose(0, 3, 1, 2) R_inverse = np.linalg.inv(R) YRY = quadratic(Y, R_inverse) YRY = np.real(YRY) YRY = np.maximum(YRY, 0) YRY = YRY.sum(axis=-1) return YRY def _compute_traces( basis: np.ndarray, activation: np.ndarray, separated: np.ndarray = None, weight: np.ndarray = None, ) -> Tuple[np.ndarray, np.ndarray]: r""" Args: basis: (n_sources, n_basis, n_blocks, n_neighbors, n_neighbors) activation: (n_sources, n_basis, n_frames) separated: (n_sources, n_blocks, n_neighbors, n_frames) Returns: Tuple of numerator and denominator. Type of each item is ``numpy.ndarray``. """ T, V = basis, activation Y = separated.transpose(0, 3, 1, 2) pi = weight na = np.newaxis R = self.reconstruct_block_decomposition_psdtf(T, V) R_inverse = np.linalg.inv(R) YY_Hermite = Y[:, :, :, :, na] @ Y[:, :, :, na, :].conj() RYYR = R_inverse @ YY_Hermite @ R_inverse piRYYR = pi[:, :, na, na, na] * RYYR num = np.trace(piRYYR[:, na, :] @ T[:, :, na], axis1=-2, axis2=-1) denom = np.trace(R_inverse[:, na, :] @ T[:, :, na], axis1=-2, axis2=-1) num = np.real(num).sum(axis=-1) denom = np.real(denom).sum(axis=-1) return num, denom n_sources = self.n_sources n_bins, n_frames = self.n_bins, self.n_frames n_blocks = self.n_blocks n_remains = self.n_remains n_neighbors = n_bins // n_blocks nu = self.dof X, W = self.input, self.demix_filter T, V = self.basis, self.activation Y = self.separate(X, demix_filter=W) R = self.reconstruct_block_decomposition_psdtf(T, V) if n_remains > 0: T_low, T_high = T Y_low, Y_high = np.split(Y, [(n_blocks - n_remains) * n_neighbors], axis=1) Y_low = Y_low.reshape(n_sources, n_blocks - n_remains, n_neighbors, n_frames) Y_high = Y_high.reshape(n_sources, n_remains, n_neighbors + 1, n_frames) R_low, R_high = R YRY_low = _quadratic(Y_low, R_low) YRY_high = _quadratic(Y_high, R_high) YRY = YRY_low + YRY_high pi = (nu + 2 * n_bins) / (nu + 2 * YRY) num_low, denom_low = _compute_traces(T_low, V, separated=Y_low, weight=pi) num_high, denom_high = _compute_traces(T_high, V, separated=Y_high, weight=pi) num = num_low + num_high denom = denom_low + denom_high else: Y = Y.reshape(n_sources, n_blocks, n_neighbors, n_frames) YRY = _quadratic(Y, R) pi = (nu + 2 * n_bins) / (nu + 2 * YRY) num, denom = _compute_traces(T, V, separated=Y, weight=pi) self.activation = V * np.sqrt(num / denom) def update_spatial_model( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update demixing filters once. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) if self.spatial_algorithm == "VCD": self.update_spatial_model_vcd(flooring_fn=flooring_fn) else: raise NotImplementedError("Not support {}.".format(self.spatial_algorithm)) def update_spatial_model_vcd( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update demixing filters once by VCD. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) def _quadratic(Y: np.ndarray, R: np.ndarray) -> np.ndarray: r""" Args: Y (np.ndarray): Separated spectrams with shape of (n_sources, n_blocks, n_neighbors, n_frames). R (np.ndarray): Covariance matrix with shape of (n_sources, n_frames, n_blocks, n_neighbors, n_neighbors). Returns: Quadratic forms with shape of (n_sources, n_frames). """ Y = Y.transpose(2, 3, 0, 1) R_inverse = np.linalg.inv(R) YRY = quadratic(Y, R_inverse) YRY = np.real(YRY) YRY = np.maximum(YRY, 0) YRY = YRY.sum(axis=-1) return YRY def _update( input: np.ndarray, demix_filter: np.ndarray, covariance: np.ndarray, weight: np.ndarray = None, ): X, W = input, demix_filter R = covariance pi = weight na = np.newaxis XX = X[:, na, :, :, na] * X[na, :, :, na, :].conj() XX = XX.transpose(2, 3, 4, 0, 1, 5) R_inverse = np.linalg.inv(R) R_inverse = R_inverse.transpose(2, 4, 3, 0, 1) pi_R_inverse = pi * R_inverse RXX = np.mean(pi_R_inverse[:, :, :, :, na, na] * XX[:, :, :, na, :, :], axis=-1) W = update_by_block_decomposition_vcd( W, weighted_covariance=RXX, singular_fn=lambda x: np.abs(x) < flooring_fn(0) ) return W n_sources, n_channels = self.n_sources, self.n_channels n_bins, n_frames = self.n_bins, self.n_frames n_blocks = self.n_blocks n_remains = self.n_remains n_neighbors = n_bins // n_blocks nu = self.dof X, W = self.input, self.demix_filter T, V = self.basis, self.activation R = self.reconstruct_block_decomposition_psdtf(T, V) if n_remains > 0: X_low, X_high = np.split(X, [(n_blocks - n_remains) * n_neighbors], axis=1) W_low, W_high = np.split(W, [(n_blocks - n_remains) * n_neighbors], axis=0) R_low, R_high = R # Lower frequency X_low = X_low.reshape(n_channels, n_blocks - n_remains, n_neighbors, n_frames) W_low = W_low.reshape(n_blocks - n_remains, n_neighbors, n_sources, n_channels) Y_low = W_low @ X_low.transpose(1, 2, 0, 3) # Higher frequency X_high = X_high.reshape(n_channels, n_remains, n_neighbors + 1, n_frames) W_high = W_high.reshape(n_remains, n_neighbors + 1, n_sources, n_channels) Y_high = W_high @ X_high.transpose(1, 2, 0, 3) YRY_low = _quadratic(Y_low, R_low) YRY_high = _quadratic(Y_high, R_high) YRY = YRY_low + YRY_high pi = (nu + 2 * n_bins) / (nu + 2 * YRY) W_low = _update(X_low, demix_filter=W_low, covariance=R_low, weight=pi) W_high = _update(X_high, demix_filter=W_high, covariance=R_high, weight=pi) W_low = W_low.reshape((n_blocks - n_remains) * n_neighbors, n_sources, n_channels) W_high = W_high.reshape(n_remains * (n_neighbors + 1), n_sources, n_channels) W = np.concatenate([W_low, W_high], axis=0) else: X = X.reshape(n_channels, n_blocks, n_neighbors, n_frames) W = W.reshape(n_blocks, n_neighbors, n_sources, n_channels) Y = W @ X.transpose(1, 2, 0, 3) YRY = _quadratic(Y, R) pi = (nu + 2 * n_bins) / (nu + 2 * YRY) W = _update(X, demix_filter=W, covariance=R, weight=pi) W = W.reshape(n_blocks * n_neighbors, n_sources, n_channels) self.demix_filter = W def compute_loss(self) -> float: r"""Compute loss :math:`\mathcal{L}`. Returns: Computed loss. """ def _quadratic(Y: np.ndarray, R: np.ndarray) -> np.ndarray: r""" Args: Y (np.ndarray): Separated spectrams with shape of (n_sources, n_blocks, n_neighbors, n_frames). R (np.ndarray): Covariance matrix with shape of (n_sources, n_frames, n_blocks, n_neighbors, n_neighbors). Returns: Quadratic forms with shape of (n_sources, n_frames). """ Y = Y.transpose(0, 3, 1, 2) R_inverse = np.linalg.inv(R) YRY = quadratic(Y, R_inverse) YRY = np.real(YRY) YRY = np.maximum(YRY, 0) YRY = YRY.sum(axis=-1) return YRY n_sources, n_channels = self.n_sources, self.n_channels n_bins, n_frames = self.n_bins, self.n_frames nu = self.dof n_blocks = self.n_blocks n_remains = self.n_remains n_neighbors = n_bins // n_blocks X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) T, V = self.basis, self.activation R = self.reconstruct_block_decomposition_psdtf(T, V) if n_remains > 0: Y_low, Y_high = np.split(Y, [(n_blocks - n_remains) * n_neighbors], axis=1) W_low, W_high = np.split(W, [(n_blocks - n_remains) * n_neighbors], axis=0) R_low, R_high = R Y_low = Y_low.reshape(n_sources, (n_blocks - n_remains), n_neighbors, n_frames) Y_high = Y_high.reshape(n_sources, n_remains, n_neighbors + 1, n_frames) W_low = W_low.reshape((n_blocks - n_remains), n_neighbors, n_sources, n_channels) W_high = W_high.reshape(n_remains, n_neighbors + 1, n_sources, n_channels) YRY_low = _quadratic(Y_low, R_low) YRY_high = _quadratic(Y_high, R_high) YRY = YRY_low + YRY_high loss = np.sum(((nu + 2 * n_bins) / 2) * np.log(1 + (2 / nu) * YRY), axis=0) _, logdetR_low = np.linalg.slogdet(R_low) logdetR_low = logdetR_low.sum(axis=(0, 2)) _, logdetR_high = np.linalg.slogdet(R_high) logdetR_high = logdetR_high.sum(axis=(0, 2)) logdetR = logdetR_low + logdetR_high logdetW_low = self.compute_logdet(W_low) logdetW_high = self.compute_logdet(W_high) logdetW = logdetW_low.sum(axis=(0, 1)) + logdetW_high.sum(axis=(0, 1)) else: Y = Y.reshape(n_sources, n_blocks, n_neighbors, n_frames) W = W.reshape(n_blocks, n_neighbors, n_sources, n_channels) YRY = _quadratic(Y, R) loss = np.sum(((nu + 2 * n_bins) / 2) * np.log(1 + (2 / nu) * YRY), axis=0) _, logdetR = np.linalg.slogdet(R) logdetR = logdetR.sum(axis=(0, 2)) logdetW = self.compute_logdet(W) logdetW = logdetW.sum(axis=(0, 1)) loss = np.mean(loss + logdetR, axis=0) - 2 * logdetW loss = loss.item() return loss ================================================ FILE: ssspy/bss/iva.py ================================================ import functools from typing import Callable, Iterable, List, Optional, Tuple, Union import numpy as np from ..algorithm import ( MINIMAL_DISTORTION_PRINCIPLE_KEYWORDS, PROJECTION_BACK_KEYWORDS, minimal_distortion_principle, projection_back, ) from ..linalg import eigh, prox from ..special.flooring import identity, max_flooring from ..transform import whiten from ..utils.flooring import choose_flooring_fn from ..utils.select_pair import sequential_pair_selector from ._update_spatial_model import ( update_by_ip1, update_by_ip2_one_pair, update_by_ipa, update_by_iss1, update_by_iss2, ) from .admmbss import ADMMBSS from .base import IterativeMethodBase from .pdsbss import PDSBSS __all__ = [ "GradIVA", "NaturalGradIVA", "FastIVA", "FasterIVA", "AuxIVA", "PDSIVA", "ADMMIVA", "GradLaplaceIVA", "GradGaussIVA", "NaturalGradLaplaceIVA", "NaturalGradGaussIVA", "AuxLaplaceIVA", "AuxGaussIVA", ] spatial_algorithms = ["IP", "IP1", "IP2", "ISS", "ISS1", "ISS2", "IPA"] EPS = 1e-10 class IVABase(IterativeMethodBase): r"""Base class of independent vector analysis (IVA). Args: flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` or ``minimal_distortion_principle``. Default: ``True``. record_loss (bool): Record the loss at each iteration of the update algorithm if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back and minimal distortion principle. Default: ``0``. """ def __init__( self, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), callbacks: Optional[ Union[Callable[["IVABase"], None], List[Callable[["IVABase"], None]]] ] = None, scale_restoration: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, ) -> None: super().__init__(callbacks=callbacks, record_loss=record_loss) if flooring_fn is None: self.flooring_fn = identity else: self.flooring_fn = flooring_fn self.input = None self.scale_restoration = scale_restoration if reference_id is None and scale_restoration: raise ValueError("Specify 'reference_id' if scale_restoration=True.") else: self.reference_id = reference_id def __call__( self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs ) -> np.ndarray: r"""Separate a frequency-domain multichannel signal. Args: input (numpy.ndarray): Mixture signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). n_iter (int): Number of iterations of demixing filter updates. Default: ``100``. initial_call (bool): If ``True``, perform callbacks (and computation of loss if necessary) before iterations. Returns: numpy.ndarray of the separated signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). """ self.input = input.copy() self._reset(**kwargs) raise NotImplementedError("Implement '__call__' method.") def __repr__(self) -> str: s = "IVA(" s += "scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) def _reset(self, **kwargs) -> None: r"""Reset attributes by given keyword arguments. Args: kwargs: Keyword arguments to set as attributes of IVA. """ assert self.input is not None, "Specify data!" for key in kwargs.keys(): setattr(self, key, kwargs[key]) X = self.input n_channels, n_bins, n_frames = X.shape n_sources = n_channels # n_channels == n_sources self.n_sources, self.n_channels = n_sources, n_channels self.n_bins, self.n_frames = n_bins, n_frames if not hasattr(self, "demix_filter"): W = np.eye(n_sources, n_channels, dtype=np.complex128) W = np.tile(W, reps=(n_bins, 1, 1)) else: if self.demix_filter is None: W = None else: # To avoid overwriting ``demix_filter`` given by keyword arguments. W = self.demix_filter.copy() self.demix_filter = W self.output = self.separate(X, demix_filter=W) def separate(self, input: np.ndarray, demix_filter: np.ndarray) -> np.ndarray: r"""Separate ``input`` using ``demixing_filter``. .. math:: \boldsymbol{y}_{ij} = \boldsymbol{W}_{i}\boldsymbol{x}_{ij} Args: input (numpy.ndarray): The mixture signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). demix_filter (numpy.ndarray): The demixing filters to separate ``input``. The shape is (n_bins, n_sources, n_channels). Returns: numpy.ndarray of the separated signal in frequency-domain. The shape is (n_sources, n_bins, n_frames). """ X, W = input, demix_filter Y = W @ X.transpose(1, 0, 2) output = Y.transpose(1, 0, 2) return output def update_once(self) -> None: r"""Update demixing filters once.""" raise NotImplementedError("Implement 'update_once' method.") def compute_loss(self) -> float: r"""Compute loss :math:`\mathcal{L}`. :math:`\mathcal{L}` is given as follows: .. math:: \mathcal{L} \ &= \frac{1}{J}\sum_{j,n}G(\vec{\boldsymbol{y}}_{jn}) \ - 2\sum_{i}\log|\det\boldsymbol{W}_{i}|, \\ G(\vec{\boldsymbol{y}}_{jn}) \ &= - \log p(\vec{\boldsymbol{y}}_{jn}) Returns: Computed loss. """ X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) # (n_sources, n_bins, n_frames) logdet = self.compute_logdet(W) # (n_bins,) G = self.contrast_fn(Y) # (n_sources, n_frames) loss = np.sum(np.mean(G, axis=1), axis=0) - 2 * np.sum(logdet, axis=0) loss = loss.item() return loss def compute_logdet(self, demix_filter: np.ndarray) -> np.ndarray: r"""Compute log-determinant of demixing filter. Args: demix_filter (numpy.ndarray): Demixing filters with shape of (n_bins, n_sources, n_channels). Returns: numpy.ndarray of computed log-determinant values. """ _, logdet = np.linalg.slogdet(demix_filter) # (n_bins,) return logdet def restore_scale(self) -> None: r"""Restore scale ambiguity. If ``self.scale_restoration=projection_back``, we use projection back technique. If ``self.scale_restoration=minimal_distortion_principle``, we use minimal distortion principle. """ scale_restoration = self.scale_restoration assert scale_restoration, "Set self.scale_restoration=True." if type(scale_restoration) is bool: scale_restoration = PROJECTION_BACK_KEYWORDS[0] if scale_restoration in PROJECTION_BACK_KEYWORDS: self.apply_projection_back() elif scale_restoration in MINIMAL_DISTORTION_PRINCIPLE_KEYWORDS: self.apply_minimal_distortion_principle() else: raise ValueError("{} is not supported for scale restoration.".format(scale_restoration)) def apply_projection_back(self) -> None: r"""Apply projection back technique to estimated spectrograms.""" assert self.scale_restoration, "Set self.scale_restoration=True." X, W = self.input, self.demix_filter W_scaled = projection_back(W, reference_id=self.reference_id) Y_scaled = self.separate(X, demix_filter=W_scaled) self.output, self.demix_filter = Y_scaled, W_scaled def apply_minimal_distortion_principle(self) -> None: r"""Apply minimal distortion principle to estimated spectrograms.""" assert self.scale_restoration, "Set self.scale_restoration=True." X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Y_scaled = minimal_distortion_principle(Y, reference=X, reference_id=self.reference_id) X = X.transpose(1, 0, 2) Y = Y_scaled.transpose(1, 0, 2) X_Hermite = X.transpose(0, 2, 1).conj() W_scaled = Y @ X_Hermite @ np.linalg.inv(X @ X_Hermite) self.output, self.demix_filter = Y_scaled, W_scaled class GradIVABase(IVABase): r"""Base class of independent vector analysis (IVA) using gradient descent. Args: step_size (float): A step size of the gradient descent. Default: ``1e-1``. contrast_fn (callable): A contrast function which corresponds to :math:`-\log p(\vec{\boldsymbol{y}}_{jn})`. This function is expected to receive (n_channels, n_bins, n_frames) and return (n_channels, n_frames). score_fn (callable): A score function which corresponds to the partial derivative of the contrast function. This function is expected to receive (n_channels, n_bins, n_frames) and return (n_channels, n_bins, n_frames). flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. is_holonomic (bool): If ``is_holonomic=True``, Holonomic-type update is used. Otherwise, Nonholonomic-type update is used. Default: ``False``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` or ``minimal_distortion_principle``. Default: ``True``. record_loss (bool): Record the loss at each iteration of the gradient descent if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back and minimal distortion principle. Default: ``0``. """ def __init__( self, step_size: float = 1e-1, contrast_fn: Callable[[np.ndarray], np.ndarray] = None, score_fn: Callable[[np.ndarray], np.ndarray] = None, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), callbacks: Optional[ Union[Callable[["GradIVABase"], None], List[Callable[["GradIVABase"], None]]] ] = None, is_holonomic: bool = False, scale_restoration: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, ) -> None: super().__init__( flooring_fn=flooring_fn, callbacks=callbacks, scale_restoration=scale_restoration, record_loss=record_loss, reference_id=reference_id, ) self.step_size = step_size if contrast_fn is None: raise ValueError("Specify contrast function.") else: self.contrast_fn = contrast_fn if score_fn is None: raise ValueError("Specify score function.") else: self.score_fn = score_fn self.is_holonomic = is_holonomic def __call__( self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs ) -> np.ndarray: r"""Separate a frequency-domain multichannel signal. Args: input (numpy.ndarray): The mixture signal in frequency-domain. \ The shape is (n_channels, n_bins, n_frames). n_iter (int): The number of iterations of demixing filter updates. \ Default: ``100``. initial_call (bool): If ``True``, perform callbacks (and computation of loss if necessary) before iterations. Returns: numpy.ndarray: The separated signal in frequency-domain. \ The shape is (n_channels, n_bins, n_frames). """ self.input = input.copy() self._reset(**kwargs) # Call __call__ of IVABase's parent, i.e. __call__ of IterativeMethodBase super(IVABase, self).__call__(n_iter=n_iter, initial_call=initial_call) if self.scale_restoration: self.restore_scale() self.output = self.separate(self.input, demix_filter=self.demix_filter) return self.output def __repr__(self) -> str: s = "GradIVA(" s += "step_size={step_size}" s += ", is_holonomic={is_holonomic}" s += ", scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) class FastIVABase(IVABase): r"""Base class of fast independent vector analysis (FastIVA). Args: flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` or ``minimal_distortion_principle``. Default: ``True``. record_loss (bool): Record the loss at each iteration of the update algorithm if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back and minimal distortion principle. Default: ``0``. """ def __init__( self, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), callbacks: Optional[ Union[Callable[["IVABase"], None], List[Callable[["IVABase"], None]]] ] = None, scale_restoration: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, ) -> None: super().__init__( flooring_fn=flooring_fn, callbacks=callbacks, scale_restoration=scale_restoration, record_loss=record_loss, reference_id=reference_id, ) def __repr__(self) -> str: s = "FastIVA(" s += "scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) def _reset(self, **kwargs) -> None: super()._reset(**kwargs) X, W = self.input, self.demix_filter Z = whiten(X) Y = self.separate(Z, demix_filter=W, use_whitening=False) self.whitened_input = Z self.output = Y def separate( self, input: np.ndarray, demix_filter: np.ndarray, use_whitening: bool = True ) -> np.ndarray: r"""Separate ``input`` using ``demixing_filter``. .. math:: \boldsymbol{y}_{ij} = \boldsymbol{W}_{i}\boldsymbol{x}_{ij} Args: input (numpy.ndarray): The mixture signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). demix_filter (numpy.ndarray): The demixing filters to separate ``input``. The shape is (n_bins, n_sources, n_channels). use_whitening (bool): If ``use_whitening=True``, use_whitening (sphering) is applied to ``input``. Default: True. Returns: numpy.ndarray of the separated signal in frequency-domain. The shape is (n_sources, n_bins, n_frames). """ if use_whitening: whitened_input = whiten(input) else: whitened_input = input output = super().separate(whitened_input, demix_filter=demix_filter) return output def compute_loss(self) -> float: r"""Compute loss :math:`\mathcal{L}`. :math:`\mathcal{L}` is given as follows: .. math:: \mathcal{L} \ &= \frac{1}{J}\sum_{j,n}G(\vec{\boldsymbol{y}}_{jn}), \\ G(\vec{\boldsymbol{y}}_{jn}) \ &= - \log p(\vec{\boldsymbol{y}}_{jn}) Returns: Computed loss. """ Z, W = self.whitened_input, self.demix_filter Y = self.separate(Z, demix_filter=W, use_whitening=False) # (n_sources, n_bins, n_frames) G = self.contrast_fn(Y) # (n_sources, n_frames) loss = np.sum(np.mean(G, axis=1), axis=0).item() return loss def apply_projection_back(self) -> None: r"""Apply projection back technique to estimated spectrograms.""" assert self.scale_restoration, "Set self.scale_restoration=True." reference_id = self.reference_id X, Z = self.input, self.whitened_input W = self.demix_filter Y = self.separate(Z, demix_filter=W, use_whitening=False) Y_scaled = projection_back(Y, reference=X, reference_id=reference_id) Z = Z.transpose(1, 0, 2) Z_Hermite = Z.transpose(0, 2, 1).conj() ZZ_Hermite = Z @ Z_Hermite W_scaled = Y_scaled.transpose(1, 0, 2) @ Z_Hermite @ np.linalg.inv(ZZ_Hermite) self.output, self.demix_filter = Y_scaled, W_scaled class AuxIVABase(IVABase): r"""Base class of auxiliary-function-based independent vector analysis (IVA). Args: contrast_fn (callable): A contrast function corresponds to :math:`-\log p(\vec{\boldsymbol{y}}_{jn})`. This function is expected to receive (n_channels, n_bins, n_frames) and return (n_channels, n_frames). d_contrast_fn (callable): A derivative of the contrast function. This function is expected to receive (n_channels, n_frames) and return (n_channels, n_frames). flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` or ``minimal_distortion_principle``. Default: ``True``. record_loss (bool): Record the loss at each iteration of the update algorithm if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back and minimal distortion principle. Default: ``0``. """ def __init__( self, contrast_fn: Callable[[np.ndarray], np.ndarray] = None, d_contrast_fn: Callable[[np.ndarray], np.ndarray] = None, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), callbacks: Optional[ Union[Callable[["AuxIVABase"], None], List[Callable[["AuxIVABase"], None]]] ] = None, scale_restoration: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, ) -> None: super().__init__( flooring_fn=flooring_fn, callbacks=callbacks, scale_restoration=scale_restoration, record_loss=record_loss, reference_id=reference_id, ) self.contrast_fn = contrast_fn self.d_contrast_fn = d_contrast_fn def __call__( self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs ) -> np.ndarray: r"""Separate a frequency-domain multichannel signal. Args: input (numpy.ndarray): The mixture signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). n_iter (int): The number of iterations of demixing filter updates. Default: ``100``. initial_call (bool): If ``True``, perform callbacks (and computation of loss if necessary) before iterations. Returns: numpy.ndarray of the separated signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). """ return super().__call__(input, n_iter=n_iter, initial_call=initial_call, **kwargs) def __repr__(self) -> str: s = "AuxIVA(" s += "scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) class GradIVA(GradIVABase): r"""Independent vector analysis (IVA) [#kim2006independent]_ using gradient descent. Args: step_size (float): A step size of the gradient descent. Default: ``1e-1``. contrast_fn (callable): A contrast function corresponds to :math:`-\log p(\vec{\boldsymbol{y}}_{jn})`. This function is expected to receive (n_channels, n_bins, n_frames) and return (n_channels, n_frames). score_fn (callable): A score function corresponds to the partial derivative of the contrast function. This function is expected to receive (n_channels, n_bins, n_frames) and return (n_channels, n_bins, n_frames). flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. is_holonomic (bool): If ``is_holonomic=True``, Holonomic-type update is used. Otherwise, Nonholonomic-type update is used. Default: ``False``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` or ``minimal_distortion_principle``. Default: ``True``. record_loss (bool): Record the loss at each iteration of the gradient descent if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back and minimal distortion principle. Default: ``0``. Examples: Update demixing filters using Holonomic-type update: .. code-block:: python >>> def contrast_fn(y): ... return 2 * np.linalg.norm(y, axis=1) >>> def score_fn(y): ... norm = np.linalg.norm(y, axis=1, keepdims=True) ... return y / np.maximum(norm, 1e-10) >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> iva = GradIVA( ... contrast_fn=contrast_fn, ... score_fn=score_fn, ... is_holonomic=True, ... ) >>> spectrogram_est = iva(spectrogram_mix, n_iter=5000) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters using Nonholonomic-type update: .. code-block:: python >>> def contrast_fn(y): ... return 2 * np.linalg.norm(y, axis=1) >>> def score_fn(y): ... norm = np.linalg.norm(y, axis=1, keepdims=True) ... return y / np.maximum(norm, 1e-10) >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> iva = GradIVA( ... contrast_fn=contrast_fn, ... score_fn=score_fn, ... is_holonomic=False, ... ) >>> spectrogram_est = iva(spectrogram_mix, n_iter=5000) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) .. [#kim2006independent] T. Kim, H. T. Attias, S.-Y. Lee, and T.-W. Lee, "Blind source separation exploiting higher-order frequency dependencies," in *IEEE Trans. ASLP*, vol. 15, no. 1, pp. 70-79, 2007. """ def __init__( self, step_size: float = 1e-1, contrast_fn: Callable[[np.ndarray], np.ndarray] = None, score_fn: Callable[[np.ndarray], np.ndarray] = None, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), callbacks: Optional[ Union[Callable[["GradIVA"], None], List[Callable[["GradIVA"], None]]] ] = None, is_holonomic: bool = True, scale_restoration: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, ) -> None: super().__init__( step_size=step_size, contrast_fn=contrast_fn, score_fn=score_fn, flooring_fn=flooring_fn, callbacks=callbacks, is_holonomic=is_holonomic, scale_restoration=scale_restoration, record_loss=record_loss, reference_id=reference_id, ) def update_once(self) -> None: r"""Update demixing filters once using the gradient descent. If ``is_holonomic=True``, demixing filters are updated as follows: .. math:: \boldsymbol{W}_{i} \leftarrow\boldsymbol{W}_{i} - \eta\left(\frac{1}{J}\sum_{j} \ \boldsymbol{\phi}_{i}(\vec{\boldsymbol{Y}}_{j})\boldsymbol{y}_{ij}^{\mathsf{H}} \ -\boldsymbol{I}\right)\boldsymbol{W}_{i}^{-\mathsf{H}}, where .. math:: \boldsymbol{\phi}_{i}(\vec{\boldsymbol{Y}}_{j}) &= \left(\phi_{i}(\vec{\boldsymbol{y}}_{j1}),\ldots,\ \phi_{i}(\vec{\boldsymbol{y}}_{jn}),\ldots,\ \phi_{i}(\vec{\boldsymbol{y}}_{jN}))\ \right)^{\mathsf{T}}\in\mathbb{C}^{N}, \\ \phi_{i}(\vec{\boldsymbol{y}}_{jn}) &= \frac{\partial G(\vec{\boldsymbol{y}}_{jn})}{\partial y_{ijn}^{*}}, \\ G(\vec{\boldsymbol{y}}_{jn}) &= -\log p(\vec{\boldsymbol{y}}_{jn}). Otherwise (``is_holonomic=False``), .. math:: \boldsymbol{W}_{i} \leftarrow\boldsymbol{W}_{i} - \eta\cdot\mathrm{offdiag}\left(\frac{1}{J}\sum_{j} \boldsymbol{\phi}_{i}(\vec{\boldsymbol{Y}}_{j})\boldsymbol{y}_{ij}^{\mathsf{H}}\right) \boldsymbol{W}_{i}^{-\mathsf{H}}. """ X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Phi = self.score_fn(Y) Y_conj = Y.conj() PhiY = np.mean(Phi[:, np.newaxis, :, :] * Y_conj[np.newaxis, :, :, :], axis=-1) PhiY = PhiY.transpose(2, 0, 1) # (n_bins, n_sources, n_sources) W_inv = np.linalg.inv(W) W_inv_Hermite = W_inv.transpose(0, 2, 1).conj() eye = np.eye(self.n_sources) if self.is_holonomic: delta = (PhiY - eye) @ W_inv_Hermite else: delta = ((1 - eye) * PhiY) @ W_inv_Hermite W = W - self.step_size * delta Y = self.separate(X, demix_filter=W) self.demix_filter = W self.output = Y class NaturalGradIVA(GradIVABase): r"""Independent vector analysis (IVA) using natural gradient descent. Args: step_size (float): A step size of the gradient descent. Default: ``1e-1``. contrast_fn (callable): A contrast function corresponds to :math:`-\log p(\vec{\boldsymbol{y}}_{jn})`. This function is expected to receive (n_channels, n_bins, n_frames) and return (n_channels, n_frames). score_fn (callable): A score function corresponds to the partial derivative of the contrast function. This function is expected to receive (n_channels, n_bins, n_frames) and return (n_channels, n_bins, n_frames). flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. is_holonomic (bool): If ``is_holonomic=True``, Holonomic-type update is used. Otherwise, Nonholonomic-type update is used. Default: ``False``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` or ``minimal_distortion_principle``. Default: ``True``. record_loss (bool): Record the loss at each iteration of the update algorithm if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back and minimal distortion principle. Default: ``0``. Examples: Update demixing filters using Holonomic-type update: .. code-block:: python >>> def contrast_fn(y): ... return 2 * np.linalg.norm(y, axis=1) >>> def score_fn(y): ... norm = np.linalg.norm(y, axis=1, keepdims=True) ... return y / np.maximum(norm, 1e-10) >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> iva = NaturalGradIVA( ... contrast_fn=contrast_fn, ... score_fn=score_fn, ... is_holonomic=True, ... ) >>> spectrogram_est = iva(spectrogram_mix, n_iter=500) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters using Nonholonomic-type update: .. code-block:: python >>> def contrast_fn(y): ... return 2 * np.linalg.norm(y, axis=1) >>> def score_fn(y): ... norm = np.linalg.norm(y, axis=1, keepdims=True) ... return y / np.maximum(norm, 1e-10) >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> iva = NaturalGradIVA( ... contrast_fn=contrast_fn, ... score_fn=score_fn, ... is_holonomic=False, ... ) >>> spectrogram_est = iva(spectrogram_mix, n_iter=500) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) """ def __init__( self, step_size: float = 1e-1, contrast_fn: Callable[[np.ndarray], np.ndarray] = None, score_fn: Callable[[np.ndarray], np.ndarray] = None, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), callbacks: Optional[ Union[Callable[["NaturalGradIVA"], None], List[Callable[["NaturalGradIVA"], None]]] ] = None, is_holonomic: bool = True, scale_restoration: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, ) -> None: super().__init__( step_size=step_size, contrast_fn=contrast_fn, score_fn=score_fn, flooring_fn=flooring_fn, callbacks=callbacks, is_holonomic=is_holonomic, scale_restoration=scale_restoration, record_loss=record_loss, reference_id=reference_id, ) def update_once(self) -> None: r"""Update demixing filters once using the natural gradient descent. If ``is_holonomic=True``, demixing filters are updated as follows: .. math:: \boldsymbol{W}_{i} \leftarrow\boldsymbol{W}_{i} - \eta\left(\frac{1}{J}\sum_{j} \ \boldsymbol{\phi}_{i}(\vec{\boldsymbol{Y}}_{j})\boldsymbol{y}_{ij}^{\mathsf{H}} \ -\boldsymbol{I}\right)\boldsymbol{W}_{i}, where .. math:: \boldsymbol{\phi}_{i}(\vec{\boldsymbol{Y}}_{j}) &= \left(\phi_{i}(\vec{\boldsymbol{y}}_{j1}),\ldots,\ \phi_{i}(\vec{\boldsymbol{y}}_{jn}),\ldots,\ \phi_{i}(\vec{\boldsymbol{y}}_{jN}))\ \right)^{\mathsf{T}}\in\mathbb{C}^{N}, \\ \phi_{i}(\vec{\boldsymbol{y}}_{jn}) &= \frac{\partial G(\vec{\boldsymbol{y}}_{jn})}{\partial y_{ijn}^{*}}, \\ G(\vec{\boldsymbol{y}}_{jn}) &= -\log p(\vec{\boldsymbol{y}}_{jn}). Otherwise (``is_holonomic=False``), .. math:: \boldsymbol{W}_{i} \leftarrow\boldsymbol{W}_{i} - \eta\cdot\mathrm{offdiag}\left(\frac{1}{J}\sum_{j} \boldsymbol{\phi}_{i}(\vec{\boldsymbol{Y}}_{j})\boldsymbol{y}_{ij}^{\mathsf{H}}\right) \boldsymbol{W}_{i}. """ X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Phi = self.score_fn(Y) Y_conj = Y.conj() PhiY = np.mean(Phi[:, np.newaxis, :, :] * Y_conj[np.newaxis, :, :, :], axis=-1) PhiY = PhiY.transpose(2, 0, 1) # (n_bins, n_sources, n_sources) eye = np.eye(self.n_sources) if self.is_holonomic: delta = (PhiY - eye) @ W else: delta = ((1 - eye) * PhiY) @ W W = W - self.step_size * delta Y = self.separate(X, demix_filter=W) self.demix_filter = W self.output = Y class FastIVA(FastIVABase): r"""Fast independent vector analysis (Fast IVA) [#lee2007fast]_. Args: contrast_fn (callable): A contrast function which corresponds to :math:`-\log p(\vec{\boldsymbol{y}}_{jn})`. This function is expected to receive (n_channels, n_bins, n_frames) and return (n_channels, n_frames). d_contrast_fn (callable): A derivative of the contrast function. This function is expected to receive (n_channels, n_frames) and return (n_channels, n_frames). dd_contrast_fn (callable): Second order derivative of the contrast function. This function is expected to receive (n_channels, n_frames) and return (n_channels, n_frames). flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` or ``minimal_distortion_principle``. Default: ``True``. record_loss (bool): Record the loss at each iteration of the update algorithm if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back and minimal distortion principle. Default: ``0``. Examples: .. code-block:: python >>> from ssspy.transform import whiten >>> from ssspy.algorithm import projection_back >>> def contrast_fn(y): ... return 2 * np.linalg.norm(y, axis=1) >>> def d_contrast_fn(y): ... return 2 * np.ones_like(y) >>> def dd_contrast_fn(y): ... return 2 * np.zeros_like(y) >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> iva = FastIVA( ... contrast_fn=contrast_fn, ... d_contrast_fn=d_contrast_fn, ... dd_contrast_fn=dd_contrast_fn, ... scale_restoration=False, ... ) >>> spectrogram_mix_whitened = whiten(spectrogram_mix) >>> spectrogram_est = iva(spectrogram_mix_whitened, n_iter=100) >>> spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) .. [#lee2007fast] I. Lee et al., "Fast fixed-point independent vector analysis algorithms \ for convolutive blind source separation," *Signal Processing*, vol. 87, no. 8, pp. 1859-1871, 2007. """ def __init__( self, contrast_fn: Callable[[np.ndarray], np.ndarray] = None, d_contrast_fn: Callable[[np.ndarray], np.ndarray] = None, dd_contrast_fn: Callable[[np.ndarray], np.ndarray] = None, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), callbacks: Optional[ Union[Callable[["FastIVA"], None], List[Callable[["FastIVA"], None]]] ] = None, scale_restoration: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, ) -> None: super().__init__( flooring_fn=flooring_fn, callbacks=callbacks, scale_restoration=scale_restoration, record_loss=record_loss, reference_id=reference_id, ) if contrast_fn is None: raise ValueError("Specify contrast function.") else: self.contrast_fn = contrast_fn if d_contrast_fn is None: raise ValueError("Specify derivative of contrast function.") else: self.d_contrast_fn = d_contrast_fn if dd_contrast_fn is None: raise ValueError("Specify second order derivative of contrast function.") else: self.dd_contrast_fn = dd_contrast_fn def __call__( self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs ) -> np.ndarray: r"""Separate a frequency-domain multichannel signal. Args: input (numpy.ndarray): The mixture signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). n_iter (int): The number of iterations of demixing filter updates. Default: ``100``. initial_call (bool): If ``True``, perform callbacks (and computation of loss if necessary) before iterations. Returns: numpy.ndarray of the separated signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). """ self.input = input.copy() self._reset(**kwargs) # Call __call__ of IVABase's parent, i.e. __call__ of IterativeMethodBase super(IVABase, self).__call__(n_iter=n_iter, initial_call=initial_call) if self.scale_restoration: self.restore_scale() self.output = self.separate( self.whitened_input, demix_filter=self.demix_filter, use_whitening=False ) return self.output def __repr__(self) -> str: s = "FastIVA(" s += "scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) def update_once( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update demixing filters once. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. Demixing filters are updated as follows: .. math:: \boldsymbol{w}_{in} \leftarrow&\frac{1}{J}\sum_{j} \frac{G'_{\mathbb{R}}(\|\vec{\boldsymbol{y}}_{jn}\|_{2})} {2\|\vec{\boldsymbol{y}}_{jn}\|_{2}} \left(\boldsymbol{w}_{in}-y_{ijn}^{*}\boldsymbol{x}_{ij}\right) \notag \\ &-\frac{1}{J}\sum_{j}\frac{|y_{ijn}|^{2}}{2\|\vec{\boldsymbol{y}}_{jn}\|_{2}}\left( \frac{G'_{\mathbb{R}}(\|\vec{\boldsymbol{y}}_{jn}\|_{2})} {\|\vec{\boldsymbol{y}}_{jn}\|_{2}} - G''_{\mathbb{R}}(\|\vec{\boldsymbol{y}}_{jn}\|_{2}) \right)\boldsymbol{w}_{in} \\ \boldsymbol{W}_{i} \leftarrow&\left(\boldsymbol{W}_{i}\boldsymbol{W}_{i}^{\mathsf{H}}\right)^{-\frac{1}{2}} \boldsymbol{W}_{i}. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) Z, W = self.whitened_input, self.demix_filter Y = self.separate(Z, demix_filter=W, use_whitening=False) norm = np.linalg.norm(Y, axis=1) varphi = self.d_contrast_fn(norm) / flooring_fn(2 * norm) # (n_sources, n_frames) Y_conj = Y.conj() YZ = Y_conj[:, np.newaxis, :, :] * Z W_Hermite = W.transpose(1, 2, 0).conj() W_YZ = W_Hermite[:, :, :, np.newaxis] - YZ W_YZ = np.mean(varphi[:, np.newaxis, np.newaxis, :] * W_YZ, axis=-1) Y_GG = (2 * varphi - self.dd_contrast_fn(norm)) / flooring_fn(2 * norm) YY_GG = Y_GG[:, np.newaxis, :] * (np.abs(Y) ** 2) YY_GGW = np.mean(W_Hermite[:, :, :, np.newaxis] * YY_GG[:, np.newaxis, :, :], axis=-1) # Update W_Hermite = W_YZ - YY_GGW W = W_Hermite.transpose(2, 0, 1).conj() u, _, v_Hermite = np.linalg.svd(W) W = u @ v_Hermite self.demix_filter = W class FasterIVA(FastIVABase): r"""Faster independent vector analysis (Faster IVA) [#brendel2021fasteriva]_. Args: contrast_fn (callable): A contrast function which corresponds to :math:`-\log p(\vec{\boldsymbol{y}}_{jn})`. This function is expected to receive (n_channels, n_bins, n_frames) and return (n_channels, n_frames). d_contrast_fn (callable): A derivative of the contrast function. This function is expected to receive (n_channels, n_frames) and return (n_channels, n_frames). flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` or ``minimal_distortion_principle``. Default: ``True``. record_loss (bool): Record the loss at each iteration of the update algorithm if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back and minimal distortion principle. Default: ``0``. Examples: .. code-block:: python >>> from ssspy.transform import whiten >>> from ssspy.algorithm import projection_back >>> def contrast_fn(y): ... return 2 * np.linalg.norm(y, axis=1) >>> def d_contrast_fn(y): ... return 2 * np.ones_like(y) >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> iva = FasterIVA( ... contrast_fn=contrast_fn, ... d_contrast_fn=d_contrast_fn, ... scale_restoration=False, ... ) >>> spectrogram_mix_whitened = whiten(spectrogram_mix) >>> spectrogram_est = iva(spectrogram_mix_whitened, n_iter=100) >>> spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) .. [#brendel2021fasteriva] A. Brendel and W. Kellermann, "Faster IVA: Update rules for independent vector analysis based on negentropy \ and the majorize-minimize principle," in *Proc. WASPAA*, pp. 131-135, 2021. """ def __init__( self, contrast_fn: Callable[[np.ndarray], np.ndarray] = None, d_contrast_fn: Callable[[np.ndarray], np.ndarray] = None, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), callbacks: Optional[ Union[Callable[["FasterIVA"], None], List[Callable[["FasterIVA"], None]]] ] = None, scale_restoration: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, ) -> None: super().__init__( flooring_fn=flooring_fn, callbacks=callbacks, scale_restoration=scale_restoration, record_loss=record_loss, reference_id=reference_id, ) if contrast_fn is None: raise ValueError("Specify contrast function.") else: self.contrast_fn = contrast_fn if d_contrast_fn is None: raise ValueError("Specify derivative of contrast function.") else: self.d_contrast_fn = d_contrast_fn def __call__( self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs ) -> np.ndarray: r"""Separate a frequency-domain multichannel signal. Args: input (numpy.ndarray): The mixture signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). n_iter (int): The number of iterations of demixing filter updates. Default: ``100``. initial_call (bool): If ``True``, perform callbacks (and computation of loss if necessary) before iterations. Returns: numpy.ndarray of the separated signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). """ self.input = input.copy() self._reset(**kwargs) # Call __call__ of IVABase's parent, i.e. __call__ of IterativeMethodBase super(IVABase, self).__call__(n_iter=n_iter, initial_call=initial_call) if self.scale_restoration: self.restore_scale() self.output = self.separate( self.whitened_input, demix_filter=self.demix_filter, use_whitening=False ) return self.output def __repr__(self) -> str: s = "FasterIVA(" s += "scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) def update_once( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update demixing filters once. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. In FasterIVA, we compute the eigenvector of :math:`\boldsymbol{U}_{in}` which corresponds to the largest eigenvalue by solving .. math:: \boldsymbol{U}_{in}\boldsymbol{w}_{in} = \lambda_{in}\boldsymbol{w}_{in}. Then, .. math:: \boldsymbol{W}_{i} \leftarrow\left(\boldsymbol{W}_{i}\boldsymbol{W}_{i}^{\mathsf{H}}\right)^{-\frac{1}{2}} \boldsymbol{W}_{i}. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) Z, W = self.whitened_input, self.demix_filter Y = self.separate(Z, demix_filter=W, use_whitening=False) ZZ_Hermite = Z[:, np.newaxis, :, :] * Z[np.newaxis, :, :, :].conj() ZZ_Hermite = ZZ_Hermite.transpose(2, 0, 1, 3) # (n_bins, n_channels, n_channels, n_frames) norm = np.linalg.norm(Y, axis=1) varphi = self.d_contrast_fn(norm) / flooring_fn(2 * norm) # (n_sources, n_frames) varphi_ZZ = varphi[:, np.newaxis, np.newaxis, :] * ZZ_Hermite[:, np.newaxis, :, :, :] U = np.mean(varphi_ZZ, axis=-1) # (n_bins, n_sources, n_channels, n_channels) _, w = eigh(U) # (n_bins, n_sources, n_channels, n_channels) W = w[..., -1].conj() # eigenvector that corresponds to largest eigenvalue u, _, v_Hermite = np.linalg.svd(W) W = u @ v_Hermite self.demix_filter = W class AuxIVA(AuxIVABase): r"""Auxiliary-function-based independent vector analysis (IVA) [#ono2011stable]_. Args: spatial_algorithm (str): Algorithm for demixing filter updates. Choose ``IP``, ``IP1``, ``IP2``, ``ISS``, ``ISS1``, ``ISS2``, or ``IPA``. Default: ``IP``. contrast_fn (callable): A contrast function which corresponds to :math:`-\log p(\vec{\boldsymbol{y}}_{jn})`. This function is expected to receive (n_channels, n_bins, n_frames) and return (n_channels, n_frames). d_contrast_fn (callable): A derivative of the contrast function. This function is expected to receive (n_channels, n_frames) and return (n_channels, n_frames). flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. pair_selector (callable, optional): Selector to choose updaing pair in ``IP2`` and ``ISS2``. If ``None`` is given, ``sequential_pair_selector`` is used. Default: ``None``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` or ``minimal_distortion_principle``. Default: ``True``. record_loss (bool): Record the loss at each iteration of the demixing filter update if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back and minimal distortion principle. Default: ``0``. lqpqm_normalization (bool): This keyword argument can be specified when ``spatial_algorithm='IPA'``. If ``True``, normalization by trace is applied to positive semi-definite matrix in LQPQM. Default: ``True``. newton_iter (int): This keyword argument can be specified when ``spatial_algorithm='IPA'``. Number of iterations in Newton method. Default: ``1``. Examples: Update demixing filters by IP: .. code-block:: python >>> def contrast_fn(y): ... return 2 * np.linalg.norm(y, axis=1) >>> def d_contrast_fn(y): ... return 2 * np.ones_like(y) >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> iva = AuxIVA( ... spatial_algorithm="IP", ... contrast_fn=contrast_fn, ... d_contrast_fn=d_contrast_fn, ... ) >>> spectrogram_est = iva(spectrogram_mix, n_iter=100) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters by IP2: .. code-block:: python >>> from ssspy.utils.select_pair import sequential_pair_selector >>> def contrast_fn(y): ... return 2 * np.linalg.norm(y, axis=1) >>> def d_contrast_fn(y): ... return 2 * np.ones_like(y) >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> iva = AuxIVA( ... spatial_algorithm="IP2", ... contrast_fn=contrast_fn, ... d_contrast_fn=d_contrast_fn, ... pair_selector=sequential_pair_selector, ... ) >>> spectrogram_est = iva(spectrogram_mix, n_iter=100) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters by ISS: .. code-block:: python >>> def contrast_fn(y): ... return 2 * np.linalg.norm(y, axis=1) >>> def d_contrast_fn(y): ... return 2 * np.ones_like(y) >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> iva = AuxIVA( ... spatial_algorithm="ISS", ... contrast_fn=contrast_fn, ... d_contrast_fn=d_contrast_fn, ... ) >>> spectrogram_est = iva(spectrogram_mix, n_iter=100) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters by ISS2: .. code-block:: python >>> import functools >>> from ssspy.utils.select_pair import sequential_pair_selector >>> def contrast_fn(y): ... return 2 * np.linalg.norm(y, axis=1) >>> def d_contrast_fn(y): ... return 2 * np.ones_like(y) >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> iva = AuxIVA( ... spatial_algorithm="ISS2", ... contrast_fn=contrast_fn, ... d_contrast_fn=d_contrast_fn, ... pair_selector=functools.partial(sequential_pair_selector, step=2), ... ) >>> spectrogram_est = iva(spectrogram_mix, n_iter=100) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters by IPA: .. code-block:: python >>> def contrast_fn(y): ... return 2 * np.linalg.norm(y, axis=1) >>> def d_contrast_fn(y): ... return 2 * np.ones_like(y) >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> iva = AuxIVA( ... spatial_algorithm="IPA", ... contrast_fn=contrast_fn, ... d_contrast_fn=d_contrast_fn, ... ) >>> spectrogram_est = iva(spectrogram_mix, n_iter=100) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) .. [#ono2011stable] N. Ono, "Stable and fast update rules for independent vector analysis based on \ auxiliary function technique," in *Proc. WASPAA*, 2011, p.189-192. """ _ipa_default_kwargs = {"lqpqm_normalization": True, "newton_iter": 1} _default_kwargs = _ipa_default_kwargs def __init__( self, spatial_algorithm: str = "IP", contrast_fn: Callable[[np.ndarray], np.ndarray] = None, d_contrast_fn: Callable[[np.ndarray], np.ndarray] = None, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), pair_selector: Optional[Callable[[int], Iterable[Tuple[int, int]]]] = None, callbacks: Optional[ Union[Callable[["AuxIVA"], None], List[Callable[["AuxIVA"], None]]] ] = None, scale_restoration: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, **kwargs, ) -> None: super().__init__( contrast_fn=contrast_fn, d_contrast_fn=d_contrast_fn, flooring_fn=flooring_fn, callbacks=callbacks, scale_restoration=scale_restoration, record_loss=record_loss, reference_id=reference_id, ) assert spatial_algorithm in spatial_algorithms, "Not support {}.".format(spatial_algorithm) self.spatial_algorithm = spatial_algorithm if pair_selector is None: if spatial_algorithm in ["IP2", "ISS2"]: self.pair_selector = sequential_pair_selector else: self.pair_selector = pair_selector if spatial_algorithm == "IPA": valid_keys = set(self.__class__._ipa_default_kwargs.keys()) else: valid_keys = set() invalid_keys = set(kwargs) - valid_keys assert invalid_keys == set(), "Invalid keywords {} are given.".format(invalid_keys) for key, value in kwargs.items(): setattr(self, key, value) # set default values if necessary for key in valid_keys: if not hasattr(self, key): value = self.__class__._default_kwargs[key] setattr(self, key, value) def __call__( self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs ) -> np.ndarray: r"""Separate a frequency-domain multichannel signal. Args: input (numpy.ndarray): The mixture signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). n_iter (int): The number of iterations of demixing filter updates. Default: ``100``. initial_call (bool): If ``True``, perform callbacks (and computation of loss if necessary) before iterations. Returns: numpy.ndarray of the separated signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). """ self.input = input.copy() self._reset(**kwargs) # Call __call__ of IVABase's parent, i.e. __call__ of IterativeMethodBase super(IVABase, self).__call__(n_iter=n_iter, initial_call=initial_call) if self.scale_restoration: self.restore_scale() if self.demix_filter is None: pass else: self.output = self.separate(self.input, demix_filter=self.demix_filter) return self.output def __repr__(self) -> str: s = "AuxIVA(" s += "spatial_algorithm={spatial_algorithm}" s += ", scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) def _reset(self, **kwargs) -> None: r"""Reset attributes by given keyword arguments. Args: kwargs: Keyword arguments to set as attributes of IVA. """ super()._reset(**kwargs) if self.spatial_algorithm in ["ISS", "ISS1", "ISS2", "IPA"]: self.demix_filter = None def update_once( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update demixing filters once. - If ``self.spatial_algorithm`` is ``IP`` or ``IP1``, ``update_once_ip1`` is called. - If ``self.spatial_algorithm`` is ``IP2``, ``update_once_ip2`` is called. - If ``self.spatial_algorithm`` is ``ISS`` or ``ISS1``, ``update_once_iss1`` is called. - If ``self.spatial_algorithm`` is ``ISS2``, ``update_once_iss2`` is called. - If ``self.spatial_algorithm`` is ``IPA``, ``update_once_ipa`` is called. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) if self.spatial_algorithm in ["IP", "IP1"]: self.update_once_ip1(flooring_fn=flooring_fn) elif self.spatial_algorithm in ["IP2"]: self.update_once_ip2(flooring_fn=flooring_fn) elif self.spatial_algorithm in ["ISS", "ISS1"]: self.update_once_iss1(flooring_fn=flooring_fn) elif self.spatial_algorithm in ["ISS2"]: self.update_once_iss2(flooring_fn=flooring_fn) elif self.spatial_algorithm in ["IPA"]: self.update_once_ipa(flooring_fn=flooring_fn) else: raise NotImplementedError("Not support {}.".format(self.spatial_algorithm)) def update_once_ip1( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update demixing filters once using iterative projection. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. Compute auxiliary variables: .. math:: \bar{r}_{jn} \leftarrow\|\vec{\boldsymbol{y}}_{jn}\|_{2} Then, demixing filters are updated sequentially for :math:`n=1,\ldots,N` as follows: .. math:: \boldsymbol{w}_{in} &\leftarrow\left(\boldsymbol{W}_{in}^{\mathsf{H}}\boldsymbol{U}_{in}\right)^{-1} \ \boldsymbol{e}_{n}, \\ \boldsymbol{w}_{in} &\leftarrow\frac{\boldsymbol{w}_{in}} {\sqrt{\boldsymbol{w}_{in}^{\mathsf{H}}\boldsymbol{U}_{in}\boldsymbol{w}_{in}}}, \\ where .. math:: \boldsymbol{U}_{in} &= \frac{1}{J}\sum_{j} \varphi(\bar{r}_{jn})\boldsymbol{x}_{ij}\boldsymbol{x}_{ij}^{\mathsf{H}}, \\ \varphi(\bar{r}_{jn}) &= \frac{G'_{\mathbb{R}}(\bar{r}_{jn})}{2\bar{r}_{jn}}, \\ G(\vec{\boldsymbol{y}}_{jn}) &= -\log p(\vec{\boldsymbol{y}}_{jn}), \\ G_{\mathbb{R}}(\|\vec{\boldsymbol{y}}_{jn}\|_{2}) &= G(\vec{\boldsymbol{y}}_{jn}). """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) XX_Hermite = X[:, np.newaxis, :, :] * X[np.newaxis, :, :, :].conj() XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3) # (n_bins, n_channels, n_channels, n_frames) norm = np.linalg.norm(Y, axis=1) denom = flooring_fn(2 * norm) weight = self.d_contrast_fn(norm) / denom # (n_sources, n_frames) GXX = weight[:, np.newaxis, np.newaxis, :] * XX_Hermite[:, np.newaxis, :, :, :] U = np.mean(GXX, axis=-1) # (n_bins, n_sources, n_channels, n_channels) self.demix_filter = update_by_ip1(W, U, flooring_fn=flooring_fn) def update_once_ip2( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update demixing filters once using pairwise iterative projection. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. For :math:`n_{1}` and :math:`n_{2}` (:math:`n_{1}\neq n_{2}`), compute auxiliary variables: .. math:: \bar{r}_{jn_{1}} &\leftarrow\|\vec{\boldsymbol{y}}_{jn_{1}}\|_{2} \\ \bar{r}_{jn_{2}} &\leftarrow\|\vec{\boldsymbol{y}}_{jn_{2}}\|_{2} Then, for :math:`n=n_{1},n_{2}`, compute weighted covariance matrix as follows: .. math:: \boldsymbol{U}_{in_{1}} &= \frac{1}{J}\sum_{j} \varphi(\bar{r}_{jn_{1}})\boldsymbol{x}_{ij}\boldsymbol{x}_{ij}^{\mathsf{H}}, \\ \boldsymbol{U}_{in_{2}} &= \frac{1}{J}\sum_{j} \varphi(\bar{r}_{jn_{2}})\boldsymbol{x}_{ij}\boldsymbol{x}_{ij}^{\mathsf{H}}, where .. math:: \varphi(\bar{r}_{jn}) = \frac{G'_{\mathbb{R}}(\bar{r}_{jn})}{2\bar{r}_{jn}}. Using :math:`\boldsymbol{U}_{in_{1}}` and :math:`\boldsymbol{U}_{in_{2}}`, we compute generalized eigenvectors. .. math:: \left({\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}}^{\mathsf{H}}\boldsymbol{U}_{in_{1}} \boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\right)\boldsymbol{h}_{i} = \lambda_{i} \left({\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}}^{\mathsf{H}}\boldsymbol{U}_{in_{2}} \boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\right)\boldsymbol{h}_{i}, where .. math:: \boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})} &= (\boldsymbol{W}_{i}\boldsymbol{U}_{in_{1}})^{-1} ( \begin{array}{cc} \boldsymbol{e}_{n_{1}} & \boldsymbol{e}_{n_{2}} \end{array} ), \\ \boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})} &= (\boldsymbol{W}_{i}\boldsymbol{U}_{in_{2}})^{-1} ( \begin{array}{cc} \boldsymbol{e}_{n_{1}} & \boldsymbol{e}_{n_{2}} \end{array} ). After that, we standardize two eigenvectors :math:`\boldsymbol{h}_{in_{1}}` and :math:`\boldsymbol{h}_{in_{2}}`. .. math:: \boldsymbol{h}_{in_{1}} &\leftarrow\frac{\boldsymbol{h}_{in_{1}}} {\sqrt{\boldsymbol{h}_{in_{1}}^{\mathsf{H}} \left({\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}}^{\mathsf{H}}\boldsymbol{U}_{in_{1}} \boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\right) \boldsymbol{h}_{in_{1}}}}, \\ \boldsymbol{h}_{in_{2}} &\leftarrow\frac{\boldsymbol{h}_{in_{2}}} {\sqrt{\boldsymbol{h}_{in_{2}}^{\mathsf{H}} \left({\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}}^{\mathsf{H}}\boldsymbol{U}_{in_{2}} \boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\right) \boldsymbol{h}_{in_{2}}}}. Then, update :math:`\boldsymbol{w}_{in_{1}}` and :math:`\boldsymbol{w}_{in_{2}}` simultaneously. .. math:: \boldsymbol{w}_{in_{1}} &\leftarrow \boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\boldsymbol{h}_{in_{1}} \\ \boldsymbol{w}_{in_{2}} &\leftarrow \boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\boldsymbol{h}_{in_{2}}. At each iteration, we update pairs of :math:`n_{1}` and :math:`n_{1}` for :math:`n_{1}\neq n_{2}`. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) n_sources = self.n_sources X, W = self.input, self.demix_filter XX_Hermite = X[:, np.newaxis, :, :] * X[np.newaxis, :, :, :].conj() XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3) for m, n in self.pair_selector(n_sources): W_mn = W[:, (m, n), :] Y_mn = self.separate(X, demix_filter=W_mn) norm = np.linalg.norm(Y_mn, axis=1) weight = self.d_contrast_fn(norm) / flooring_fn(2 * norm) GXX_mn = weight[:, np.newaxis, np.newaxis, :] * XX_Hermite[:, np.newaxis, :, :, :] U_mn = np.mean(GXX_mn, axis=-1) W[:, (m, n), :] = update_by_ip2_one_pair( W, U_mn, pair=(m, n), flooring_fn=flooring_fn, ) self.demix_filter = W def update_once_iss1( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update estimated spectrograms once using \ iterative source steering [#scheibler2020fast]_. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. First, update auxiliary variables .. math:: \bar{r}_{jn} \leftarrow\|\vec{\boldsymbol{y}}_{jn}\|_{2}. Then, update :math:`y_{ijn}` as follows: .. math:: \boldsymbol{y}_{ij} & \leftarrow\boldsymbol{y}_{ij} - \boldsymbol{d}_{in}y_{ijn}, \\ d_{inn'} &= \begin{cases} \dfrac{\sum_{j}\dfrac{G'_{\mathbb{R}}(\bar{r}_{jn'})}{2\bar{r}_{jn'}} y_{ijn'}y_{ijn}^{*}}{\sum_{j}\dfrac{G'_{\mathbb{R}}(\bar{r}_{jn'})} {2\bar{r}_{jn'}}|y_{ijn}|^{2}} & (n'\neq n) \\ 1 - \dfrac{1}{\sqrt{\dfrac{1}{J}\sum_{j}\dfrac{G'_{\mathbb{R}}(\bar{r}_{jn'})} {2\bar{r}_{jn'}} |y_{ijn}|^{2}}} & (n'=n) \end{cases}. .. [#scheibler2020fast] R. Scheibler and N. Ono, "Fast and stable blind source separation with rank-1 updates," in *Proc. ICASSP*, 2020, pp. 236-240. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) Y = self.output r = np.linalg.norm(Y, axis=1) denom = flooring_fn(2 * r) varphi = self.d_contrast_fn(r) / denom # (n_sources, n_frames) self.output = update_by_iss1(Y, varphi[:, np.newaxis, :], flooring_fn=flooring_fn) def update_once_iss2( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update estimated spectrograms once using \ pairwise iterative source steering [#ikeshita2022iss2]_. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. First, we compute auxiliary variables: .. math:: \bar{r}_{jn} \leftarrow\|\vec{\boldsymbol{y}}_{jn}\|_{2}, where .. math:: G(\vec{\boldsymbol{y}}_{jn}) &= -\log p(\vec{\boldsymbol{y}}_{jn}), \\ G_{\mathbb{R}}(\|\vec{\boldsymbol{y}}_{jn}\|_{2}) &= G(\vec{\boldsymbol{y}}_{jn}). Then, we compute :math:`\boldsymbol{G}_{in}^{(n_{1},n_{2})}` \ and :math:`\boldsymbol{f}_{in}^{(n_{1},n_{2})}` for :math:`n_{1}\neq n_{2}`: .. math:: \begin{array}{rclc} \boldsymbol{G}_{in}^{(n_{1},n_{2})} &=& {\displaystyle\frac{1}{J}\sum_{j}}\varphi(\bar{r}_{jn}) \boldsymbol{y}_{ij}^{(n_{1},n_{2})}{\boldsymbol{y}_{ij}^{(n_{1},n_{2})}}^{\mathsf{H}} &(n=1,\ldots,N), \\ \boldsymbol{f}_{in}^{(n_{1},n_{2})} &=& {\displaystyle\frac{1}{J}\sum_{j}} \varphi(\bar{r}_{jn})y_{ijn}^{*}\boldsymbol{y}_{ij}^{(n_{1},n_{2})} &(n\neq n_{1},n_{2}), \\ \varphi(\bar{r}_{jn}) &=&\dfrac{G'_{\mathbb{R}}(\bar{r}_{jn})}{2\bar{r}_{jn}}. \end{array} Using :math:`\boldsymbol{G}_{in}^{(n_{1},n_{2})}` and \ :math:`\boldsymbol{f}_{in}^{(n_{1},n_{2})}`, we compute .. math:: \begin{array}{rclc} \boldsymbol{p}_{in} &=& \dfrac{\boldsymbol{h}_{in}} {\sqrt{\boldsymbol{h}_{in}^{\mathsf{H}}\boldsymbol{G}_{in}^{(n_{1},n_{2})} \boldsymbol{h}_{in}}} & (n=n_{1},n_{2}), \\ \boldsymbol{q}_{in} &=& -{\boldsymbol{G}_{in}^{(n_{1},n_{2})}}^{-1}\boldsymbol{f}_{in}^{(n_{1},n_{2})} & (n\neq n_{1},n_{2}), \end{array} where :math:`\boldsymbol{h}_{in}` (:math:`n=n_{1},n_{2}`) is \ a generalized eigenvector obtained from .. math:: \boldsymbol{G}_{in_{1}}^{(n_{1},n_{2})}\boldsymbol{h}_{i} = \lambda_{i}\boldsymbol{G}_{in_{2}}^{(n_{1},n_{2})}\boldsymbol{h}_{i}. Separated signal :math:`y_{ijn}` is updated as follows: .. math:: y_{ijn} &\leftarrow\begin{cases} &\boldsymbol{p}_{in}^{\mathsf{H}}\boldsymbol{y}_{ij}^{(n_{1},n_{2})} & (n=n_{1},n_{2}) \\ &\boldsymbol{q}_{in}^{\mathsf{H}}\boldsymbol{y}_{ij}^{(n_{1},n_{2})} + y_{ijn} & (n\neq n_{1},n_{2}) \end{cases}. .. [#ikeshita2022iss2] R. Ikeshita and T. Nakatani, "ISS2: An extension of iterative source steering algorithm for \ majorization-minimization-based independent vector analysis," *arXiv:2202.00875*, 2022. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) Y = self.output # Auxiliary variables r = np.linalg.norm(Y, axis=1) varphi = self.d_contrast_fn(r) / flooring_fn(2 * r) self.output = update_by_iss2( Y, varphi[:, np.newaxis, :], flooring_fn=flooring_fn, pair_selector=self.pair_selector, ) def update_once_ipa( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update estimated spectrograms once using \ iterative projection with adjustment [#scheibler2021independent]_. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. First, we compute auxiliary variables: .. math:: \bar{r}_{jn} \leftarrow\|\vec{\boldsymbol{y}}_{jn}\|_{2}, where .. math:: G(\vec{\boldsymbol{y}}_{jn}) &= -\log p(\vec{\boldsymbol{y}}_{jn}), \\ G_{\mathbb{R}}(\|\vec{\boldsymbol{y}}_{jn}\|_{2}) &= G(\vec{\boldsymbol{y}}_{jn}). Then, by defining, :math:`\tilde{\boldsymbol{U}}_{in'}`, :math:`\boldsymbol{A}_{in}\in\mathbb{R}^{(N-1)\times(N-1)}`, :math:`\boldsymbol{b}_{in}\in\mathbb{C}^{N-1}`, :math:`\boldsymbol{C}_{in}\in\mathbb{C}^{(N-1)\times(N-1)}`, :math:`\boldsymbol{d}_{in}\in\mathbb{C}^{N-1}`, and :math:`z_{in}\in\mathbb{R}_{\geq 0}` as follows: .. math:: \tilde{\boldsymbol{U}}_{in'} &= \frac{1}{J}\sum_{j}\frac{G'_{\mathbb{R}}(\bar{r}_{jn'})}{2\bar{r}_{jn'}} \boldsymbol{y}_{ij}\boldsymbol{y}_{ij}^{\mathsf{H}}, \\ \boldsymbol{A}_{in} &= \mathrm{diag}(\ldots, \boldsymbol{e}_{n}^{\mathsf{T}}\tilde{\boldsymbol{U}}_{in'}\boldsymbol{e}_{n} ,\ldots)~~(n'\neq n), \\ \boldsymbol{b}_{in} &= (\ldots, \boldsymbol{e}_{n}^{\mathsf{T}}\tilde{\boldsymbol{U}}_{in'}\boldsymbol{e}_{n'} ,\ldots)^{\mathsf{T}}~~(n'\neq n), \\ \boldsymbol{C}_{in} &= \bar{\boldsymbol{E}}_{n}^{\mathsf{T}}(\tilde{\boldsymbol{U}}_{in}^{-1})^{*} \bar{\boldsymbol{E}}_{n}, \\ \boldsymbol{d}_{in} &= \bar{\boldsymbol{E}}_{n}^{\mathsf{T}}(\tilde{\boldsymbol{U}}_{in}^{-1})^{*} \boldsymbol{e}_{n}, \\ z_{in} &= \boldsymbol{e}_{n}^{\mathsf{T}}\tilde{\boldsymbol{U}}_{in}^{-1}\boldsymbol{e}_{n} - \boldsymbol{d}_{in}^{\mathsf{H}}\boldsymbol{C}_{in}^{-1}\boldsymbol{d}_{in}, :math:`\boldsymbol{y}_{ij}` is updated via log-quadratically penelized quadratic minimization (LQPQM). .. math:: \check{\boldsymbol{q}}_{in} &\leftarrow \mathrm{LQPQM2}(\boldsymbol{H}_{in},\boldsymbol{v}_{in},z_{in}), \\ \boldsymbol{q}_{in} &\leftarrow \boldsymbol{G}_{in}^{-1}\check{\boldsymbol{q}}_{in} - \boldsymbol{A}_{in}^{-1}\boldsymbol{b}_{in}, \\ \tilde{\boldsymbol{q}}_{in} &\leftarrow \boldsymbol{e}_{n} - \bar{\boldsymbol{E}}_{n}\boldsymbol{q}_{in}, \\ \boldsymbol{p}_{in} &\leftarrow \frac{\tilde{\boldsymbol{U}}_{in}^{-1}\tilde{\boldsymbol{q}}_{in}^{*}} {\sqrt{(\tilde{\boldsymbol{q}}_{in}^{*})^{\mathsf{H}}\tilde{\boldsymbol{U}}_{in}^{-1} \tilde{\boldsymbol{q}}_{in}^{*}}}, \\ \boldsymbol{\Upsilon}_{i}^{(n)} &\leftarrow \boldsymbol{I} + \boldsymbol{e}_{n}(\boldsymbol{p}_{in} - \boldsymbol{e}_{n})^{\mathsf{H}} + \bar{\boldsymbol{E}}_{n}\boldsymbol{q}_{in}^{*}\boldsymbol{e}_{n}^{\mathsf{T}}, \\ \boldsymbol{y}_{ij} &\leftarrow \boldsymbol{\Upsilon}_{i}^{(n)}\boldsymbol{y}_{ij}, .. [#scheibler2021independent] R. Scheibler, "Independent vector analysis via log-quadratically penalized quadratic minimization," *IEEE Trans. Signal Processing*, vol. 69, pp. 2509-2524, 2021. """ self.lqpqm_normalization: bool self.newton_iter: int flooring_fn = choose_flooring_fn(flooring_fn, method=self) Y = self.output r = np.linalg.norm(Y, axis=1) denom = flooring_fn(2 * r) varphi = self.d_contrast_fn(r) / denom normalization = self.lqpqm_normalization max_iter = self.newton_iter self.output = update_by_ipa( Y, varphi[:, np.newaxis, :], normalization=normalization, flooring_fn=flooring_fn, max_iter=max_iter, ) def compute_loss(self) -> float: r"""Compute loss.""" if self.demix_filter is None: X, Y = self.input, self.output G = self.contrast_fn(Y) # (n_sources, n_frames) X, Y = X.transpose(1, 0, 2), Y.transpose(1, 0, 2) X_Hermite = X.transpose(0, 2, 1).conj() XX_Hermite = X @ X_Hermite # (n_bins, n_channels, n_channels) W = Y @ X_Hermite @ np.linalg.inv(XX_Hermite) logdet = self.compute_logdet(W) # (n_bins,) loss = np.sum(np.mean(G, axis=1), axis=0) - 2 * np.sum(logdet, axis=0) loss = loss.item() return loss else: return super().compute_loss() def apply_projection_back(self) -> None: r"""Apply projection back technique to estimated spectrograms.""" if self.demix_filter is None: assert self.scale_restoration, "Set self.scale_restoration=True." X, Y = self.input, self.output Y_scaled = projection_back(Y, reference=X, reference_id=self.reference_id) self.output = Y_scaled else: super().apply_projection_back() def apply_minimal_distortion_principle(self) -> None: r"""Apply minimal distortion principle to estimated spectrograms.""" if self.demix_filter is None: X, Y = self.input, self.output Y_scaled = minimal_distortion_principle(Y, reference=X, reference_id=self.reference_id) self.output = Y_scaled else: super().apply_minimal_distortion_principle() class PDSIVA(PDSBSS): def __init__( self, mu1: float = 1, mu2: float = 1, alpha: float = None, relaxation: float = 1, contrast_fn: Callable[[np.ndarray], np.ndarray] = None, prox_penalty: Callable[[np.ndarray, float], np.ndarray] = None, callbacks: Optional[ Union[Callable[["PDSIVA"], None], List[Callable[["PDSIVA"], None]]] ] = None, scale_restoration: bool = True, record_loss: bool = True, reference_id: int = 0, ) -> None: if contrast_fn is not None and prox_penalty is None: raise ValueError("Set prox_penalty.") elif contrast_fn is None and prox_penalty is not None: raise ValueError("Set contrast_fn.") elif contrast_fn is None and prox_penalty is None: def _contrast_fn(y: np.ndarray) -> np.ndarray: return np.linalg.norm(y, axis=1) def _prox_penalty(x: np.ndarray, step_size: float = 1) -> np.ndarray: return prox.l21(x, step_size=step_size, axis2=1) contrast_fn = _contrast_fn prox_penalty = _prox_penalty def penalty_fn(y: np.ndarray) -> float: r"""Sum of contrast function. Args: y (numpy.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: Computed loss. """ G = contrast_fn(y) # (n_sources, n_frames) loss = np.sum(G, axis=(0, 1)) loss = loss.item() return loss super().__init__( mu1=mu1, mu2=mu2, alpha=alpha, relaxation=relaxation, penalty_fn=penalty_fn, prox_penalty=prox_penalty, callbacks=callbacks, scale_restoration=scale_restoration, record_loss=record_loss, reference_id=reference_id, ) self.contrast_fn = contrast_fn class ADMMIVA(ADMMBSS): def __init__( self, rho: float = 1, alpha: float = None, relaxation: float = 1, contrast_fn: Callable[[np.ndarray], np.ndarray] = None, prox_penalty: Callable[[np.ndarray, float], np.ndarray] = None, callbacks: Optional[ Union[Callable[["ADMMIVA"], None], List[Callable[["ADMMIVA"], None]]] ] = None, scale_restoration: bool = True, record_loss: bool = True, reference_id: int = 0, ) -> None: if contrast_fn is not None and prox_penalty is None: raise ValueError("Set prox_penalty.") elif contrast_fn is None and prox_penalty is not None: raise ValueError("Set contrast_fn.") elif contrast_fn is None and prox_penalty is None: def _contrast_fn(y: np.ndarray) -> np.ndarray: return np.linalg.norm(y, axis=1) def _prox_penalty(x: np.ndarray, step_size: float = 1) -> np.ndarray: return prox.l21(x, step_size=step_size, axis2=1) contrast_fn = _contrast_fn prox_penalty = _prox_penalty def penalty_fn(y: np.ndarray) -> float: r"""Sum of contrast function. Args: y (numpy.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: Computed loss. """ G = contrast_fn(y) # (n_sources, n_frames) loss = np.sum(G, axis=(0, 1)) loss = loss.item() return loss super().__init__( rho=rho, alpha=alpha, relaxation=relaxation, penalty_fn=penalty_fn, prox_penalty=prox_penalty, callbacks=callbacks, scale_restoration=scale_restoration, record_loss=record_loss, reference_id=reference_id, ) self.contrast_fn = contrast_fn class GradLaplaceIVA(GradIVA): r"""Independent vector analysis (IVA) using the gradient descent on a Laplace distribution. We assume :math:`\vec{\boldsymbol{y}}_{jn}` follows a Laplace distribution. .. math:: p(\vec{\boldsymbol{y}}_{jn})\propto\exp(\|\vec{\boldsymbol{y}}_{jn}\|_{2}) Args: step_size (float): A step size of the gradient descent. Default: ``1e-1``. flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. is_holonomic (bool): If ``is_holonomic=True``, Holonomic-type update is used. Otherwise, Nonholonomic-type update is used. Default: ``False``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` or ``minimal_distortion_principle``. Default: ``True``. record_loss (bool): Record the loss at each iteration of the gradient descent if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back and minimal distortion principle. Default: ``0``. Examples: Update demixing filters using Holonomic-type update: .. code-block:: python >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> iva = GradLaplaceIVA(is_holonomic=True) >>> spectrogram_est = iva(spectrogram_mix, n_iter=5000) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters using Nonholonomic-type update: .. code-block:: python >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> iva = GradLaplaceIVA(is_holonomic=False) >>> spectrogram_est = iva(spectrogram_mix, n_iter=5000) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) """ def __init__( self, step_size: float = 1e-1, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), callbacks: Optional[ Union[Callable[["GradLaplaceIVA"], None], List[Callable[["GradLaplaceIVA"], None]]] ] = None, is_holonomic: bool = True, scale_restoration: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, ) -> None: def contrast_fn(y: np.ndarray) -> np.ndarray: r"""Contrast function. Args: y (numpy.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: numpy.ndarray of the shape is (n_sources, n_frames). """ return 2 * np.linalg.norm(y, axis=1) def score_fn(y: np.ndarray) -> np.ndarray: r"""Score function. Args: y (numpy.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: numpy.ndarray of the shape is (n_sources, n_bins, n_frames). """ norm = np.linalg.norm(y, axis=1, keepdims=True) norm = self.flooring_fn(norm) return y / norm super().__init__( step_size=step_size, contrast_fn=contrast_fn, score_fn=score_fn, flooring_fn=flooring_fn, callbacks=callbacks, is_holonomic=is_holonomic, scale_restoration=scale_restoration, record_loss=record_loss, reference_id=reference_id, ) def update_once(self) -> None: r"""Update demixing filters once using the gradient descent. If ``is_holonomic=True``, demixing filters are updated as follows: .. math:: \boldsymbol{W}_{i} \leftarrow\boldsymbol{W}_{i} - \eta\left(\frac{1}{J}\sum_{j} \ \boldsymbol{\phi}_{i}(\vec{\boldsymbol{Y}}_{j})\boldsymbol{y}_{ij}^{\mathsf{H}} \ -\boldsymbol{I}\right)\boldsymbol{W}_{i}^{-\mathsf{H}}, where .. math:: \boldsymbol{\phi}_{i}(\vec{\boldsymbol{Y}}_{j}) &= \left(\phi_{i}(\vec{\boldsymbol{y}}_{j1}),\ldots,\ \phi_{i}(\vec{\boldsymbol{y}}_{jn}),\ldots,\ \phi_{i}(\vec{\boldsymbol{y}}_{jN}))\ \right)^{\mathsf{T}}\in\mathbb{C}^{N}, \\ \phi_{i}(\vec{\boldsymbol{y}}_{jn}) &= \frac{y_{ijn}}{\|\vec{\boldsymbol{y}}_{jn}\|_{2}}. Otherwise (``is_holonomic=False``), .. math:: \boldsymbol{W}_{i} \leftarrow\boldsymbol{W}_{i} - \eta\cdot\mathrm{offdiag}\left(\frac{1}{J}\sum_{j} \boldsymbol{\phi}_{i}(\vec{\boldsymbol{Y}}_{j})\boldsymbol{y}_{ij}^{\mathsf{H}}\right) \boldsymbol{W}_{i}^{-\mathsf{H}}. """ return super().update_once() def compute_loss(self) -> float: r"""Compute loss :math:`\mathcal{L}`. :math:`\mathcal{L}` is given as follows: .. math:: \mathcal{L} \ = \frac{2}{J}\sum_{j,n}\|\vec{\boldsymbol{y}}_{jn}\|_{2} \ - 2\sum_{i}\log|\det\boldsymbol{W}_{i}|. Returns: Computed loss. """ return super().compute_loss() class GradGaussIVA(GradIVA): r"""Independent vector analysis (IVA) using the gradient descent on \ a time-varying Gaussian distribution. We assume :math:`\vec{\boldsymbol{y}}_{jn}` follows a time-varying Gaussian distribution. .. math:: p(\vec{\boldsymbol{y}}_{jn}) \propto\frac{1}{\alpha_{jn}^{I}} \exp\left(\frac{\|\vec{\boldsymbol{y}}_{jn}\|_{2}^{2}}{\alpha_{jn}}\right). Args: step_size (float): A step size of the gradient descent. Default: ``1e-1``. flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. is_holonomic (bool): If ``is_holonomic=True``, Holonomic-type update is used. Otherwise, Nonholonomic-type update is used. Default: ``False``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` or ``minimal_distortion_principle``. Default: ``True``. record_loss (bool): Record the loss at each iteration of the gradient descent if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back and minimal distortion principle. Default: ``0``. Examples: Update demixing filters using Holonomic-type update: .. code-block:: python >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> iva = GradGaussIVA(is_holonomic=True) >>> spectrogram_est = iva(spectrogram_mix, n_iter=5000) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters using Nonholonomic-type update: .. code-block:: python >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> iva = GradGaussIVA(is_holonomic=False) >>> spectrogram_est = iva(spectrogram_mix, n_iter=5000) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) """ def __init__( self, step_size: float = 1e-1, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), callbacks: Optional[ Union[Callable[["GradGaussIVA"], None], List[Callable[["GradGaussIVA"], None]]] ] = None, is_holonomic: bool = True, scale_restoration: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, ) -> None: def contrast_fn(y: np.ndarray) -> np.ndarray: r""" Args: y (numpy.ndarray): Separated signal with shape of (n_sources, n_bins, n_frames). Returns: numpy.ndarray of computed contrast function. The shape is (n_sources, n_frames). """ n_bins = self.n_bins alpha = self.variance norm = np.linalg.norm(y, axis=1) return n_bins * np.log(alpha) + (norm**2) / alpha def score_fn(y: np.ndarray) -> np.ndarray: r""" Args: y (numpy.ndarray): Norm of separated signal. The shape is (n_sources, n_bins, n_frames). Returns: numpy.ndarray of computed contrast function. The shape is (n_sources, n_frames). """ alpha = self.variance return y / alpha[:, np.newaxis, :] super().__init__( step_size=step_size, contrast_fn=contrast_fn, score_fn=score_fn, flooring_fn=flooring_fn, callbacks=callbacks, is_holonomic=is_holonomic, scale_restoration=scale_restoration, record_loss=record_loss, reference_id=reference_id, ) def _reset(self, **kwargs) -> None: r"""Reset attributes by given keyword arguments. We also set variance of Gaussian distribution. Args: kwargs: Keyword arguments to set as attributes of IVA. """ super()._reset(**kwargs) n_sources, n_frames = self.n_sources, self.n_frames self.variance = np.ones((n_sources, n_frames)) def update_once(self) -> None: r"""Update variance and demixing filters and once.""" self.update_source_model() super().update_once() def update_source_model(self) -> None: r"""Update variance of Gaussian distribution.""" X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) self.variance = np.mean(np.abs(Y) ** 2, axis=1) class NaturalGradLaplaceIVA(NaturalGradIVA): r"""Independent vector analysis (IVA) using the natural gradient descent \ on a Laplace distribution. We assume :math:`\vec{\boldsymbol{y}}_{jn}` follows a Laplace distribution. .. math:: p(\vec{\boldsymbol{y}}_{jn}) \propto\frac{1}{\alpha_{jn}^{I}} \exp\left(\frac{\|\vec{\boldsymbol{y}}_{jn}\|_{2}}{\alpha_{jn}}\right) Args: step_size (float): A step size of the gradient descent. Default: ``1e-1``. flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. is_holonomic (bool): If ``is_holonomic=True``, Holonomic-type update is used. Otherwise, Nonholonomic-type update is used. Default: ``False``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` or ``minimal_distortion_principle``. Default: ``True``. record_loss (bool): Record the loss at each iteration of the gradient descent if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back and minimal distortion principle. Default: ``0``. Examples: Update demixing filters using Holonomic-type update: .. code-block:: python >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> iva = NaturalGradLaplaceIVA(is_holonomic=True) >>> spectrogram_est = iva(spectrogram_mix, n_iter=500) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters using Nonholonomic-type update: .. code-block:: python >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> iva = NaturalGradLaplaceIVA(is_holonomic=False) >>> spectrogram_est = iva(spectrogram_mix, n_iter=500) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) """ def __init__( self, step_size: float = 1e-1, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), callbacks: Optional[ Union[ Callable[["NaturalGradLaplaceIVA"], None], List[Callable[["NaturalGradLaplaceIVA"], None]], ] ] = None, is_holonomic: bool = True, scale_restoration: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, ) -> None: def contrast_fn(y: np.ndarray) -> np.ndarray: r"""Contrast function. Args: y (numpy.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: numpy.ndarray of the shape is (n_sources, n_frames). """ return 2 * np.linalg.norm(y, axis=1) def score_fn(y: np.ndarray) -> np.ndarray: r"""Score function. Args: y (numpy.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: numpy.ndarray of the shape is (n_sources, n_bins, n_frames). """ norm = np.linalg.norm(y, axis=1, keepdims=True) norm = self.flooring_fn(norm) return y / norm super().__init__( step_size=step_size, contrast_fn=contrast_fn, score_fn=score_fn, flooring_fn=flooring_fn, callbacks=callbacks, is_holonomic=is_holonomic, scale_restoration=scale_restoration, record_loss=record_loss, reference_id=reference_id, ) def update_once(self) -> None: r"""Update demixing filters once using the natural gradient descent. If ``is_holonomic=True``, demixing filters are updated as follows: .. math:: \boldsymbol{W}_{i} \leftarrow\boldsymbol{W}_{i} - \eta\left(\frac{1}{J}\sum_{j} \ \boldsymbol{\phi}_{i}(\vec{\boldsymbol{Y}}_{j})\boldsymbol{y}_{ij}^{\mathsf{H}} \ -\boldsymbol{I}\right)\boldsymbol{W}_{i}, where .. math:: \boldsymbol{\phi}_{i}(\vec{\boldsymbol{Y}}_{j}) &= \left(\phi_{i}(\vec{\boldsymbol{y}}_{j1}),\ldots,\ \phi_{i}(\vec{\boldsymbol{y}}_{jn}),\ldots,\ \phi_{i}(\vec{\boldsymbol{y}}_{jN}))\ \right)^{\mathsf{T}}\in\mathbb{C}^{N}, \\ \phi_{i}(\vec{\boldsymbol{y}}_{jn}) &= \frac{y_{ijn}}{\|\vec{\boldsymbol{y}}_{jn}\|_{2}}. Otherwise (``is_holonomic=False``), .. math:: \boldsymbol{W}_{i} \leftarrow\boldsymbol{W}_{i} - \eta\cdot\mathrm{offdiag}\left(\frac{1}{J}\sum_{j} \boldsymbol{\phi}_{i}(\vec{\boldsymbol{Y}}_{j})\boldsymbol{y}_{ij}^{\mathsf{H}}\right) \boldsymbol{W}_{i}. """ return super().update_once() def compute_loss(self) -> float: r"""Compute loss :math:`\mathcal{L}`. :math:`\mathcal{L}` is given as follows: .. math:: \mathcal{L} \ = \frac{2}{J}\sum_{j,n}\|\vec{\boldsymbol{y}}_{jn}\|_{2} \ - 2\sum_{i}\log|\det\boldsymbol{W}_{i}|. Returns: Computed loss. """ return super().compute_loss() class NaturalGradGaussIVA(NaturalGradIVA): r"""Independent vector analysis (IVA) using the natural gradient descent \ on a time-varying Gaussian distribution. We assume :math:`\vec{\boldsymbol{y}}_{jn}` follows a time-varying Gaussian distribution. .. math:: p(\vec{\boldsymbol{y}}_{jn}) \propto\frac{1}{\alpha_{jn}^{I}} \exp\left(\frac{\|\vec{\boldsymbol{y}}_{jn}\|_{2}^{2}}{\alpha_{jn}}\right). Args: step_size (float): A step size of the gradient descent. Default: ``1e-1``. flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. is_holonomic (bool): If ``is_holonomic=True``, Holonomic-type update is used. Otherwise, Nonholonomic-type update is used. Default: ``False``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` or ``minimal_distortion_principle``. Default: ``True``. record_loss (bool): Record the loss at each iteration of the gradient descent if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back and minimal distortion principle. Default: ``0``. Examples: Update demixing filters using Holonomic-type update: .. code-block:: python >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> iva = NaturalGradGaussIVA(is_holonomic=True) >>> spectrogram_est = iva(spectrogram_mix, n_iter=500) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters using Nonholonomic-type update: .. code-block:: python >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> iva = NaturalGradGaussIVA(is_holonomic=False) >>> spectrogram_est = iva(spectrogram_mix, n_iter=500) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) """ def __init__( self, step_size: float = 1e-1, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), callbacks: Optional[ Union[ Callable[["NaturalGradGaussIVA"], None], List[Callable[["NaturalGradGaussIVA"], None]], ] ] = None, is_holonomic: bool = True, scale_restoration: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, ) -> None: def contrast_fn(y: np.ndarray) -> np.ndarray: r""" Args: y (numpy.ndarray): Separated signal with shape of (n_sources, n_bins, n_frames). Returns: numpy.ndarray of computed contrast function. The shape is (n_sources, n_frames). """ n_bins = self.n_bins alpha = self.variance norm = np.linalg.norm(y, axis=1) return n_bins * np.log(alpha) + (norm**2) / alpha def score_fn(y: np.ndarray) -> np.ndarray: r""" Args: y (numpy.ndarray): Norm of separated signal. The shape is (n_sources, n_bins, n_frames). Returns: numpy.ndarray of computed contrast function. The shape is (n_sources, n_frames). """ alpha = self.variance return y / alpha[:, np.newaxis, :] super().__init__( step_size=step_size, contrast_fn=contrast_fn, score_fn=score_fn, flooring_fn=flooring_fn, callbacks=callbacks, is_holonomic=is_holonomic, scale_restoration=scale_restoration, record_loss=record_loss, reference_id=reference_id, ) def _reset(self, **kwargs) -> None: r"""Reset attributes by given keyword arguments. We also set variance of Gaussian distribution. Args: kwargs: Keyword arguments to set as attributes of IVA. """ super()._reset(**kwargs) n_sources, n_frames = self.n_sources, self.n_frames self.variance = np.ones((n_sources, n_frames)) def update_once(self) -> None: r"""Update variance and demixing filters and once.""" self.update_source_model() super().update_once() def update_source_model(self) -> None: r"""Update variance of Gaussian distribution.""" X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) self.variance = np.mean(np.abs(Y) ** 2, axis=1) class AuxLaplaceIVA(AuxIVA): r"""Auxiliary-function-based independent vector analysis (IVA) \ on a Laplace distribution. We assume :math:`\vec{\boldsymbol{y}}_{jn}` follows a Laplace distribution. .. math:: p(\vec{\boldsymbol{y}}_{jn})\propto\exp(\|\vec{\boldsymbol{y}}_{jn}\|_{2}) Args: spatial_algorithm (str): Algorithm for demixing filter updates. Choose ``IP``, ``IP1``, ``IP2``, ``ISS``, ``ISS1``, or ``ISS2``. Default: ``IP``. flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. pair_selector (callable, optional): Selector to choose updaing pair in ``IP2`` and ``ISS2``. If ``None`` is given, ``sequential_pair_selector`` is used. Default: ``None``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` or ``minimal_distortion_principle``. Default: ``True``. record_loss (bool): Record the loss at each iteration of the demixing filter update if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back and minimal distortion principle. Default: ``0``. Examples: Update demixing filters by IP: .. code-block:: python >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> iva = AuxLaplaceIVA(spatial_algorithm="IP") >>> spectrogram_est = iva(spectrogram_mix, n_iter=100) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters by IP2: .. code-block:: python >>> from ssspy.utils.select_pair import sequential_pair_selector >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> iva = AuxLaplaceIVA( ... spatial_algorithm="IP2", ... pair_selector=sequential_pair_selector, ... ) >>> spectrogram_est = iva(spectrogram_mix, n_iter=100) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters by ISS: .. code-block:: python >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> iva = AuxLaplaceIVA(spatial_algorithm="ISS") >>> spectrogram_est = iva(spectrogram_mix, n_iter=100) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters by ISS2: .. code-block:: python >>> import functools >>> from ssspy.utils.select_pair import sequential_pair_selector >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> iva = AuxLaplaceIVA( ... spatial_algorithm="ISS2", ... pair_selector=functools.partial(sequential_pair_selector, step=2), ... ) >>> spectrogram_est = iva(spectrogram_mix, n_iter=100) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) """ def __init__( self, spatial_algorithm: str = "IP", flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), pair_selector: Optional[Callable[[int], Iterable[Tuple[int, int]]]] = None, callbacks: Optional[ Union[Callable[["AuxLaplaceIVA"], None], List[Callable[["AuxLaplaceIVA"], None]]] ] = None, scale_restoration: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, **kwargs, ) -> None: def contrast_fn(y) -> np.ndarray: r"""Contrast function. Args: y (numpy.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: numpy.ndarray of the shape is (n_sources, n_frames). """ return 2 * np.linalg.norm(y, axis=1) def d_contrast_fn(y) -> np.ndarray: r"""Derivative of contrast function. Args: y (numpy.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: numpy.ndarray of the shape is (n_sources, n_frames). """ return 2 * np.ones_like(y) super().__init__( spatial_algorithm=spatial_algorithm, contrast_fn=contrast_fn, d_contrast_fn=d_contrast_fn, flooring_fn=flooring_fn, pair_selector=pair_selector, callbacks=callbacks, scale_restoration=scale_restoration, record_loss=record_loss, reference_id=reference_id, **kwargs, ) class AuxGaussIVA(AuxIVA): r"""Auxiliary-function-based independent vector analysis (IVA) \ on a time-varying Gaussian distribution [#ono2012auxiliary]_. We assume :math:`\vec{\boldsymbol{y}}_{jn}` follows a time-varying Gaussian distribution. .. math:: p(\vec{\boldsymbol{y}}_{jn}) \propto\frac{1}{\alpha_{jn}^{I}} \exp\left(\frac{\|\vec{\boldsymbol{y}}_{jn}\|_{2}^{2}}{\alpha_{jn}}\right). Args: spatial_algorithm (str): Algorithm for demixing filter updates. Choose ``IP``, ``IP1``, ``IP2``, ``ISS``, ``ISS1``, or ``ISS2``. Default: ``IP``. flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. pair_selector (callable, optional): Selector to choose updaing pair in ``IP2`` and ``ISS2``. If ``None`` is given, ``sequential_pair_selector`` is used. Default: ``None``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` or ``minimal_distortion_principle``. Default: ``True``. record_loss (bool): Record the loss at each iteration of the demixing filter update if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back and minimal distortion principle. Default: ``0``. Examples: Update demixing filters by IP: .. code-block:: python >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> iva = AuxGaussIVA(spatial_algorithm="IP") >>> spectrogram_est = iva(spectrogram_mix, n_iter=100) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters by IP2: .. code-block:: python >>> from ssspy.utils.select_pair import sequential_pair_selector >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> iva = AuxGaussIVA( ... spatial_algorithm="IP2", ... pair_selector=sequential_pair_selector, ... ) >>> spectrogram_est = iva(spectrogram_mix, n_iter=100) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters by ISS: .. code-block:: python >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> iva = AuxGaussIVA(spatial_algorithm="ISS") >>> spectrogram_est = iva(spectrogram_mix, n_iter=100) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) Update demixing filters by ISS2: .. code-block:: python >>> import functools >>> from ssspy.utils.select_pair import sequential_pair_selector >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \ ... + 1j * np.random.randn(n_channels, n_bins, n_frames) >>> iva = AuxGaussIVA( ... spatial_algorithm="ISS2", ... pair_selector=functools.partial(sequential_pair_selector, step=2), ... ) >>> spectrogram_est = iva(spectrogram_mix, n_iter=100) >>> print(spectrogram_mix.shape, spectrogram_est.shape) (2, 2049, 128), (2, 2049, 128) .. [#ono2012auxiliary] N. Ono, "Auxiliary-function-based independent vector analysis with power of \ vector-norm type weighting functions," in *Proc. APSIPA ASC*, 2012, pp. 1-4. """ def __init__( self, spatial_algorithm: str = "IP", flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), pair_selector: Optional[Callable[[int], Iterable[Tuple[int, int]]]] = None, callbacks: Optional[ Union[Callable[["AuxGaussIVA"], None], List[Callable[["AuxGaussIVA"], None]]] ] = None, scale_restoration: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, **kwargs, ) -> None: def contrast_fn(y: np.ndarray) -> np.ndarray: r""" Args: y (numpy.ndarray): Separated signal with shape of (n_sources, n_bins, n_frames). Returns: numpy.ndarray: Computed contrast function. The shape is (n_sources, n_frames). """ n_bins = self.n_bins alpha = self.variance norm = np.linalg.norm(y, axis=1) return n_bins * np.log(alpha) + (norm**2) / alpha def d_contrast_fn(y: np.ndarray, variance: np.ndarray = None) -> np.ndarray: r""" Args: y (numpy.ndarray): Norm of separated signal. The shape is (n_sources, n_frames). Returns: numpy.ndarray of computed contrast function. The shape is (n_sources, n_frames). """ if variance is None: alpha = self.variance else: alpha = variance return 2 * y / alpha super().__init__( spatial_algorithm=spatial_algorithm, contrast_fn=contrast_fn, d_contrast_fn=d_contrast_fn, flooring_fn=flooring_fn, pair_selector=pair_selector, callbacks=callbacks, scale_restoration=scale_restoration, record_loss=record_loss, reference_id=reference_id, **kwargs, ) def _reset(self, **kwargs) -> None: r"""Reset attributes by given keyword arguments. We also set variance of Gaussian distribution. Args: kwargs: Keyword arguments to set as attributes of IVA. """ super()._reset(**kwargs) n_sources, n_frames = self.n_sources, self.n_frames self.variance = np.ones((n_sources, n_frames)) def update_once( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update variance and demixing filters and once. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ self.update_source_model() super().update_once(flooring_fn=flooring_fn) def update_once_ip2( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update demixing filters once using pairwise iterative projection. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. For :math:`n_{1}` and :math:`n_{2}` (:math:`n_{1}\neq n_{2}`), compute auxiliary variables: .. math:: \bar{r}_{jn_{1}} &\leftarrow\|\vec{\boldsymbol{y}}_{jn_{1}}\|_{2} \\ \bar{r}_{jn_{2}} &\leftarrow\|\vec{\boldsymbol{y}}_{jn_{2}}\|_{2} Then, for :math:`n=n_{1},n_{2}`, compute weighted covariance matrix as follows: .. math:: \boldsymbol{U}_{in_{1}} &= \frac{1}{J}\sum_{j} \varphi(\bar{r}_{jn_{1}})\boldsymbol{x}_{ij}\boldsymbol{x}_{ij}^{\mathsf{H}}, \\ \boldsymbol{U}_{in_{2}} &= \frac{1}{J}\sum_{j} \varphi(\bar{r}_{jn_{2}})\boldsymbol{x}_{ij}\boldsymbol{x}_{ij}^{\mathsf{H}}, where .. math:: \varphi(\bar{r}_{jn}) = \frac{G'_{\mathbb{R}}(\bar{r}_{jn})}{2\bar{r}_{jn}}. Using :math:`\boldsymbol{U}_{in_{1}}` and :math:`\boldsymbol{U}_{in_{2}}`, we compute generalized eigenvectors. .. math:: \left({\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}}^{\mathsf{H}}\boldsymbol{U}_{in_{1}} \boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\right)\boldsymbol{h}_{i} = \lambda_{i} \left({\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}}^{\mathsf{H}}\boldsymbol{U}_{in_{2}} \boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\right)\boldsymbol{h}_{i}, where .. math:: \boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})} &= (\boldsymbol{W}_{i}\boldsymbol{U}_{in_{1}})^{-1} ( \begin{array}{cc} \boldsymbol{e}_{n_{1}} & \boldsymbol{e}_{n_{2}} \end{array} ), \\ \boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})} &= (\boldsymbol{W}_{i}\boldsymbol{U}_{in_{2}})^{-1} ( \begin{array}{cc} \boldsymbol{e}_{n_{1}} & \boldsymbol{e}_{n_{2}} \end{array} ). After that, we standardize two eigenvectors :math:`\boldsymbol{h}_{in_{1}}` and :math:`\boldsymbol{h}_{in_{2}}`. .. math:: \boldsymbol{h}_{in_{1}} &\leftarrow\frac{\boldsymbol{h}_{in_{1}}} {\sqrt{\boldsymbol{h}_{in_{1}}^{\mathsf{H}} \left({\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}}^{\mathsf{H}}\boldsymbol{U}_{in_{1}} \boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\right) \boldsymbol{h}_{in_{1}}}}, \\ \boldsymbol{h}_{in_{2}} &\leftarrow\frac{\boldsymbol{h}_{in_{2}}} {\sqrt{\boldsymbol{h}_{in_{2}}^{\mathsf{H}} \left({\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}}^{\mathsf{H}}\boldsymbol{U}_{in_{2}} \boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\right) \boldsymbol{h}_{in_{2}}}}. Then, update :math:`\boldsymbol{w}_{in_{1}}` and :math:`\boldsymbol{w}_{in_{2}}` simultaneously. .. math:: \boldsymbol{w}_{in_{1}} &\leftarrow \boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\boldsymbol{h}_{in_{1}} \\ \boldsymbol{w}_{in_{2}} &\leftarrow \boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\boldsymbol{h}_{in_{2}}. At each iteration, we update pairs of :math:`n_{1}` and :math:`n_{1}` for :math:`n_{1}\neq n_{2}`. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) n_sources = self.n_sources X, W = self.input, self.demix_filter R = self.variance XX_Hermite = X[:, np.newaxis, :, :] * X[np.newaxis, :, :, :].conj() XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3) for m, n in self.pair_selector(n_sources): W_mn = W[:, (m, n), :] Y_mn = self.separate(X, demix_filter=W_mn) R_mn = R[(m, n), :] norm = np.linalg.norm(Y_mn, axis=1) weight_mn = self.d_contrast_fn(norm, variance=R_mn) / flooring_fn(2 * norm) GXX_mn = weight_mn[:, np.newaxis, np.newaxis, :] * XX_Hermite[:, np.newaxis, :, :, :] U_mn = np.mean(GXX_mn, axis=-1) W[:, (m, n), :] = update_by_ip2_one_pair( W, U_mn, pair=(m, n), flooring_fn=flooring_fn, ) self.demix_filter = W def update_source_model(self) -> None: r"""Update variance of Gaussian distribution.""" if self.demix_filter is None: Y = self.output else: X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) self.variance = np.mean(np.abs(Y) ** 2, axis=1) ================================================ FILE: ssspy/bss/mnmf.py ================================================ import functools from typing import Callable, Iterable, List, Optional, Tuple, Union import numpy as np from ..linalg._solve import solve from ..linalg.mean import gmeanmh from ..special.flooring import identity, max_flooring from ..special.psd import to_psd from ..utils.flooring import choose_flooring_fn from ..utils.select_pair import sequential_pair_selector from ._update_spatial_model import update_by_ip1, update_by_ip2 from .base import IterativeMethodBase __all__ = ["GaussMNMF", "FastGaussMNMF"] diagonalizer_algorithms = ["IP", "IP1", "IP2"] EPS = 1e-10 class MNMFBase(IterativeMethodBase): r"""Base class of multichannel nonnegative matrix factorization (MNMF). Args: n_basis (int): Number of NMF bases. n_sources (int, optional): Number of sources to be separated. If ``None`` is given, ``n_sources`` is determined by number of channels in input spectrogram. Default: ``None``. partitioning (bool): Whether to use partioning function. Default: ``False``. flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. record_loss (bool): Record the loss at each iteration of the update algorithm if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel in multichannel Wiener filter. Default: ``0``. rng (numpy.random.Generator, optioinal): Random number generator. This is mainly used to randomly initialize PSDTF. If ``None`` is given, ``np.random.default_rng()`` is used. Default: ``None``. """ def __init__( self, n_basis: int, n_sources: Optional[int] = None, partitioning: bool = False, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), callbacks: Optional[ Union[Callable[["MNMFBase"], None], List[Callable[["MNMFBase"], None]]] ] = None, normalization: bool = True, record_loss: bool = True, reference_id: int = 0, rng: Optional[np.random.Generator] = None, ) -> None: super().__init__(callbacks=callbacks, record_loss=record_loss) self.n_basis = n_basis self.n_sources = n_sources self.partitioning = partitioning if flooring_fn is None: self.flooring_fn = identity else: self.flooring_fn = flooring_fn self.normalization = normalization self.input = None self.reference_id = reference_id if rng is None: rng = np.random.default_rng() self.rng = rng def __call__( self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs ) -> np.ndarray: r"""Separate a frequency-domain multichannel signal. Args: input (numpy.ndarray): The mixture signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). n_iter (int): The number of iterations of demixing filter updates. Default: ``100``. initial_call (bool): If ``True``, perform callbacks (and computation of loss if necessary) before iterations. Returns: numpy.ndarray of the separated signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). """ self.input = input.copy() self._reset(**kwargs) super().__call__(n_iter=n_iter, initial_call=initial_call) self.output = self.separate(self.input) return self.output def __repr__(self) -> str: s = "MNMF(" s += "n_basis={n_basis}" if self.n_sources is not None: s += ", n_sources={n_sources}" if hasattr(self, "n_channels"): s += ", n_channels={n_channels}" s += ", partitioning={partitioning}" s += ", normalization={normalization}" s += ", record_loss={record_loss}" s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) def _reset(self, **kwargs) -> None: r"""Reset attributes by given keyword arguments. Args: kwargs: Keyword arguments to set as attributes of MNMF. """ assert self.input is not None, "Specify data!" for key in kwargs.keys(): setattr(self, key, kwargs[key]) X = self.input n_sources = self.n_sources n_channels, n_bins, n_frames = X.shape if n_sources is None: n_sources = n_channels self.n_sources, self.n_channels = n_sources, n_channels self.n_bins, self.n_frames = n_bins, n_frames self._init_instant_covariance() self._init_nmf(rng=self.rng) self.output = self.separate(X) def _init_instant_covariance( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Initialize instantaneous covariance of input. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) X = self.input XX = X[:, np.newaxis] * X[np.newaxis, :].conj() XX = XX.transpose(2, 3, 0, 1) self.instant_covariance = to_psd(XX, flooring_fn=flooring_fn) def _init_nmf( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", rng: Optional[np.random.Generator] = None, ) -> None: r"""Initialize NMF. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. rng (numpy.random.Generator, optional): Random number generator. If ``None`` is given, ``np.random.default_rng()`` is used. Default: ``None``. """ n_basis = self.n_basis n_sources = self.n_sources n_bins, n_frames = self.n_bins, self.n_frames flooring_fn = choose_flooring_fn(flooring_fn, method=self) if rng is None: rng = np.random.default_rng() if self.partitioning: if not hasattr(self, "basis"): T = rng.random((n_bins, n_basis)) T = flooring_fn(T) else: # To avoid overwriting. T = self.basis.copy() if not hasattr(self, "activation"): V = rng.random((n_basis, n_frames)) V = flooring_fn(V) else: # To avoid overwriting. V = self.activation.copy() if not hasattr(self, "latent"): Z = rng.random((n_sources, n_basis)) Z = Z / Z.sum(axis=0) Z = flooring_fn(Z) else: # To avoid overwriting. Z = self.latent.copy() self.basis, self.activation = T, V self.latent = Z else: if not hasattr(self, "basis"): T = rng.random((n_sources, n_bins, n_basis)) T = flooring_fn(T) else: # To avoid overwriting. T = self.basis.copy() if not hasattr(self, "activation"): V = rng.random((n_sources, n_basis, n_frames)) V = flooring_fn(V) else: # To avoid overwriting. V = self.activation.copy() self.basis, self.activation = T, V def separate(self, input: np.ndarray) -> np.ndarray: raise NotImplementedError("Implement 'separate' method.") def reconstruct_nmf( self, basis: np.ndarray, activation: np.ndarray, latent: Optional[np.ndarray] = None, ) -> np.ndarray: r"""Reconstruct single-channel NMF. Args: basis (numpy.ndarray): Basis matrix. The shape is (n_sources, n_basis, n_frames) if latent is given. Otherwise, (n_basis, n_frames). activation (numpy.ndarray): Activation matrix. The shape is (n_sources, n_bins, n_basis) if latent is given. Otherwise, (n_bins, n_basis). latent (numpy.ndarray, optional): Latent variable that determines number of bases per source. Returns: numpy.ndarray of reconstructed single-channel NMF. The shape is (n_sources, n_bins, n_frames). """ if latent is None: T, V = basis, activation Lamb = T @ V else: Z = latent T, V = basis, activation TV = T[:, :, np.newaxis] * V[np.newaxis, :, :] Lamb = np.sum(Z[:, np.newaxis, :, np.newaxis] * TV[np.newaxis, :, :, :], axis=2) return Lamb class MNMF(MNMFBase): def __init__( self, n_basis: int, n_sources: Optional[int] = None, partitioning: bool = False, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), callbacks: Optional[Union[Callable[["MNMF"], None], List[Callable[["MNMF"], None]]]] = None, normalization: bool = True, record_loss: bool = True, reference_id: int = 0, rng: Optional[np.random.Generator] = None, ) -> None: super().__init__( n_basis, n_sources=n_sources, partitioning=partitioning, flooring_fn=flooring_fn, callbacks=callbacks, normalization=normalization, record_loss=record_loss, reference_id=reference_id, rng=rng, ) def _init_nmf(self, rng: Optional[np.random.Generator] = None) -> None: r"""Initialize NMF. Args: rng (numpy.random.Generator, optional): Random number generator. If ``None`` is given, ``np.random.default_rng()`` is used. Default: ``None``. """ if rng is None: rng = np.random.default_rng() super()._init_nmf(rng=rng) n_sources, n_channels = self.n_sources, self.n_channels n_bins = self.n_bins if not hasattr(self, "spatial"): H = np.eye(n_channels, dtype=self.input.dtype) trace = np.trace(H, axis1=-2, axis2=-1) H = H / np.real(trace) H = np.tile(H, reps=(n_sources, n_bins, 1, 1)) else: # To avoid overwriting. H = self.spatial.copy() self.spatial = H def reconstruct_mnmf( self, basis: np.ndarray, activation: np.ndarray, spatial: np.ndarray, latent: Optional[np.ndarray] = None, ) -> np.ndarray: r"""Reconstruct multichannel NMF. Args: basis (numpy.ndarray): Basis matrix with shape of (n_bins, n_basis). activation (numpy.ndarray): Activation matrix with shape of (n_basis, n_frames). spatial (numpy.ndarray): Spatial property with shape of (n_sources, n_bins, n_channels, n_channels). latent (numpy.ndarray, optional): Latent variables with shape of (n_sources, n_basis). Returns: numpy.ndarray of reconstructed multichannel NMF. The shape is (n_bins, n_frames, n_channels, n_channels). """ T, V = basis, activation H = spatial if latent is None: Lamb = self.reconstruct_nmf(T, V) else: Lamb = self.reconstruct_nmf(T, V, latent=latent) R_n = Lamb[:, :, :, np.newaxis, np.newaxis] * H[:, :, np.newaxis, :, :] R = np.sum(R_n, axis=0) return R def normalize(self, axis1=-2, axis2=-1) -> None: r"""Ensure unit trace of spatial property of MNMF.""" H = self.spatial n_dims = H.ndim axis1 = n_dims + axis1 if axis1 < 0 else axis1 axis2 = n_dims + axis2 if axis2 < 0 else axis2 assert axis1 == 2 and axis2 == 3 trace = np.trace(H, axis1=axis1, axis2=axis2) trace = np.real(trace) H = H / trace[..., np.newaxis, np.newaxis] if self.partitioning: # When self.partitioning=True, # normalization may change value of cost function pass else: T = self.basis T = trace[:, :, np.newaxis] * T self.basis = T self.spatial = H class FastMNMFBase(MNMFBase): r"""Base class of fast multichannel nonnegative matrix factorization (Fast MNMF). Args: n_basis (int): Number of NMF bases. n_sources (int, optional): Number of sources to be separated. If ``None`` is given, ``n_sources`` is determined by number of channels in input spectrogram. Default: ``None``. partitioning (bool): Whether to use partioning function. Default: ``False``. flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. normalization (bool or str): Normalization of diagonalizers and diagonal elements of spatial covariance matrices. Only power-based normalization is supported. Default: ``True``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. record_loss (bool): Record the loss at each iteration of the update algorithm if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel in multichannel Wiener filter. Default: ``0``. rng (numpy.random.Generator, optioinal): Random number generator. This is mainly used to randomly initialize PSDTF. If ``None`` is given, ``np.random.default_rng()`` is used. Default: ``None``. """ def __init__( self, n_basis: int, n_sources: Optional[int] = None, partitioning: bool = False, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), callbacks: Optional[ Union[Callable[["FastMNMFBase"], None], List[Callable[["FastMNMFBase"], None]]] ] = None, normalization: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, rng: Optional[np.random.Generator] = None, ) -> None: super().__init__( n_basis, n_sources=n_sources, partitioning=partitioning, flooring_fn=flooring_fn, callbacks=callbacks, normalization=normalization, record_loss=record_loss, reference_id=reference_id, rng=rng, ) def __repr__(self) -> str: s = "FastMNMF(" s += "n_basis={n_basis}" if self.n_sources is not None: s += ", n_sources={n_sources}" if hasattr(self, "n_channels"): s += ", n_channels={n_channels}" s += ", partitioning={partitioning}" s += ", normalization={normalization}" s += ", record_loss={record_loss}" s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) def _reset( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", **kwargs, ) -> None: r"""Reset attributes by given keyword arguments. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. kwargs: Keyword arguments to set as attributes of MNMF. """ assert self.input is not None, "Specify data!" flooring_fn = choose_flooring_fn(flooring_fn, method=self) for key in kwargs.keys(): setattr(self, key, kwargs[key]) X = self.input n_sources = self.n_sources n_channels, n_bins, n_frames = X.shape if n_sources is None: n_sources = n_channels self.n_sources, self.n_channels = n_sources, n_channels self.n_bins, self.n_frames = n_bins, n_frames self._init_instant_covariance(flooring_fn=flooring_fn) self._init_nmf(flooring_fn=flooring_fn, rng=self.rng) self._init_diagonalizer(rng=self.rng) self._init_spatial(flooring_fn=flooring_fn, rng=self.rng) self.output = self.separate(X) def _init_diagonalizer(self, rng: Optional[np.random.Generator] = None) -> None: """Initialize diagonalizer. Args: rng (numpy.random.Generator, optional): Random number generator. If ``None`` is given, ``np.random.default_rng()`` is used. Default: ``None``. """ n_channels = self.n_channels n_bins = self.n_bins if rng is None: rng = np.random.default_rng() if not hasattr(self, "diagonalizer"): Q = np.eye(n_channels, dtype=np.complex128) Q = np.tile(Q, reps=(n_bins, 1, 1)) else: # To avoid overwriting. Q = self.diagonalizer.copy() self.diagonalizer = Q def _init_spatial( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", rng: Optional[np.random.Generator] = None, ) -> None: """Initialize diagonal elements of spatial covariance matrices. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. rng (numpy.random.Generator, optional): Random number generator. If ``None`` is given, ``np.random.default_rng()`` is used. Default: ``None``. """ n_sources, n_channels = self.n_sources, self.n_channels n_bins = self.n_bins flooring_fn = choose_flooring_fn(flooring_fn, method=self) if rng is None: rng = np.random.default_rng() if not hasattr(self, "spatial"): D = rng.random((n_bins, n_sources, n_channels)) D = flooring_fn(D) else: D = self.spatial self.spatial = D def normalize( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Normalize diagonalizers and diagonal elements of spatial covariance matrices. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ normalization = self.normalization flooring_fn = choose_flooring_fn(flooring_fn, method=self) assert normalization, "Set normalization." if type(normalization) is bool: # when normalization is True normalization = "power" if normalization == "power": self.normalize_by_power(flooring_fn=flooring_fn) else: raise NotImplementedError("Normalization {} is not implemented.".format(normalization)) def normalize_by_power( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Normalize diagonalizers and diagonal elements of spatial covariance matrices by power. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. Diagonalizers are normalized by .. math:: \boldsymbol{q}_{im} \leftarrow\frac{\boldsymbol{q}_{im}}{\psi_{im}}, where .. math:: \psi_{im} = \sqrt{\frac{1}{IJ}\sum_{i,j}|\boldsymbol{q}_{im}^{\mathsf{H}} \boldsymbol{x}_{ij}|^{2}}. For diagonal elements of spatial covariance matrices, .. math:: d_{inm} \leftarrow\frac{d_{inm}}{\psi_{im}^{2}}. """ X = self.input Q, D = self.diagonalizer, self.spatial flooring_fn = choose_flooring_fn(flooring_fn, method=self) QX = Q @ X.transpose(1, 0, 2) QX2 = np.mean(np.abs(QX) ** 2, axis=(0, 2)) psi = np.sqrt(QX2) psi = flooring_fn(psi) Q = Q / psi[np.newaxis, :, np.newaxis] D = D / (psi**2) self.diagonalizer, self.spatial = Q, D class GaussMNMF(MNMF): def __init__( self, n_basis: int, n_sources: Optional[int] = None, partitioning: bool = False, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), callbacks: Optional[ Union[Callable[["GaussMNMF"], None], List[Callable[["GaussMNMF"], None]]] ] = None, normalization: Union[bool, str] = True, record_loss: bool = True, reference_id: int = 0, rng: Optional[np.random.Generator] = None, ) -> None: super().__init__( n_basis, n_sources=n_sources, partitioning=partitioning, flooring_fn=flooring_fn, callbacks=callbacks, normalization=normalization, record_loss=record_loss, reference_id=reference_id, rng=rng, ) def __repr__(self) -> str: s = "GaussMNMF(" s += "n_basis={n_basis}" if self.n_sources is not None: s += ", n_sources={n_sources}" if hasattr(self, "n_channels"): s += ", n_channels={n_channels}" s += ", partitioning={partitioning}" s += ", normalization={normalization}" s += ", record_loss={record_loss}" s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) def separate(self, input: np.ndarray) -> np.ndarray: """Separate ``input`` using multichannel Wiener filter. Args: input (numpy.ndarray): The mixture signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). Returns: numpy.ndarray of the separated signal in frequency-domain. The shape is (n_sources, n_bins, n_frames). """ n_sources = self.n_sources reference_id = self.reference_id X = input T, V = self.basis, self.activation H = self.spatial if self.partitioning: Lamb = self.reconstruct_nmf(T, V, latent=self.latent) else: Lamb = self.reconstruct_nmf(T, V) R_n = Lamb[:, :, :, np.newaxis, np.newaxis] * H[:, :, np.newaxis, :, :] R = np.sum(R_n, axis=0) R = to_psd(R, flooring_fn=self.flooring_fn) R = np.tile(R, reps=(n_sources, 1, 1, 1, 1)) W_Hermite = solve(R, R_n) W = W_Hermite.transpose(0, 1, 2, 4, 3).conj() W_ref = W[:, :, :, reference_id, :] W_ref = W_ref.transpose(0, 3, 1, 2) Y = np.sum(W_ref * X, axis=1) return Y def compute_loss(self) -> float: r"""Compute loss :math:`\mathcal{L}`. Returns: Computed loss. """ XX = self.instant_covariance T, V = self.basis, self.activation H = self.spatial if self.partitioning: R = self.reconstruct_mnmf(T, V, H, latent=self.latent) else: R = self.reconstruct_mnmf(T, V, H) R = to_psd(R, flooring_fn=self.flooring_fn) XXR_inv = solve(R, XX) # Hermitian transpose of XX @ np.linalg.inv(R) trace = np.trace(XXR_inv, axis1=-2, axis2=-1) trace = np.real(trace) logdet = self.compute_logdet(R) loss = np.mean(trace + logdet, axis=-1) loss = loss.sum(axis=0) loss = loss.item() return loss def compute_logdet(self, reconstructed: np.ndarray) -> np.ndarray: r"""Compute log-determinant. Args: reconstructed: Reconstructed MNMF with shape of (\*, n_channels, n_channels). Returns: numpy.ndarray of computed log-determinant values. The shape is (\*). """ _, logdet = np.linalg.slogdet(reconstructed) return logdet def update_once( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update MNMF parameters once. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) self.update_basis(flooring_fn=flooring_fn) self.update_activation(flooring_fn=flooring_fn) self.update_spatial(flooring_fn=flooring_fn) if self.normalization: # ensure unit trace of spatial property # before updates of latent variables in MNMF self.normalize(axis1=-2, axis2=-1) if self.partitioning: self.update_latent(flooring_fn=flooring_fn) def update_basis( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update NMF bases by MM algorithm. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ n_sources = self.n_sources n_frames = self.n_frames na = np.newaxis flooring_fn = choose_flooring_fn(flooring_fn, method=self) def _compute_traces( target: np.ndarray, reconstructed: np.ndarray, spatial: np.ndarray ) -> np.ndarray: RXX = solve(reconstructed, target) R = np.tile(reconstructed, reps=(n_sources, 1, 1, 1, 1)) H = np.tile(spatial[:, :, na, :, :], reps=(1, 1, n_frames, 1, 1)) RH = solve(R, H) trace_RXXRH = np.trace(RXX @ RH, axis1=-2, axis2=-1) trace_RXXRH = np.real(trace_RXXRH) trace_RH = np.trace(RH, axis1=-2, axis2=-1) trace_RH = np.real(trace_RH) return trace_RXXRH, trace_RH XX = self.instant_covariance T, V = self.basis, self.activation H = self.spatial if self.partitioning: Z = self.latent R = self.reconstruct_mnmf(T, V, H, latent=Z) R = to_psd(R, flooring_fn=flooring_fn) trace_RXXRH, trace_RH = _compute_traces(XX, R, spatial=H) VRXXRH = np.sum(V[na, na, :] * trace_RXXRH[:, :, na], axis=-1) VRH = np.sum(V[na, na, :] * trace_RH[:, :, na], axis=-1) num = np.sum(Z[:, na, :] * VRXXRH, axis=0) denom = np.sum(Z[:, na, :] * VRH, axis=0) else: R = self.reconstruct_mnmf(T, V, H) R = to_psd(R, flooring_fn=flooring_fn) trace_RXXRH, trace_RH = _compute_traces(XX, R, spatial=H) num = np.sum(V[:, na, :, :] * trace_RXXRH[:, :, na, :], axis=-1) denom = np.sum(V[:, na, :, :] * trace_RH[:, :, na, :], axis=-1) T = T * np.sqrt(num / denom) T = flooring_fn(T) self.basis = T def update_activation( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update NMF activations by MM algorithm. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ n_sources = self.n_sources n_frames = self.n_frames na = np.newaxis flooring_fn = choose_flooring_fn(flooring_fn, method=self) def _compute_traces( target: np.ndarray, reconstructed: np.ndarray, spatial: np.ndarray ) -> np.ndarray: RXX = solve(reconstructed, target) R = np.tile(reconstructed, reps=(n_sources, 1, 1, 1, 1)) H = np.tile(spatial[:, :, na, :, :], reps=(1, 1, n_frames, 1, 1)) RH = solve(R, H) trace_RXXRH = np.trace(RXX @ RH, axis1=-2, axis2=-1) trace_RXXRH = np.real(trace_RXXRH) trace_RH = np.trace(RH, axis1=-2, axis2=-1) trace_RH = np.real(trace_RH) return trace_RXXRH, trace_RH XX = self.instant_covariance T, V = self.basis, self.activation H = self.spatial if self.partitioning: Z = self.latent R = self.reconstruct_mnmf(T, V, H, latent=Z) R = to_psd(R, flooring_fn=flooring_fn) trace_RXXRH, trace_RH = _compute_traces(XX, R, spatial=H) TRXXRH = np.sum(T[na, :, :, na] * trace_RXXRH[:, :, na, :], axis=1) TRH = np.sum(T[na, :, :, na] * trace_RH[:, :, na, :], axis=1) num = np.sum(Z[:, :, na] * TRXXRH, axis=0) denom = np.sum(Z[:, :, na] * TRH, axis=0) else: R = self.reconstruct_mnmf(T, V, H) R = to_psd(R, flooring_fn=flooring_fn) trace_RXXRH, trace_RH = _compute_traces(XX, R, spatial=H) num = np.sum(T[:, :, :, na] * trace_RXXRH[:, :, na, :], axis=1) denom = np.sum(T[:, :, :, na] * trace_RH[:, :, na, :], axis=1) V = V * np.sqrt(num / denom) V = flooring_fn(V) self.activation = V def update_spatial( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update spatial properties in NMF by MM algorithm. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ na = np.newaxis flooring_fn = choose_flooring_fn(flooring_fn, method=self) XX = self.instant_covariance T, V = self.basis, self.activation H = self.spatial if self.partitioning: Z = self.latent Lamb = self.reconstruct_nmf(T, V, latent=Z) else: Lamb = self.reconstruct_nmf(T, V) R_n = Lamb[:, :, :, na, na] * H[:, :, na, :, :] R = np.sum(R_n, axis=0) R = to_psd(R, flooring_fn=flooring_fn) R_inverse = np.linalg.inv(R) RXXR = R_inverse @ XX @ R_inverse P = np.sum(Lamb[:, :, :, na, na] * R_inverse, axis=2) Q = np.sum(Lamb[:, :, :, na, na] * RXXR, axis=2) HQH = H @ Q @ H P = to_psd(P, flooring_fn=flooring_fn) HQH = to_psd(HQH, flooring_fn=flooring_fn) # geometric mean of P^(-1) and HQH H = gmeanmh(P, HQH, type=2) H = to_psd(H, flooring_fn=flooring_fn) self.spatial = H def update_latent( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update latent variables in NMF by MM algorithm. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ n_sources = self.n_sources n_frames = self.n_frames na = np.newaxis flooring_fn = choose_flooring_fn(flooring_fn, method=self) def _compute_traces( target: np.ndarray, reconstructed: np.ndarray, spatial: np.ndarray ) -> np.ndarray: RXX = solve(reconstructed, target) R = np.tile(reconstructed, reps=(n_sources, 1, 1, 1, 1)) H = np.tile(spatial[:, :, na, :, :], reps=(1, 1, n_frames, 1, 1)) RH = solve(R, H) trace_RXXRH = np.trace(RXX @ RH, axis1=-2, axis2=-1) trace_RXXRH = np.real(trace_RXXRH) trace_RH = np.trace(RH, axis1=-2, axis2=-1) trace_RH = np.real(trace_RH) return trace_RXXRH, trace_RH XX = self.instant_covariance T, V = self.basis, self.activation H, Z = self.spatial, self.latent R = self.reconstruct_mnmf(T, V, H, latent=Z) R = to_psd(R, flooring_fn=flooring_fn) trace_RXXRH, trace_RH = _compute_traces(XX, R, spatial=H) VRXXRH = np.sum(V[na, na, :] * trace_RXXRH[:, :, na], axis=-1) VRH = np.sum(V[na, na, :] * trace_RH[:, :, na], axis=-1) num = np.sum(T * VRXXRH, axis=1) denom = np.sum(T * VRH, axis=1) Z = Z * np.sqrt(num / denom) Z = Z / Z.sum(axis=0) self.latent = Z class FastGaussMNMF(FastMNMFBase): r"""Fast multichannel nonnegative matrix factorization on Gaussian distribution \ (Fast Gauss-MNMF). Args: n_basis (int): Number of NMF bases. n_sources (int, optional): Number of sources to be separated. If ``None`` is given, ``n_sources`` is determined by number of channels in input spectrogram. Default: ``None``. diagonalizer_algorithm (str): Algorithm for diagonalizers. Choose ``IP``, ``IP1``, or ``IP2``. Default: ``IP``. partitioning (bool): Whether to use partioning function. Default: ``False``. flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. record_loss (bool): Record the loss at each iteration of the update algorithm if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel in multichannel Wiener filter. Default: ``0``. rng (numpy.random.Generator, optioinal): Random number generator. This is mainly used to randomly initialize PSDTF. If ``None`` is given, ``np.random.default_rng()`` is used. Default: ``None``. """ def __init__( self, n_basis: int, n_sources: Optional[int] = None, diagonalizer_algorithm: str = "IP", partitioning: bool = False, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), pair_selector: Optional[Callable[[int], Iterable[Tuple[int, int]]]] = None, callbacks: Optional[ Union[Callable[["FastGaussMNMF"], None], List[Callable[["FastGaussMNMF"], None]]] ] = None, normalization: bool = True, record_loss: bool = True, reference_id: int = 0, rng: Optional[np.random.Generator] = None, ) -> None: super().__init__( n_basis, n_sources=n_sources, partitioning=partitioning, flooring_fn=flooring_fn, callbacks=callbacks, normalization=normalization, record_loss=record_loss, reference_id=reference_id, rng=rng, ) assert diagonalizer_algorithm in diagonalizer_algorithms, "Not support {}.".format( diagonalizer_algorithm ) assert not partitioning, "partitioning function is not supported." self.diagonalizer_algorithm = diagonalizer_algorithm if pair_selector is None: if diagonalizer_algorithm == "IP2": self.pair_selector = sequential_pair_selector else: self.pair_selector = pair_selector def __repr__(self) -> str: s = "FastGaussMNMF(" s += "n_basis={n_basis}" if self.n_sources is not None: s += ", n_sources={n_sources}" if hasattr(self, "n_channels"): s += ", n_channels={n_channels}" s += ", diagonalizer_algorithm={diagonalizer_algorithm}" s += ", partitioning={partitioning}" s += ", record_loss={record_loss}" s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) def separate(self, input: np.ndarray) -> np.ndarray: """Separate ``input`` using multichannel Wiener filter. Args: input (numpy.ndarray): The mixture signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). Returns: numpy.ndarray of the separated signal in frequency-domain. The shape is (n_sources, n_bins, n_frames). """ na = np.newaxis n_sources = self.n_sources reference_id = self.reference_id X = input T, V = self.basis, self.activation Q, D = self.diagonalizer, self.spatial if self.partitioning: Lamb = self.reconstruct_nmf(T, V, latent=self.latent) else: Lamb = self.reconstruct_nmf(T, V) D = D.transpose(1, 0, 2) Q_inverse = np.linalg.inv(Q) Q_inverse_Hermite = Q_inverse.transpose(0, 2, 1).conj() QQ_Hermite = Q_inverse[:, :, :, na] * Q_inverse_Hermite[:, na, :, :] LambD = Lamb[:, :, :, na] * D[:, :, na, :] R_n = np.sum(LambD[:, :, :, na, :, na] * QQ_Hermite[:, na, :, :, :], axis=4) R = np.sum(R_n, axis=0) R = to_psd(R, flooring_fn=self.flooring_fn) R = np.tile(R, reps=(n_sources, 1, 1, 1, 1)) W_Hermite = solve(R, R_n) W = W_Hermite.transpose(0, 1, 2, 4, 3).conj() W_ref = W[:, :, :, reference_id, :] W_ref = W_ref.transpose(0, 3, 1, 2) Y = np.sum(W_ref * X, axis=1) return Y def compute_loss(self) -> float: r"""Compute loss :math:`\mathcal{L}`. :math:`\mathcal{L}` is defined as follows: .. math:: \mathcal{L} &:=-\frac{1}{J}\sum_{i,j}\left\{ \mathrm{tr}\left( \boldsymbol{x}_{ij}\boldsymbol{x}_{ij}^{\mathsf{H}}\boldsymbol{R}_{ij}^{-1} \right) - \log\det\boldsymbol{R}_{ij} \right\} \\ &:=\frac{1}{J}\sum_{i,j,m}\left\{ \frac{|\boldsymbol{q}_{im}^{\mathsf{H}}\boldsymbol{x}_{ij}|^{2}} {\sum_{n}\lambda_{ijn}d_{inm}} + \log\sum_{n}\lambda_{ijn}d_{inm}\right\} - 2\sum_{i}\log|\det\boldsymbol{Q}_{i}|. Returns: Computed loss. """ X = self.input T, V = self.basis, self.activation Q, D = self.diagonalizer, self.spatial na = np.newaxis if self.partitioning: Lamb = self.reconstruct_nmf(T, V, latent=self.latent) else: Lamb = self.reconstruct_nmf(T, V) D = D.transpose(1, 0, 2) LambD = np.sum(Lamb[:, :, na, :] * D[:, :, :, na], axis=0) QX = Q @ X.transpose(1, 0, 2) QX2 = np.abs(QX) ** 2 logdetQ = self.compute_logdet(Q) loss = np.sum(QX2 / LambD + np.log(LambD), axis=1) loss = np.mean(loss, axis=-1) - 2 * logdetQ loss = loss.sum(axis=0) loss = loss.item() return loss def compute_logdet(self, diagonalizer: np.ndarray) -> np.ndarray: r"""Compute log-determinant. Args: reconstructed: Diagonalizer with shape of (\*, n_channels, n_channels). Returns: numpy.ndarray of computed log-determinant values. The shape is (\*). """ _, logdet = np.linalg.slogdet(diagonalizer) return logdet def update_once( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update MNMF parameters, diagonalizers, and diagonal elements of \ spatial covariance matrices once. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) self.update_basis(flooring_fn=flooring_fn) self.update_activation(flooring_fn=flooring_fn) self.update_diagonalizer(flooring_fn=flooring_fn) self.update_spatial() if self.normalization: self.normalize(flooring_fn=flooring_fn) def update_basis( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update NMF bases by MM algorithm. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. Update :math:`t_{ikn}` as follows: .. math:: t_{ikn} \leftarrow\left[ \frac{\displaystyle\sum_{j,m}\frac{|\boldsymbol{q}_{im}^{\mathsf{H}}\boldsymbol{x}_{ij}|^{2}d_{inm}v_{kjn}} {\left(\sum_{k',n'}t_{ik'n'}v_{k'jn'}d_{in'm}\right)^{2}}} {\displaystyle\sum_{j,m}\dfrac{d_{inm}v_{kjn}}{\sum_{k',n'}t_{ik'n'}v_{k'jn'}d_{in'm}}} \right]^{\frac{1}{2}}t_{ikn}. """ assert not self.partitioning, "partitioning function is not supported." na = np.newaxis flooring_fn = choose_flooring_fn(flooring_fn, method=self) X = self.input T, V = self.basis, self.activation Q, D = self.diagonalizer, self.spatial if self.partitioning: Lamb = self.reconstruct_nmf(T, V, latent=self.latent) else: Lamb = self.reconstruct_nmf(T, V) D = D.transpose(1, 0, 2) LambD = Lamb[:, :, :, na] * D[:, :, na, :] LambD = np.sum(LambD, axis=0) QX = Q @ X.transpose(1, 0, 2) QX = np.abs(QX) QX = QX.transpose(0, 2, 1) QXLambD = (QX / LambD) ** 2 DQXLambD = np.sum(D[:, :, na, :] * QXLambD, axis=-1) DLambD = np.sum(D[:, :, na, :] / LambD, axis=-1) num = np.sum(V[:, na, :] * DQXLambD[:, :, na], axis=-1) denom = np.sum(V[:, na, :] * DLambD[:, :, na], axis=-1) T = T * np.sqrt(num / denom) T = flooring_fn(T) self.basis = T def update_activation( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update NMF activations by MM algorithm. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. Update :math:`v_{kjn}` as follows: .. math:: v_{kjn} \leftarrow\left[ \frac{\displaystyle\sum_{i,m}\frac{|\boldsymbol{q}_{im}^{\mathsf{H}}\boldsymbol{x}_{ij}|^{2}d_{inm}t_{ikn}} {\left(\sum_{k',n'}t_{ik'n'}v_{k'jn'}d_{in'm}\right)^{2}}} {\displaystyle\sum_{i,m}\dfrac{d_{inm}t_{ikn}}{\sum_{k',n'}t_{ik'n'}v_{k'jn'}d_{in'm}}} \right]^{\frac{1}{2}}v_{kjn}. """ assert not self.partitioning, "partitioning function is not supported." na = np.newaxis flooring_fn = choose_flooring_fn(flooring_fn, method=self) X = self.input T, V = self.basis, self.activation Q, D = self.diagonalizer, self.spatial if self.partitioning: Lamb = self.reconstruct_nmf(T, V, latent=self.latent) else: Lamb = self.reconstruct_nmf(T, V) D = D.transpose(1, 0, 2) LambD = Lamb[:, :, :, na] * D[:, :, na, :] LambD = np.sum(LambD, axis=0) QX = Q @ X.transpose(1, 0, 2) QX = np.abs(QX) QX = QX.transpose(0, 2, 1) QXLambD = (QX / LambD) ** 2 DQXLambD = np.sum(D[:, :, na, :] * QXLambD, axis=-1) DLambD = np.sum(D[:, :, na, :] / LambD, axis=-1) num = np.sum(T[:, :, :, na] * DQXLambD[:, :, na, :], axis=1) denom = np.sum(T[:, :, :, na] * DLambD[:, :, na, :], axis=1) V = V * np.sqrt(num / denom) V = flooring_fn(V) self.activation = V def update_diagonalizer( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: """Update diagonalizer. - If ``diagonalizer_algorithm`` is ``IP`` or ``IP1``, \ ``update_diagonalizer_model_ip1`` is called. - If ``diagonalizer_algorithm`` is ``IP2``, \ ``update_diagonalizer_model_ip2`` is called. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. """ flooring_fn = choose_flooring_fn(flooring_fn, method=self) if self.diagonalizer_algorithm in ["IP", "IP1"]: self.update_diagonalizer_ip1(flooring_fn=flooring_fn) elif self.diagonalizer_algorithm in ["IP2"]: self.update_diagonalizer_ip2(flooring_fn=flooring_fn) else: raise NotImplementedError("Not support {}.".format(self.diagonalizer_algorithm)) def update_diagonalizer_ip1( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update diagonalizer once using iterative projection. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. Diagonalizers are updated sequentially for :math:`m=1,\ldots,M` as follows: .. math:: \boldsymbol{q}_{im} &\leftarrow\left(\boldsymbol{Q}_{im}^{\mathsf{H}}\boldsymbol{U}_{im}\right)^{-1} \ \boldsymbol{e}_{m}, \\ \boldsymbol{q}_{im} &\leftarrow\frac{\boldsymbol{q}_{im}} {\sqrt{\boldsymbol{q}_{im}^{\mathsf{H}}\boldsymbol{U}_{im}\boldsymbol{q}_{im}}}, where .. math:: \boldsymbol{U}_{im} = \frac{1}{J}\sum_{j} \frac{\boldsymbol{x}_{ij}\boldsymbol{x}_{ij}^{\mathsf{H}}} {\sum_{n}\left(\sum_{k}z_{nk}t_{ik}v_{kj}\right)d_{inm}} if ``partitioning=True``, otherwise .. math:: \boldsymbol{U}_{im} = \frac{1}{J}\sum_{j} \frac{\boldsymbol{x}_{ij}\boldsymbol{x}_{ij}^{\mathsf{H}}} {\sum_{n}\left(\sum_{k}t_{ikn}v_{kjn}\right)d_{inm}}. """ assert not self.partitioning, "partitioning function is not supported." na = np.newaxis flooring_fn = choose_flooring_fn(flooring_fn, method=self) X = self.input T, V = self.basis, self.activation Q, D = self.diagonalizer, self.spatial if self.partitioning: Lamb = self.reconstruct_nmf(T, V, latent=self.latent) else: Lamb = self.reconstruct_nmf(T, V) XX_Hermite = X[:, na, :, :] * X[na, :, :, :].conj() XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3) Lamb = Lamb.transpose(1, 0, 2) LambD = np.sum(Lamb[:, :, na, :] * D[:, :, :, na], axis=1) varphi = 1 / LambD varphi_XX = varphi[:, :, na, na, :] * XX_Hermite[:, na, :, :, :] U = np.mean(varphi_XX, axis=-1) self.diagonalizer = update_by_ip1(Q, U, flooring_fn=flooring_fn) def update_diagonalizer_ip2( self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", ) -> None: r"""Update diagonalizer once using pairwise iterative projection. Args: flooring_fn (callable or str, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. If ``self`` is given as str, ``self.flooring_fn`` is used. Default: ``self``. For :math:`m_{1}` and :math:`m_{2}` (:math:`m_{1}\neq m_{2}`), compute weighted covariance matrix as follows: .. math:: \boldsymbol{U}_{im} = \frac{1}{J}\sum_{j} \frac{\boldsymbol{x}_{ij}\boldsymbol{x}_{ij}^{\mathsf{H}}}{\sum_{n}\lambda_{ijn}d_{inm}}, :math:`\lambda_{ijn}` is computed by .. math:: \lambda_{ijn}=\sum_{k}z_{nk}t_{ik}v_{kj} if ``partitioning=True``. Otherwise, .. math:: \lambda_{ijn}=\sum_{k}t_{ikn}v_{kjn}. Using :math:`\boldsymbol{U}_{im_{1}}` and :math:`\boldsymbol{U}_{im_{2}}`, we compute generalized eigenvectors. .. math:: \left({\boldsymbol{P}_{im_{1}}^{(m_{1},m_{2})}}^{\mathsf{H}}\boldsymbol{U}_{im_{1}} \boldsymbol{P}_{im_{1}}^{(m_{1},m_{2})}\right)\boldsymbol{h}_{i} = \mu_{i} \left({\boldsymbol{P}_{im_{2}}^{(m_{1},m_{2})}}^{\mathsf{H}}\boldsymbol{U}_{im_{2}} \boldsymbol{P}_{im_{2}}^{(m_{1},m_{2})}\right)\boldsymbol{h}_{i}, where .. math:: \boldsymbol{P}_{im_{1}}^{(m_{1},m_{2})} &= (\boldsymbol{Q}_{i}\boldsymbol{U}_{im_{1}})^{-1} ( \begin{array}{cc} \boldsymbol{e}_{m_{1}} & \boldsymbol{e}_{m_{2}} \end{array} ), \\ \boldsymbol{P}_{im_{2}}^{(m_{1},m_{2})} &= (\boldsymbol{Q}_{i}\boldsymbol{U}_{im_{2}})^{-1} ( \begin{array}{cc} \boldsymbol{e}_{m_{1}} & \boldsymbol{e}_{m_{2}} \end{array} ). After that, we standardize two eigenvectors :math:`\boldsymbol{h}_{im_{1}}` and :math:`\boldsymbol{h}_{im_{2}}`. .. math:: \boldsymbol{h}_{im_{1}} &\leftarrow\frac{\boldsymbol{h}_{im_{1}}} {\sqrt{\boldsymbol{h}_{im_{1}}^{\mathsf{H}} \left({\boldsymbol{P}_{im_{1}}^{(m_{1},m_{2})}}^{\mathsf{H}}\boldsymbol{U}_{im_{1}} \boldsymbol{P}_{im_{1}}^{(m_{1},m_{2})}\right) \boldsymbol{h}_{im_{1}}}}, \\ \boldsymbol{h}_{im_{2}} &\leftarrow\frac{\boldsymbol{h}_{im_{2}}} {\sqrt{\boldsymbol{h}_{im_{2}}^{\mathsf{H}} \left({\boldsymbol{P}_{im_{2}}^{(m_{1},m_{2})}}^{\mathsf{H}}\boldsymbol{U}_{im_{2}} \boldsymbol{P}_{im_{2}}^{(m_{1},m_{2})}\right) \boldsymbol{h}_{im_{2}}}}. Then, update :math:`\boldsymbol{q}_{im_{1}}` and :math:`\boldsymbol{q}_{im_{2}}` simultaneously. .. math:: \boldsymbol{q}_{im_{1}} &\leftarrow \boldsymbol{P}_{im_{1}}^{(m_{1},m_{2})}\boldsymbol{h}_{im_{1}} \\ \boldsymbol{q}_{im_{2}} &\leftarrow \boldsymbol{P}_{im_{2}}^{(m_{1},m_{2})}\boldsymbol{h}_{im_{2}} At each iteration, we update pairs of :math:`m_{1}` and :math:`m_{2}` for :math:`m_{1}\neq m_{2}`. """ assert not self.partitioning, "partitioning function is not supported." na = np.newaxis flooring_fn = choose_flooring_fn(flooring_fn, method=self) X = self.input T, V = self.basis, self.activation Q, D = self.diagonalizer, self.spatial if self.partitioning: Lamb = self.reconstruct_nmf(T, V, latent=self.latent) else: Lamb = self.reconstruct_nmf(T, V) XX_Hermite = X[:, na, :, :] * X[na, :, :, :].conj() XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3) Lamb = Lamb.transpose(1, 0, 2) LambD = np.sum(Lamb[:, :, na, :] * D[:, :, :, na], axis=1) varphi = 1 / LambD varphi_XX = varphi[:, :, na, na, :] * XX_Hermite[:, na, :, :, :] U = np.mean(varphi_XX, axis=-1) self.diagonalizer = update_by_ip2( Q, U, flooring_fn=flooring_fn, pair_selector=self.pair_selector ) def update_spatial(self) -> None: r"""Update diagonal elements of spatial covariance matrix by MM algorithm. Update :math:`d_{inm}` as follows: .. math:: d_{inm}\leftarrow\left[ \dfrac{\displaystyle\sum_{j}\frac{\lambda_{ijn}|\boldsymbol{q}_{im}^{\mathsf{H}}\boldsymbol{x}_{ij}|^{2}} {\left(\sum_{n'}\lambda_{ijn'}d_{in'm}\right)^{2}}} {\displaystyle\sum_{j}\frac{\lambda_{ijn}}{\sum_{n'}\lambda_{ijn'}d_{in'm}}} \right]^{\frac{1}{2}}d_{inm}. """ assert not self.partitioning, "partitioning function is not supported." na = np.newaxis X = self.input T, V = self.basis, self.activation Q, D = self.diagonalizer, self.spatial if self.partitioning: Lamb = self.reconstruct_nmf(T, V, latent=self.latent) else: Lamb = self.reconstruct_nmf(T, V) QX = Q @ X.transpose(1, 0, 2) QX = np.abs(QX) QX2 = QX**2 Lamb = Lamb.transpose(1, 0, 2) LambD = np.sum(Lamb[:, :, na, :] * D[:, :, :, na], axis=1) LambD2 = LambD**2 Lamb_LambD2 = Lamb[:, :, na] / LambD2[:, na, :] num = np.sum(Lamb_LambD2 * QX2[:, na, :, :], axis=-1) Lamb_LambD = Lamb[:, :, na] / LambD[:, na, :] denom = np.sum(Lamb_LambD, axis=-1) D = np.sqrt(num / denom) * D self.spatial = D ================================================ FILE: ssspy/bss/pdsbss.py ================================================ import warnings from typing import Callable, List, Optional, Union import numpy as np from ..linalg import prox from .proxbss import ProxBSSBase EPS = 1e-10 __all__ = ["PDSBSS", "MaskingPDSBSS"] class PDSBSSBase(ProxBSSBase): r"""Base class of blind source separation \ via proximal splitting algorithm [#yatabe2018determined]_. Args: penalty_fn (callable): Penalty function that determines source model. prox_penalty (callable): Proximal operator of penalty function. Default: ``None``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` explicitly. Default: ``True``. record_loss (bool): Record the loss at each iteration of the update algorithm if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back. Default: ``0``. .. [#yatabe2018determined] K. Yatabe and D. Kitamura, "Determined blind source separation via proximal splitting algorithm," in *Proc. ICASSP*, 2018, pp. 776-780. """ def __repr__(self) -> str: s = "PDSBSS(" s += "n_penalties={n_penalties}".format(n_penalties=self.n_penalties) s += ", scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) class PDSBSS(PDSBSSBase): r"""Blind source separation via proximal splitting algorithm [#yatabe2018determined]_. Args: mu1 (float): Step size. Default: ``1``. mu2 (float): Step size. Default: ``1``. alpha (float): Relaxation parameter (deprecated). Set ``relaxation`` instead. relaxation (float): Relaxation parameter. Default: ``1``. penalty_fn (callable, optional): Penalty function that determines source model. prox_penalty (callable): Proximal operator of penalty function. Default: ``None``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` explicitly. Default: ``True``. record_loss (bool, optional): Record the loss at each iteration of the update algorithm if ``record_loss=True``. Default: ``None``. reference_id (int): Reference channel for projection back. Default: ``0``. """ def __init__( self, mu1: float = 1, mu2: float = 1, alpha: float = None, relaxation: float = 1, penalty_fn: Optional[Callable[[np.ndarray, np.ndarray], float]] = None, prox_penalty: Callable[[np.ndarray, float], np.ndarray] = None, callbacks: Optional[ Union[Callable[["PDSBSS"], None], List[Callable[["PDSBSS"], None]]] ] = None, scale_restoration: bool = True, record_loss: Optional[bool] = None, reference_id: int = 0, ) -> None: super().__init__( penalty_fn=penalty_fn, prox_penalty=prox_penalty, callbacks=callbacks, scale_restoration=scale_restoration, record_loss=record_loss, reference_id=reference_id, ) self.mu1, self.mu2 = mu1, mu2 if alpha is None: self.relaxation = relaxation else: assert relaxation == 1, "You cannot specify relaxation and alpha simultaneously." warnings.warn("alpha is deprecated. Set relaxation instead.", DeprecationWarning) self.relaxation = alpha def __call__(self, input, n_iter=100, initial_call: bool = True, **kwargs) -> np.ndarray: r"""Separate a frequency-domain multichannel signal. Args: input (numpy.ndarray): Mixture signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). n_iter (int): Number of iterations of demixing filter updates. Default: ``100``. initial_call (bool): If ``True``, perform callbacks (and computation of loss if necessary) before iterations. Returns: numpy.ndarray of the separated signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). """ self.input = input.copy() self._reset(**kwargs) # Call __call__ of PDSBSSBase's parent, i.e. __call__ of IterativeMethodBase super(PDSBSSBase, self).__call__(n_iter=n_iter, initial_call=initial_call) if self.scale_restoration: self.restore_scale() self.output = self.separate(self.input, demix_filter=self.demix_filter) return self.output def __repr__(self) -> str: s = "PDSBSS(" s += "mu1={mu1}, mu2={mu2}" s += ", relaxation={relaxation}" s += ", n_penalties={n_penalties}".format(n_penalties=self.n_penalties) s += ", scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) def _reset(self, **kwargs) -> None: r"""Reset attributes by given keyword arguments. Args: kwargs: Keyword arguments to set as attributes of PDSBSS. """ super()._reset(**kwargs) n_penalties = self.n_penalties n_sources = self.n_sources n_bins, n_frames = self.n_bins, self.n_frames if not hasattr(self, "dual"): dual = np.zeros((n_penalties, n_sources, n_bins, n_frames), dtype=np.complex128) else: if self.dual is None: dual = None else: # To avoid overwriting ``dual`` given by keyword arguments. dual = self.dual.copy() self.dual = dual def update_once(self) -> None: r"""Update demixing filters and dual parameters once.""" mu1, mu2 = self.mu1, self.mu2 alpha = self.relaxation Y = self.dual X, W = self.input, self.demix_filter Y_sum = Y.sum(axis=0) XY = Y_sum.transpose(1, 0, 2) @ X.transpose(1, 2, 0).conj() W_tilde = prox.neg_logdet(W - mu1 * mu2 * XY, step_size=mu1) XW = self.separate(X, demix_filter=2 * W_tilde - W) Y_tilde = [] for Y_q, prox_penalty in zip(Y, self.prox_penalty): Z_q = Y_q + XW Y_tilde_q = Z_q - prox_penalty(Z_q, step_size=1 / mu2) Y_tilde.append(Y_tilde_q) Y_tilde = np.stack(Y_tilde, axis=0) self.demix_filter = alpha * W_tilde + (1 - alpha) * W self.dual = alpha * Y_tilde + (1 - alpha) * Y class MaskingPDSBSS(PDSBSSBase): r"""Blind source separation via proximal splitting algorithm with masking [#yatabe2019time]_. Args: mu1 (float): Step size. Default: ``1``. mu2 (float): Step size. Default: ``1``. alpha (float): Relaxation parameter (deprecated). Set ``relaxation`` instead. relaxation (float): Relaxation parameter. Default: ``1``. penalty_fn (callable, optional): Penalty function that determines source model. mask_fn (callable): Proximal operator of penalty function. Default: ``None``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` explicitly. Default: ``True``. record_loss (bool): Record the loss at each iteration of the update algorithm if ``record_loss=True``. Default: ``True``. reference_id (int): Reference channel for projection back. Default: ``0``. .. [#yatabe2019time] K. Yatabe and D. Kitamura, "Time-frequency-masking-based determined BSS with application to sparse IVA," in *Proc. ICASSP*, pp. 715-719, 2019. """ def __init__( self, mu1: float = 1, mu2: float = 1, alpha: float = None, relaxation: float = 1, penalty_fn: Optional[Callable[[np.ndarray, np.ndarray], float]] = None, mask_fn: Callable[[np.ndarray], float] = None, callbacks: Optional[ Union[Callable[["MaskingPDSBSS"], None], List[Callable[["MaskingPDSBSS"], None]]] ] = None, scale_restoration: bool = True, record_loss: Optional[bool] = None, reference_id: int = 0, ) -> None: super(ProxBSSBase, self).__init__( callbacks=callbacks, record_loss=record_loss, ) if penalty_fn is None: # Since penalty_fn is not necessarily written in closed form, # None is acceptable. if record_loss is None: record_loss = False assert not record_loss, "To record loss, set penalty_fn." else: assert callable(penalty_fn), "penalty_fn should be callable." if record_loss is None: record_loss = True if mask_fn is None: raise ValueError("Specify masking function.") else: assert callable(mask_fn), "mask_fn should be callable." self.penalty_fn = penalty_fn self.mask_fn = mask_fn self.input = None self.scale_restoration = scale_restoration if reference_id is None and scale_restoration: raise ValueError("Specify 'reference_id' if scale_restoration=True.") else: self.reference_id = reference_id self.mu1, self.mu2 = mu1, mu2 if alpha is None: self.relaxation = relaxation else: assert relaxation == 1, "You cannot specify relaxation and alpha simultaneously." warnings.warn("alpha is deprecated. Set relaxation instead.", DeprecationWarning) self.relaxation = alpha def __call__(self, input, n_iter=100, initial_call: bool = True, **kwargs) -> np.ndarray: r"""Separate a frequency-domain multichannel signal. Args: input (numpy.ndarray): Mixture signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). n_iter (int): Number of iterations of demixing filter updates. Default: ``100``. initial_call (bool): If ``True``, perform callbacks (and computation of loss if necessary) before iterations. Returns: numpy.ndarray of the separated signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). """ self.input = input.copy() self._reset(**kwargs) # Call __call__ of PDSBSSBase's parent, i.e. __call__ of IterativeMethodBase super(PDSBSSBase, self).__call__(n_iter=n_iter, initial_call=initial_call) if self.scale_restoration: self.restore_scale() self.output = self.separate(self.input, demix_filter=self.demix_filter) return self.output def __repr__(self) -> str: s = "MaskingPDSBSS(" s += "mu1={mu1}, mu2={mu2}" s += ", relaxation={relaxation}" s += ", scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) def _reset(self, **kwargs) -> None: r"""Reset attributes by given keyword arguments. Args: kwargs: Keyword arguments to set as attributes of MaskingPDSBSS. """ super()._reset(**kwargs) assert self.n_penalties == 1, "Number of penalty function should be one." n_sources = self.n_sources n_bins, n_frames = self.n_bins, self.n_frames if not hasattr(self, "dual"): dual = np.zeros((n_sources, n_bins, n_frames), dtype=np.complex128) else: if self.dual is None: dual = None else: # To avoid overwriting ``dual`` given by keyword arguments. dual = self.dual.copy() self.dual = dual @property def n_penalties(self): r"""Return number of penalty terms.""" return 1 def update_once(self) -> None: r"""Update demixing filters and dual parameters once.""" mu1, mu2 = self.mu1, self.mu2 alpha = self.relaxation Y = self.dual X, W = self.input, self.demix_filter XY = Y.transpose(1, 0, 2) @ X.transpose(1, 2, 0).conj() W_tilde = prox.neg_logdet(W - mu1 * mu2 * XY, step_size=mu1) XW = self.separate(X, demix_filter=2 * W_tilde - W) Z = Y + XW Y_tilde = Z - self.mask_fn(Z) * Z self.demix_filter = alpha * W_tilde + (1 - alpha) * W self.dual = alpha * Y_tilde + (1 - alpha) * Y ================================================ FILE: ssspy/bss/proxbss.py ================================================ from typing import Callable, List, Optional, Union import numpy as np from ..algorithm import ( MINIMAL_DISTORTION_PRINCIPLE_KEYWORDS, PROJECTION_BACK_KEYWORDS, minimal_distortion_principle, projection_back, ) from .base import IterativeMethodBase EPS = 1e-10 class ProxBSSBase(IterativeMethodBase): """Base class of blind source separation via proximal gradient method. Args: penalty_fn (callable, optional): Penalty function that determines source model. prox_penalty (callable): Proximal operator of penalty function. Default: ``None``. callbacks (callable or list[callable], optional): Callback functions. Each function is called before separation and at each iteration. Default: ``None``. scale_restoration (bool or str): Technique to restore scale ambiguity. If ``scale_restoration=True``, the projection back technique is applied to estimated spectrograms. You can also specify ``projection_back`` explicitly. Default: ``True``. record_loss (bool, optional): Record the loss at each iteration of the update algorithm if ``record_loss=True``. Default: ``None``. reference_id (int): Reference channel for projection back. Default: ``0``. """ def __init__( self, penalty_fn: Optional[Callable[[np.ndarray, np.ndarray], float]] = None, prox_penalty: Callable[[np.ndarray, float], np.ndarray] = None, callbacks: Optional[ Union[Callable[["ProxBSSBase"], None], List[Callable[["ProxBSSBase"], None]]] ] = None, scale_restoration: bool = True, record_loss: Optional[bool] = None, reference_id: int = 0, ) -> None: super().__init__( callbacks=callbacks, record_loss=record_loss, ) if penalty_fn is None: # Since penalty_fn is not necessarily written in closed form, # None is acceptable. if record_loss is None: record_loss = False assert not record_loss, "To record loss, set penalty_fn." else: if callable(penalty_fn): penalty_fn = [penalty_fn] if record_loss is None: record_loss = True if prox_penalty is None: raise ValueError("Specify proximal operator of penalty function.") else: if callable(prox_penalty): prox_penalty = [prox_penalty] self.penalty_fn = penalty_fn self.prox_penalty = prox_penalty if self.penalty_fn is not None: assert len(self.penalty_fn) == len( self.prox_penalty ), "Length of penalty_fn and prox_penalty are different." self.input = None self.scale_restoration = scale_restoration if reference_id is None and scale_restoration: raise ValueError("Specify 'reference_id' if scale_restoration=True.") else: self.reference_id = reference_id def __repr__(self) -> str: s = "ProxBSSBase(" s += "n_penalties={n_penalties}".format(n_penalties=self.n_penalties) s += ", scale_restoration={scale_restoration}" s += ", record_loss={record_loss}" if self.scale_restoration: s += ", reference_id={reference_id}" s += ")" return s.format(**self.__dict__) def _reset(self, **kwargs) -> None: r"""Reset attributes by given keyword arguments. Args: kwargs: Keyword arguments to set as attributes of ProxBSSBase. """ assert self.input is not None, "Specify data!" for key in kwargs.keys(): setattr(self, key, kwargs[key]) X = self.input n_channels, n_bins, n_frames = X.shape n_sources = n_channels # n_channels == n_sources self.n_sources, self.n_channels = n_sources, n_channels self.n_bins, self.n_frames = n_bins, n_frames if not hasattr(self, "demix_filter"): W = np.eye(n_sources, n_channels, dtype=np.complex128) W = np.tile(W, reps=(n_bins, 1, 1)) else: if self.demix_filter is None: W = None else: # To avoid overwriting ``demix_filter`` given by keyword arguments. W = self.demix_filter.copy() self.demix_filter = W self.output = self.separate(X, demix_filter=W) @property def n_penalties(self): r"""Return number of penalty terms.""" # asumption of len(self.penalty_fn) == len(self.prox_penalty) return len(self.prox_penalty) def separate(self, input: np.ndarray, demix_filter: np.ndarray) -> np.ndarray: r"""Separate ``input`` using ``demixing_filter``. .. math:: \boldsymbol{y}_{ij} = \boldsymbol{W}_{i}\boldsymbol{x}_{ij} Args: input (numpy.ndarray): The mixture signal in frequency-domain. The shape is (n_channels, n_bins, n_frames). demix_filter (numpy.ndarray): The demixing filters to separate ``input``. The shape is (n_bins, n_sources, n_channels). Returns: numpy.ndarray of the separated signal in frequency-domain. The shape is (n_sources, n_bins, n_frames). """ X, W = input, demix_filter Y = W @ X.transpose(1, 0, 2) output = Y.transpose(1, 0, 2) return output def compute_loss(self) -> float: r"""Compute loss :math:`\mathcal{L}`. Returns: Computed loss. """ X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) # (n_sources, n_bins, n_frames) logdet = self.compute_logdet(W) # (n_bins,) penalty = 0 for penalty_fn in self.penalty_fn: penalty = penalty + penalty_fn(Y) loss = penalty - np.sum(logdet, axis=0) loss = loss.item() return loss def compute_logdet(self, demix_filter: np.ndarray) -> np.ndarray: r"""Compute log-determinant of demixing filter Args: demix_filter (numpy.ndarray): Demixing filters with shape of (n_bins, n_sources, n_channels). Returns: numpy.ndarray of computed log-determinant values. """ _, logdet = np.linalg.slogdet(demix_filter) # (n_bins,) return logdet def normalize_by_spectral_norm(self, input: np.ndarray, n_penalties: int = None) -> np.ndarray: r"""Spectral normalization. Args: input (numpy.ndarray): Input spectrogram with shape of (n_channels, n_bins, n_frames). n_penalties (int): Number of penalty functions, which determines coefficient of normalization. Returns: numpy.ndarray of normalized spectrogram with shape of (n_channels, n_bins, n_frames). """ if n_penalties is None: n_penalties = self.n_penalties norm = np.linalg.norm(input.transpose(1, 0, 2), ord=2, axis=(-2, -1)) norm = np.max(norm) return input / (np.sqrt(n_penalties) * norm) def restore_scale(self) -> None: r"""Restore scale ambiguity. If ``self.scale_restoration=projection_back``, we use projection back technique. """ scale_restoration = self.scale_restoration assert scale_restoration, "Set self.scale_restoration=True." if type(scale_restoration) is bool: scale_restoration = "projection_back" if scale_restoration in PROJECTION_BACK_KEYWORDS: self.apply_projection_back() elif scale_restoration in MINIMAL_DISTORTION_PRINCIPLE_KEYWORDS: self.apply_minimal_distortion_principle() else: raise ValueError("{} is not supported for scale restoration.".format(scale_restoration)) def apply_projection_back(self) -> None: r"""Apply projection back technique to estimated spectrograms.""" assert self.scale_restoration, "Set self.scale_restoration=True." X, W = self.input, self.demix_filter W_scaled = projection_back(W, reference_id=self.reference_id) Y_scaled = self.separate(X, demix_filter=W_scaled) self.output, self.demix_filter = Y_scaled, W_scaled def apply_minimal_distortion_principle(self) -> None: r"""Apply minimal distortion principle to estimated spectrograms.""" assert self.scale_restoration, "Set self.scale_restoration=True." X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) Y_scaled = minimal_distortion_principle(Y, reference=X, reference_id=self.reference_id) X = X.transpose(1, 0, 2) Y = Y_scaled.transpose(1, 0, 2) X_Hermite = X.transpose(0, 2, 1).conj() W_scaled = Y @ X_Hermite @ np.linalg.inv(X @ X_Hermite) self.output, self.demix_filter = Y_scaled, W_scaled ================================================ FILE: ssspy/io/__init__.py ================================================ import struct from io import BufferedReader, BufferedWriter from typing import Optional, Tuple import numpy as np def wavread( path: str, frame_offset: int = 0, num_frames: Optional[int] = None, return_2d: Optional[bool] = None, channels_first: Optional[bool] = None, ) -> Tuple[np.ndarray, int]: with open(path, mode="rb") as f: riff = f.read(4) # ensure byte order is little endian if riff != b"RIFF": raise NotImplementedError(f"Not support {repr(riff)}.") # total file size _ = struct.unpack(" None: assert path[-4:] == ".wav", "Only wav file is supported." if waveform.ndim == 1: _waveform = waveform n_channels = 1 elif waveform.ndim == 2: if channels_first: _waveform = waveform.transpose(1, 0) else: _waveform = waveform n_channels = _waveform.shape[1] if n_channels < 1 or 2 < n_channels: raise ValueError(f"{n_channels}channel-input is not supported.") else: raise ValueError(f"waveform.ndim should be less or equal to 2, but given {waveform.ndim}.") if _waveform.dtype in ["f2", "f4", "f8", "f16"]: bits_per_sample = 16 # float to int _waveform = _waveform * 2 ** (bits_per_sample - 1) _waveform = _waveform.astype(" Tuple[int, int, int]: fmt_chunk_size = struct.unpack(" np.ndarray: data_chunk_size = struct.unpack("= 0: shape = (n_channels * num_frames,) end_frame = frame_offset + num_frames else: raise ValueError(f"Invalid num_frames={num_frames} is given. Set nonnegative integer.") if end_frame > max_frame: raise ValueError(f"num_frames={num_frames} exceeds maximum frame {max_frame}.") data = np.memmap(f, dtype=f" 1: data = data.reshape(-1, n_channels) if channels_first: data = data.transpose(1, 0) else: if return_2d: data = data.reshape(-1, n_channels) if channels_first: data = data.transpose(1, 0) vmax = 2 ** (8 * bytes_per_sample - 1) return data / vmax def _write_fmt_chunk( f: BufferedWriter, n_channels: int, sample_rate: int, byte_rate: int, block_align: int, bits_per_sample: int, ) -> None: data = b"fmt " f.write(data) data = struct.pack(" None: data = b"data" f.write(data) data_chunk_size = waveform.nbytes data = struct.pack("= version.parse("2") def solve(a: np.ndarray, b: np.ndarray) -> np.ndarray: requires_new_axis = IS_NUMPY_GE_2 and a.ndim == b.ndim + 1 if requires_new_axis: b = b[..., np.newaxis] x = np.linalg.solve(a, b) if requires_new_axis: x = x[..., 0] b = b[..., 0] return x ================================================ FILE: ssspy/linalg/cubic.py ================================================ import numpy as np def cbrt(x: np.ndarray) -> np.ndarray: """Return cube-root of an array. Args: x (np.ndarray): Values to compute cube-root. Complex value is available. Returns: np.ndarray of cube-root. """ if np.iscomplexobj(x): amplitude = np.abs(x) phase = np.angle(x) x_cbrt = np.cbrt(amplitude) * np.exp(1j * phase / 3) else: x_cbrt = np.cbrt(x) return x_cbrt ================================================ FILE: ssspy/linalg/eigh.py ================================================ from typing import Callable, Optional, Tuple, Union import numpy as np from .inv import inv2 def eigh( A: np.ndarray, B: Optional[np.ndarray] = None, type: Optional[int] = 1 ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: r"""Compute the (generalized) eigenvalues and eigenvectors of a complex Hermitian \ (conjugate symmetric) or a real symmetric matrix. If ``B`` is ``None``, solve :math:`\boldsymbol{A}\boldsymbol{z} = \lambda\boldsymbol{z}`. If ``B`` is given, solve :math:`\boldsymbol{A}\boldsymbol{z} = \lambda\boldsymbol{B}\boldsymbol{z}`. Args: A (numpy.ndarray): A complex Hermitian matrix with shape of (\*, n_channels, n_channels). B (numpy.ndarray, optional): A complex Hermitian matrix with shape of (\*, n_channels, n_channels). type (int): For the generalized eigenproblem, this value specifies the type of problem. Only ``1``, ``2``, and ``3`` are supported. - When ``type=1``, solve :math:`\boldsymbol{Az}=\lambda\boldsymbol{Bz}`. - When ``type=2``, solve :math:`\boldsymbol{ABz}=\lambda\boldsymbol{z}`. - When ``type=3``, solve :math:`\boldsymbol{BAz}=\lambda\boldsymbol{z}`. Returns: A tuple of (eigenvalues, eigenvectors) - Eigenvalues have shape of (\*, n_channels). - Eigenvectors have shape of (\*, n_channels, n_channels). .. note:: If ``B`` is given, we use cholesky decomposition to satisfy :math:`\boldsymbol{L}\boldsymbol{L}^{\mathsf{H}}=\boldsymbol{B}`. Then, solve :math:`\boldsymbol{C}\boldsymbol{y} = \lambda\boldsymbol{y}`, where :math:`\boldsymbol{C}=\boldsymbol{L}^{-1}\boldsymbol{A}\boldsymbol{L}^{-\mathsf{H}}`. The generalized eigenvalues of :math:`\boldsymbol{A}` and :math:`\boldsymbol{B}` are computed by :math:`\boldsymbol{L}^{-\mathsf{H}}\boldsymbol{y}`. Examples: .. code-block:: python >>> import numpy as np >>> from ssspy.linalg import eigh >>> A = np.array([[1, -2j], [2j, 3]]) >>> lamb, z = eigh(A) >>> lamb; z array([-0.23606798, 4.23606798]) array([[-0.85065081+0.j , -0.52573111+0.j ], [ 0. +0.52573111j, 0. -0.85065081j]]) >>> np.allclose(A @ z, lamb * z) True .. code-block:: python >>> import numpy as np >>> from ssspy.linalg import eigh >>> A = np.array([[1, -2j], [2j, 3]]) >>> B = np.array([[2, -3j], [3j, 5]]) >>> lamb, z = eigh(A, B) >>> lamb; z array([-1.61803399, 0.61803399]) array([[ 2.22703273+0.j , -0.20081142+0.j ], [ 0. -1.37638192j, 0. -0.3249197j ]]) >>> np.allclose(A @ z, lamb * (B @ z)) True """ if B is None: return np.linalg.eigh(A) lamb, z = _eigh(A, B, type=type, inv=np.linalg.inv) return lamb, z def eigh2( A: np.ndarray, B: Optional[np.ndarray] = None, type: Optional[int] = 1 ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: r"""Compute the (generalized) eigenvalues and eigenvectors of a 2x2 complex Hermitian \ (conjugate symmetric) or a real symmetric matrix. If ``B`` is ``None``, solve :math:`\boldsymbol{A}\boldsymbol{z} = \lambda\boldsymbol{z}`. If ``B`` is given, solve :math:`\boldsymbol{A}\boldsymbol{z} = \lambda\boldsymbol{B}\boldsymbol{z}`. Args: A (numpy.ndarray): A complex Hermitian matrix with shape of (\*, 2, 2). B (numpy.ndarray, optional): A complex Hermitian matrix with shape of (\*, 2, 2). type (int): For the generalized eigenproblem, this value specifies the type of problem. Only ``1``, ``2``, and ``3`` are supported. - When ``type=1``, solve :math:`\boldsymbol{Az}=\lambda\boldsymbol{Bz}`. - When ``type=2``, solve :math:`\boldsymbol{ABz}=\lambda\boldsymbol{z}`. - When ``type=3``, solve :math:`\boldsymbol{BAz}=\lambda\boldsymbol{z}`. Returns: A tuple of (eigenvalues, eigenvectors) - Eigenvalues have shape of (\*, 2). - Eigenvectors have shape of (\*, 2, 2). .. note:: If ``B`` is given, we use cholesky decomposition to satisfy :math:`\boldsymbol{L}\boldsymbol{L}^{\mathsf{H}}=\boldsymbol{B}`. Then, solve :math:`\boldsymbol{C}\boldsymbol{y} = \lambda\boldsymbol{y}`, where :math:`\boldsymbol{C}=\boldsymbol{L}^{-1}\boldsymbol{A}\boldsymbol{L}^{-\mathsf{H}}`. The generalized eigenvalues of :math:`\boldsymbol{A}` and :math:`\boldsymbol{B}` are computed by :math:`\boldsymbol{L}^{-\mathsf{H}}\boldsymbol{y}`. See also https://github.com/tky823/ssspy/issues/115 for this implementation. Examples: .. code-block:: python >>> import numpy as np >>> from ssspy.linalg import eigh2 >>> A = np.array([[1, -2j], [2j, 3]]) >>> lamb, z = eigh2(A) >>> lamb; z array([-0.23606798, 4.23606798]) array([[-0.85065081+0.j , -0.52573111+0.j ], [ 0. +0.52573111j, 0. -0.85065081j]]) >>> np.allclose(A @ z, lamb * z) True .. code-block:: python >>> import numpy as np >>> from ssspy.linalg import eigh2 >>> A = np.array([[1, -2j], [2j, 3]]) >>> B = np.array([[2, -3j], [3j, 5]]) >>> lamb, z = eigh2(A, B) >>> lamb; z array([-1.61803399, 0.61803399]) array([[ 2.22703273+0.j , -0.20081142+0.j ], [ 0. -1.37638192j, 0. -0.3249197j ]]) >>> np.allclose(A @ z, lamb * (B @ z)) True """ assert A.shape[-2:] == (2, 2), "2x2 matrix is expected, but given shape of {}.".format(A.shape) if B is None: return np.linalg.eigh(A) lamb, z = _eigh(A, B, type=type, inv=inv2) return lamb, z def _eigh( A: np.ndarray, B: np.ndarray, type: int = 1, inv: Callable[[np.ndarray], np.ndarray] = None, ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: if inv is None: inv = np.linalg.inv L = np.linalg.cholesky(B) if type == 1: L_inv = inv(L) L_inv_Hermite = np.swapaxes(L_inv, -2, -1) if np.iscomplexobj(L_inv_Hermite): L_inv_Hermite = L_inv_Hermite.conj() C = L_inv @ A @ L_inv_Hermite elif type in [2, 3]: L_Hermite = np.swapaxes(L, -2, -1) if np.iscomplexobj(L_Hermite): L_Hermite = L_Hermite.conj() C = L_Hermite @ A @ L if type == 2: L_inv_Hermite = inv(L_Hermite) else: L_inv_Hermite = None else: raise ValueError("Invalid type={} is given.".format(type)) lamb, y = np.linalg.eigh(C) if type in [1, 2]: z = L_inv_Hermite @ y elif type == 3: z = L @ y else: raise ValueError("Invalid type={} is given.".format(type)) return lamb, z ================================================ FILE: ssspy/linalg/inv.py ================================================ import numpy as np def inv2(X: np.ndarray) -> np.ndarray: r"""Compute the (multiplicative) inverse of a 2x2 matrix. Args: X (numpy.ndarray): 2x2 matrix to be inverted. The shape is (\*, 2, 2). Returns: numpy.ndarray: (Multiplicative) inverse of the matrix X. Examples: .. code-block:: python >>> import numpy as np >>> from ssspy.linalg import inv2 >>> X = np.array([[0, 1], [2, 3]]) >>> X_inv = inv2(X) >>> np.allclose(X @ X_inv, np.eye(2)) True >>> np.allclose(X_inv @ X, np.eye(2)) True .. code-block:: python >>> import numpy as np >>> from ssspy.linalg import inv2 >>> X = np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]]) >>> inv2(X) array([[[-1.5, 0.5], [ 1. , -0. ]], [[-3.5, 2.5], [ 3. , -2. ]]]) """ shape = X.shape assert shape[-2:] == (2, 2), "2x2 matrix is expected, but given shape of {}.".format(shape) a = X[..., 0, 0] b = X[..., 0, 1] c = X[..., 1, 0] d = X[..., 1, 1] det = a * d - b * c X_adj = np.stack([d, -b, -c, a], axis=-1) X_adj = X_adj.reshape(shape[:-2] + (2, 2)) X_inv = X_adj / det[..., np.newaxis, np.newaxis] return X_inv ================================================ FILE: ssspy/linalg/lqpqm.py ================================================ import functools import warnings from typing import Callable, Optional, Union import numpy as np from ..special.flooring import identity, max_flooring from .cubic import cbrt EPS = 1e-10 def lqpqm2( H: np.ndarray, v: np.ndarray, z: np.ndarray, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), singular_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "flooring", max_iter: int = 10, ) -> None: r"""Solve of log-quadratically penelized quadratic minimization (type 2). .. math:: \check{\boldsymbol{q}}_{in} = \min_{\check{\boldsymbol{q}}_{in}} ~~\check{\boldsymbol{q}}_{in}^{\mathsf{H}}\check{\boldsymbol{q}}_{in} - \log\left((\check{\boldsymbol{q}}_{in}+\boldsymbol{v}_{in})^{\mathsf{H}} \boldsymbol{H}_{in}(\check{\boldsymbol{q}}_{in}+\boldsymbol{v}_{in}) + z_{in} \right) Args: H (numpy.ndarray): Positive semidefinite matrices of shape (n_bins, n_sources - 1, n_sources - 1). v (numpy.ndarray): Linear terms in LQPQM of shape (n_bins, n_sources - 1). z (numpy.ndarray): Constant terms in LQPQM of shape (n_bins,). flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. singular_fn (callable, optional): A flooring function to return singular condition. This function is expected to return the same shape bool tensor as the input. If ``singular_fn=None``, ``lambda x: x == 0`` is used. Default: ``flooring``. max_iter (int): Maximum number of Newton-Raphson method. Default: ``10``. Returns: np.ndarray: Solutions of LQPQM type-2 of shape (n_bins, n_sources - 1). """ if flooring_fn is None: flooring_fn = identity if singular_fn is None: def _is_zero(x: np.ndarray) -> np.ndarray: return x == 0 singular_fn = _is_zero elif singular_fn == "flooring": def _is_lower_than_floor(x: np.ndarray) -> np.ndarray: return x < flooring_fn(0) singular_fn = _is_lower_than_floor else: assert callable(singular_fn), "singular_fn should be callable." phi, sigma = np.linalg.eigh(H) norm = np.linalg.norm(v, axis=-1) is_singular = singular_fn(norm) # when v = 0 phi_singular = phi[is_singular] sigma_singular = sigma[is_singular] z_singular = z[is_singular] phi_max_singular = phi_singular[:, -1] sigma_max_singular = sigma_singular[:, -1] lamb_singular = np.maximum(z_singular, phi_max_singular) scale = (lamb_singular - z_singular) / phi_max_singular scale = np.maximum(scale, 0) scale = np.sqrt(scale) y_singular = scale[..., np.newaxis] * sigma_max_singular # when v != 0 phi_non_singular = phi[~is_singular] sigma_non_singular = sigma[~is_singular] v_non_singular = v[~is_singular] z_non_singular = z[~is_singular] v_tilde_non_singular = np.sum( sigma_non_singular.conj() * v_non_singular[:, :, np.newaxis], axis=-2 ) lamb_non_singular = solve_equation( phi_non_singular, v_tilde_non_singular, z_non_singular, flooring_fn=flooring_fn, max_iter=max_iter, normalization=True, ) num = phi_non_singular * v_tilde_non_singular denom = lamb_non_singular[..., np.newaxis] - phi_non_singular v_nonsingular = num / denom y_non_singular = np.sum(sigma_non_singular * v_nonsingular[:, np.newaxis, :], axis=-1) y = np.zeros_like(v) y[is_singular] = y_singular y[~is_singular] = y_non_singular return y def solve_equation( phi: np.ndarray, v: np.ndarray, z: np.ndarray, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), max_iter: int = 10, normalization: bool = True, ): r"""Find largest root of :math:`f(\lambda_{in})`, where .. math:: f(\lambda_{in}) = \lambda_{in}^{2}\sum_{n'} \frac{\phi_{inn'}|\tilde{v}_{inn'}|^{2}}{(\lambda_{in}-\phi_{inn'})^{2}} - \lambda_{in} + z_{in} Args: phi (numpy.ndarray): Eigen values defined in LQPQM of shape (n_bins, n_sources). v (numpy.ndarray): Linear term defined in LQPQM of shape (n_bins, n_sources). z (numpy.ndarray): Constant term defined in LQPQM of shape (n_bins,). flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. max_iter (int): Maximum iteration of Newton-Raphson method. Default: ``10``. normalization (bool): If ``True``, coefficients are normalized by ``phi_max``. Returns: numpy.ndarray of largest root of :math:`f(\lambda_{in})`. The shape is (n_bins,). """ if flooring_fn is None: flooring_fn = identity n_bins, n_sources = phi.shape non_zero_mask = phi * np.abs(v) ** 2 >= flooring_fn(0) phi = non_zero_mask * phi v = non_zero_mask * v max_index = np.argmax(phi, axis=-1) + np.arange(0, n_bins * n_sources, n_sources) phi_flatten = phi.flatten() v_flatten = v.flatten() phi_max = phi_flatten[max_index] v_max = v_flatten[max_index] phi_max = flooring_fn(phi_max) if normalization: phi_max_original = phi_max phi = phi / phi_max[:, np.newaxis] v = v / phi_max[:, np.newaxis] v_max = v_max / phi_max z = z / phi_max phi_max = phi_max / phi_max # i.e. phi_max = 1 else: phi_max_original = None # Find largest root of cubic polynomial for initialization A = -(phi_max * np.abs(v_max) ** 2 + 2 * phi_max + z) B = (phi_max + 2 * z) * phi_max C = -(phi_max**2) * z lamb = _find_largest_root(A, B, C) is_valid = lamb > phi_max lamb[~is_valid] = phi_max[~is_valid] + flooring_fn(0) lamb = np.maximum(lamb, z) for iter_idx in range(max_iter): f = _fn(lamb, phi, v, z) is_convergence = np.abs(f) <= flooring_fn(0) if np.all(is_convergence): break df = _d_fn(lamb, phi, v, z) mu = lamb - f / df lamb = np.where(mu > phi_max, mu, (phi_max + lamb) / 2) if iter_idx == max_iter - 1: f = _fn(lamb, phi, v, z) is_convergence = np.abs(f) <= flooring_fn(0) if not np.all(is_convergence): warnings.warn( f"Newton-Raphson method did not converge in {max_iter} iterations.", UserWarning ) if normalization: lamb = lamb * phi_max_original return lamb def _find_largest_root(A: np.ndarray, B: np.ndarray, C: np.ndarray) -> np.ndarray: r"""Find largest (real) roots of the following cubic equations: .. math:: x^{3} + Ax^{2} + Bx + C = 0. Args: A (numpy.ndarray): Coefficients of quadratic terms with shape of (\*). B (numpy.ndarray): Coefficients of linear terms with shape of (\*). C (numpy.ndarray): Coefficients of constant terms with shape of (\*). Returns: numpy.ndarray of largest (real) roots. .. note:: :math:`x^{3} + Ax^{2} + Bx + C = 0` can be transformed into :math:`t^{3} + pt + q = 0` by :math:`t=x+\frac{A}{3}`. When :math:`p<0` and :math:`\frac{q^{2}}{4}+\frac{p^{3}}{27}\leq 0`, there exists three real solutions: :math:`t=u-\frac{p}{3u}`, :math:`x=u\omega-\frac{p\omega^{*}}{3u}`, and :math:`x=u\omega^{*}-\frac{p\omega}{3u}`, where .. math:: u &=\sqrt[3]{-\frac{q}{2}+\sqrt{\frac{q^{2}}{4} + \frac{p^{3}}{27}}}, \\ \omega &=\frac{-1+j\sqrt{3}}{2}. When :math:`p<0` and :math:`\frac{q^{2}}{4}+\frac{p^{3}}{27}>0`, :math:`t=u-p/(3u)` is a unique real solution. When :math:`p>0`, :math:`t=u-p/(3u)` is a unique real solution. Otherwise (when :math:`p=0`), :math:`t=\sqrt[3]{-q}` is a unique real solution. """ P = -(A**2) / 3 + B Q = (2 * A**3) / 27 - (A * B) / 3 + C omega = (-1 + 1j * np.sqrt(3)) / 2 omega_conj = (-1 - 1j * np.sqrt(3)) / 2 discriminant = (Q / 2) ** 2 + (P / 3) ** 3 discriminant = discriminant.astype(np.complex128) U = cbrt(-Q / 2 + np.sqrt(discriminant)) # When U = 0, P is always 0 in real coefficients cases. is_singular = U == 0 U = np.where(is_singular, 1, U) V = -P / (3 * U) X1 = U + V X1 = np.where(is_singular, cbrt(-Q), X1) X2 = np.real(U * omega + V * omega_conj) X3 = np.real(U * omega_conj + V * omega) roots = np.stack([X1, X2, X3], axis=-1) roots = np.real(roots) is_monotonic = P >= 0 is_unique = np.array([True, False, False]) imaginary_mask = is_monotonic[..., np.newaxis] & ~is_unique roots = np.where(imaginary_mask, -float("inf"), roots) imaginary_mask = ~is_monotonic[..., np.newaxis] & ~is_unique is_positive = discriminant > 0 roots = np.where(imaginary_mask & is_positive[..., np.newaxis], -float("inf"), roots) root = np.max(roots, axis=-1) root = root - A / 3 return root def _fn(lamb: np.ndarray, phi: np.ndarray, v: np.ndarray, z: np.ndarray) -> np.ndarray: r"""Compute values of :math:`f(\lambda_{in})`, where .. math:: f(\lambda_{in}) = \lambda_{in}^{2}\sum_{n'} \frac{\phi_{inn'}|\tilde{v}_{inn'}|^{2}}{(\lambda_{in}-\phi_{inn'})^{2}} - \lambda_{in} + z_{in} Args: lamb (numpy.ndarray): Argument of :math:`f(\lambda_{in})` with shape of (n_bins,). phi (numpy.ndarray): Eigen values defined in LQPQM of shape (n_bins, n_sources). v (numpy.ndarray): Linear term defined in LQPQM of shape (n_bins, n_sources). z (numpy.ndarray): Constant term defined in LQPQM of shape (n_bins,). Returns: numpy.ndarray of values :math:`f(\lambda_{in})` of shape (n_bins,). """ num = phi * np.abs(v) ** 2 denom = (lamb[..., np.newaxis] - phi) ** 2 f = lamb**2 * np.sum(num / denom, axis=-1) - lamb + z return f def _d_fn( lamb: np.ndarray, phi: np.ndarray, v: np.ndarray, z: Optional[np.ndarray] = None, ): r"""Compute values of :math:`f'(\lambda_{in})`, where .. math:: f'(\lambda_{in}) = -2\lambda_{in}\sum_{n'} \frac{\phi_{inn'}^{2}|\tilde{v}_{inn'}|^{2}}{(\lambda_{in}-\phi_{inn'})^{3}} - 1 Args: lamb (numpy.ndarray): Argument of :math:`f'(\lambda_{in})` with shape of (n_bins,). phi (numpy.ndarray): Eigen values defined in LQPQM of shape (n_bins, n_sources). v (numpy.ndarray): Linear term defined in LQPQM of shape (n_bins, n_sources). z (numpy.ndarray, optional): Constant term defined in LQPQM of shape (n_bins,). This argument is not used in this funtion. Returns: numpy.ndarray of values :math:`f'(\lambda_{in})` of shape (n_bins,). """ num = (phi * np.abs(v)) ** 2 denom = (lamb[..., np.newaxis] - phi) ** 3 df = -2 * lamb * np.sum(num / denom, axis=-1) - 1 return df ================================================ FILE: ssspy/linalg/mean.py ================================================ import numpy as np from .eigh import eigh def gmeanmh(A: np.ndarray, B: np.ndarray, type: int = 1) -> np.ndarray: r"""Compute the geometric mean of complex Hermitian \ (conjugate symmetric) or real symmetric matrices. The geometric mean of :math:`\boldsymbol{A}` and :math:`\boldsymbol{B}` is defined as follows [#bhatia2009positive]_: .. math:: \boldsymbol{A}\#\boldsymbol{B} &= \boldsymbol{A}^{1/2} (\boldsymbol{A}^{-1/2}\boldsymbol{B}\boldsymbol{A}^{-1/2})^{1/2} \boldsymbol{A}^{1/2} \\ &= \boldsymbol{A}(\boldsymbol{A}^{-1}\boldsymbol{B})^{1/2} \\ &= (\boldsymbol{A}\boldsymbol{B}^{-1})^{1/2}\boldsymbol{B}. This is a solution of the following equation for complex Hermitian or real symmetric matrices, :math:`\boldsymbol{A}`, :math:`\boldsymbol{B}`, and :math:`\boldsymbol{X}`: .. math:: \boldsymbol{X}\boldsymbol{A}^{-1}\boldsymbol{X} = \boldsymbol{B}. .. note:: In this toolkit, :math:`\boldsymbol{A}\#\boldsymbol{B}` is computed by :math:`\boldsymbol{B}(\boldsymbol{B}^{-1}\boldsymbol{A})^{1/2}` in terms of computational speed. Note that :math:`\boldsymbol{A}\#\boldsymbol{B}` is equal to :math:`\boldsymbol{B}\#\boldsymbol{A}`. For comparison of computational time, see https://github.com/tky823/ssspy/issues/210. .. note:: :math:`(\boldsymbol{B}^{-1}\boldsymbol{A})^{1/2}` is computed by generalized eigendecomposition. Let :math:`\lambda` and :math:`z` be the eigenvalue and eigenvector of the generalized eigenproblem :math:`\boldsymbol{Az}=\lambda\boldsymbol{Bz}`. Then, :math:`(\boldsymbol{B}^{-1}\boldsymbol{A})^{1/2}` is computed by :math:`\boldsymbol{Z}\boldsymbol{\Lambda}^{1/2}\boldsymbol{Z}^{-1}`, where the main diagonals of :math:`\boldsymbol{\Lambda}` are :math:`\lambda` s and the columns of :math:`\boldsymbol{Z}` are :math:`\boldsymbol{z}` s. Args: A (numpy.ndarray): A complex Hermitian matrix with shape of (\*, n_channels, n_channels). B (numpy.ndarray): A complex Hermitian matrix with shape of (\*, n_channels, n_channels). type (int): This value specifies the type of geometric mean. Only ``1``, ``2``, and ``3`` are supported. - When ``type=1``, return :math:`\boldsymbol{A}\#\boldsymbol{B}`. - When ``type=2``, return :math:`\boldsymbol{A}^{-1}\#\boldsymbol{B}`. - When ``type=3``, return :math:`\boldsymbol{A}\#\boldsymbol{B}^{-1}`. Returns: Geometric mean of matrices with shape of (\*, n_channels, n_channels). .. [#bhatia2009positive] R. Bhatia, "Positive definite matrices," Princeton university press, 2009. """ # noqa: W605 lamb, Z = eigh(A, B, type=type) lamb = np.sqrt(lamb) Lamb = lamb[..., np.newaxis] * np.eye(Z.shape[-1]) ZLZ = Z @ Lamb @ np.linalg.inv(Z) if type == 1: BA = ZLZ G = B @ BA elif type == 2: AB = ZLZ G = np.linalg.inv(A) @ AB elif type == 3: BA = ZLZ G = np.linalg.inv(B) @ BA else: raise ValueError("Invalid type={} is given.".format(type)) return G ================================================ FILE: ssspy/linalg/polynomial.py ================================================ from typing import Optional import numpy as np from numpy.linalg import LinAlgError from .cubic import cbrt def solve_cubic( A: np.ndarray, B: np.ndarray, C: np.ndarray, D: Optional[np.ndarray] = None, all: bool = True, ) -> np.ndarray: r"""Find roots of cubic equations. Args: A (numpy.ndarray): Coefficients of cubic or quadratic terms. B (numpy.ndarray): Coefficients of quadratic or linear terms. C (numpy.ndarray): Coefficients of linear or constant terms. D (numpy.ndarray, optional): Constant terms. all (bool): If ``all=True``, returns all roots. Otherwise, returns one of them. Default: ``True``. Returns: numpy.ndarray: All roots of cuadratic equations of shape (3, \*) if ``all=True``. Otherwise, (\*). This function solves the following equations if ``D`` is given: .. math:: Ax^{3} + Bx^{2} + Cx + D = 0. If ``D`` is not given, solves .. math:: x^{3} + Ax^{2} + Bx + C = 0. """ if D is None: P = -(A**2) / 3 + B Q = (2 * A**3) / 27 - (A * B) / 3 + C X = _find_cubic_roots(P, Q) x = X - A / 3 return x if all else x[0] else: if np.any(A == 0): raise LinAlgError("Coefficients include zero.") return solve_cubic(B / A, C / A, D / A, all=all) def _find_cubic_roots(P: np.ndarray, Q: np.ndarray) -> np.ndarray: r"""Find roots of the following cubic equations: .. math:: x^{3} + px + q = 0 Args: P (np.ndarray): Coefficients of cubic equation. Q (np.ndarray): Coefficients of cubic equation. Returns: numpy.ndarray of the three roots. The shape is (3, \*). """ P = P.astype(np.complex128) Q = Q.astype(np.complex128) omega = (-1 + 1j * np.sqrt(3)) / 2 omega_conj = (-1 - 1j * np.sqrt(3)) / 2 discriminant = (Q / 2) ** 2 + (P / 3) ** 3 U = cbrt(-Q / 2 + np.sqrt(discriminant)) # U = 0, when P = 0. is_singular = P == 0 U = np.where(is_singular, 1, U) V = -P / (3 * U) X1 = U + V X1 = np.where(is_singular, cbrt(-Q), X1) X2 = U * omega + V * omega_conj X2 = np.where(is_singular, X1 * omega, X2) X3 = U * omega_conj + V * omega X3 = np.where(is_singular, X1 * omega_conj, X3) return np.stack([X1, X2, X3], axis=0) ================================================ FILE: ssspy/linalg/prox.py ================================================ import numpy as np __all__ = ["l21", "neg_log", "neg_logdet"] def l1(x, step_size: float = 1) -> np.ndarray: norm = np.abs(x) # to suppress warning RuntimeWarning norm = np.where(norm < step_size, step_size, norm) return np.maximum(1 - step_size / norm, 0) * x def l21(x: np.ndarray, step_size: float = 1, axis1: int = -2, axis2: int = -1): r"""Proximal operator of L21 norm. Args: x (numpy.ndarray): Input tensor. step_size (float): Step size parameter. Returns: numpy.ndarray: Output tensor. The shape is same as input. """ norm = np.linalg.norm(x, axis=axis2, keepdims=True) # to suppress warning RuntimeWarning norm = np.where(norm < step_size, step_size, norm) return np.maximum(1 - step_size / norm, 0) * x def neg_log(x: np.ndarray, step_size: float = 1): r"""Proximal operator of negative logarithm function. Proximal operator of :math:`-\log(x)` is defined as follows: .. math:: \mathrm{prox}_{-\mu\log}(x) = \frac{x + \sqrt{x^{2} + 4\mu}}{2} Args: x (np.ndarray): Shape is (n_bins, n_sources, n_channels). step_size (float): Step size parameter. Default: 1. Returns: np.ndarray: Proximal operator of negative logarithm function. """ assert np.all(x >= 0) output = (x + np.sqrt(x**2 + 4 * step_size)) / 2 return output def neg_logdet(X: np.ndarray, step_size=1): r"""Proximal operator of negative log-determinant. :math:`X\in\mathbb{C}^{N\times M}` .. math:: \mathrm{prox}_{-\mu\log}(\boldsymbol{X}) &= \boldsymbol{U}\tilde{\boldsymbol{\Sigma}}\boldsymbol{V}^{\mathsf{H}} \\ \tilde{\boldsymbol{\Sigma}} &= \mathrm{diag}(\mathrm{prox}_{-\mu\log}(\sigma_{1}), \ldots,\mathrm{prox}_{-\mu\log}(\sigma_{M})) Args: X (np.ndarray): Shape is (n_bins, n_sources, n_channels). step_size (float): Step size parameter. Default: 1. Returns: np.ndarray: Proximal operator of log-determinant. """ n_channels = X.shape[-1] U, Sigma, V = np.linalg.svd(X) Sigma = neg_log(Sigma, step_size=step_size) Sigma = Sigma[..., np.newaxis] * np.eye(n_channels) USV = U @ Sigma @ V return USV ================================================ FILE: ssspy/linalg/quadratic.py ================================================ import numpy as np def quadratic(X: np.ndarray, A: np.ndarray) -> np.ndarray: r"""Compute values of quadratic forms. Args: X (np.ndarray): Input vectors with shape of (\*, n_channels). A (np.ndarray): Input matrices with shape of (\*, n_channels, n_channels). Returns: Computed values of quadratic forms. The shape is (\*,). """ if np.iscomplexobj(X): X_Hermite = X.conj() else: X_Hermite = X Y = X_Hermite[..., np.newaxis, :] @ A @ X[..., np.newaxis] Y = Y[..., 0, 0] return Y ================================================ FILE: ssspy/linalg/sqrtm.py ================================================ from typing import Callable, Optional import numpy as np from .eigh import eigh def sqrtmh(X: np.ndarray) -> np.ndarray: r"""Compute square root of a positive semidefinite Hermitian or symmetric matrix. Args: X (numpy.ndarray): A complex Hermitian or symmetric matrix with shape of (\*, n_channels, n_channels). Returns: numpy.ndarray of square root. The shape is same as that of input. """ Lamb, P = eigh(X) P_Hermite = P.swapaxes(-2, -1) if np.iscomplexobj(X): P_Hermite = P_Hermite.conj() Lamb = np.sqrt(Lamb)[..., np.newaxis] * np.eye(Lamb.shape[-1]) return P @ Lamb @ P_Hermite def invsqrtmh( X: np.ndarray, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = None, ) -> np.ndarray: r"""Compute inversion of square root for a positive definite Hermitian or symmetric matrix. Args: X (numpy.ndarray): A complex Hermitian matrix with shape of (\*, n_channels, n_channels). flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to receive and return the same shape as that of X. By default, the identity function (``lambda x: x``) is used. Returns: numpy.ndarray of inversion of square root. The shape is same as that of input. """ def _identity(x): return x if flooring_fn is None: flooring_fn = _identity Lamb, P = eigh(X) P_Hermite = P.swapaxes(-2, -1) if np.iscomplexobj(X): P_Hermite = P_Hermite.conj() Lamb = 1 / flooring_fn(np.sqrt(Lamb)) Lamb = Lamb[..., np.newaxis] * np.eye(Lamb.shape[-1]) return P @ Lamb @ P_Hermite ================================================ FILE: ssspy/special/__init__.py ================================================ from .flooring import add_flooring, identity, max_flooring from .logsumexp import logsumexp from .psd import to_psd from .softmax import softmax __all__ = ["add_flooring", "max_flooring", "identity", "to_psd", "logsumexp", "softmax"] ================================================ FILE: ssspy/special/flooring.py ================================================ import numpy as np EPS = 1e-10 def identity(input: np.ndarray) -> np.ndarray: r"""Identity function.""" return input def max_flooring(input: np.ndarray, eps: float = EPS) -> np.ndarray: r"""Max flooring operation.""" return np.maximum(input, eps) def add_flooring(input: np.ndarray, eps: float = EPS) -> np.ndarray: r"""Add flooring operation.""" return input + eps ================================================ FILE: ssspy/special/logsumexp.py ================================================ import numpy as np def logsumexp(X: np.ndarray, axis: int = None, keepdims: bool = False) -> np.ndarray: r"""Compute log-sum-exp values. Args: X (np.ndarray): Elements to compute log-sum-exp. axis (int or tuple[int], optional): Axis or axes over which the sum is performed. Default: ``None``. keepdims (bool): If ``True`` is given, ``axis`` dimension(s) is reduced. Default: ``False``. Returns: np.ndarray of log-sum-exp values. Examples: .. code-block:: python >>> import numpy as np >>> X = np.array([[1, 2, 3], [4, 5, 6]]) >>> logsumexp(X, axis=0) array([4.04858735, 5.04858735, 6.04858735]) >>> logsumexp(X, axis=1) array([3.40760596, 6.40760596]) """ vmax = np.max(X, axis=axis, keepdims=True) exp = np.exp(X - vmax) sum_exp = exp.sum(axis=axis, keepdims=True) v = np.log(sum_exp) + vmax if not keepdims: v = np.squeeze(v, axis=axis) return v ================================================ FILE: ssspy/special/psd.py ================================================ import functools from typing import Callable, Optional import numpy as np from ..special.flooring import identity, max_flooring EPS = 1e-10 def to_psd( X: np.ndarray, axis1: int = -2, axis2: int = -1, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial( max_flooring, eps=EPS ), ) -> np.ndarray: r"""Ensure matrix to be positive semidefinite. Args: X (np.ndarray): A complex Hermitian matrix. axis1 (int): Axis to be used as first axis of 2D sub-arrays. axis2 (int): Axis to be used as second axis of 2D sub-arrays. flooring_fn (callable, optional): A flooring function for numerical stability. This function is expected to return the same shape tensor as the input. If you explicitly set ``flooring_fn=None``, the identity function (``lambda x: x``) is used. Default: ``functools.partial(max_flooring, eps=1e-10)``. Returns: Positive semidefinite matrix. """ if flooring_fn is None: flooring_fn = identity shape = X.shape n_dims = len(shape) axis1 = n_dims + axis1 if axis1 < 0 else axis1 axis2 = n_dims + axis2 if axis2 < 0 else axis2 assert axis1 == n_dims - 2 and axis2 == n_dims - 1, "axis1 == -2 and axis2 == -1" if np.iscomplexobj(X): X = (X + X.swapaxes(axis1, axis2).conj()) / 2 else: X = (X + X.swapaxes(axis1, axis2)) / 2 Lamb, P = np.linalg.eigh(X) P_Hermite = P.swapaxes(-2, -1) if np.iscomplexobj(X): P_Hermite = P_Hermite.conj() Lamb = flooring_fn(Lamb) Lamb = Lamb[..., np.newaxis] * np.eye(Lamb.shape[-1]) X = P @ Lamb @ P_Hermite if np.iscomplexobj(X): X = (X + X.swapaxes(axis1, axis2).conj()) / 2 else: X = (X + X.swapaxes(axis1, axis2)) / 2 return X ================================================ FILE: ssspy/special/softmax.py ================================================ import numpy as np def softmax(X: np.ndarray, axis: int = None) -> np.ndarray: r"""Compute softmax values. Args: X (np.ndarray): Elements to compute softmax. axis (int or tuple[int], optional): Axis or axes over which the sum is performed. Default: ``None``. Returns: np.ndarray of softmax values. Examples: .. code-block:: python >>> import numpy as np >>> X = np.array([[1, 2, 3], [4, 5, 6]]) >>> softmax(X, axis=0) array([[0.04742587, 0.04742587, 0.04742587], [0.95257413, 0.95257413, 0.95257413]]) >>> softmax(X, axis=1) array([[0.09003057, 0.24472847, 0.66524096], [0.09003057, 0.24472847, 0.66524096]]) """ vmax = np.max(X, axis=axis, keepdims=True) Y = X - vmax exp = np.exp(Y) v = exp / np.sum(exp, axis=axis, keepdims=True) return v ================================================ FILE: ssspy/transform/__init__.py ================================================ from .pca import pca from .whiten import whiten __all__ = ["pca", "whiten"] ================================================ FILE: ssspy/transform/pca.py ================================================ import numpy as np def pca(input: np.ndarray, ascend: bool = True) -> np.ndarray: r"""Apply principal component analysis (PCA). Args: input (numpy.ndarray): Input tensor. ascend (bool): If ``ascend=True``, first channel corresponds to first principle component. \ Otherwise, last channel corresponds to first principle component. Returns: numpy.ndarray: Output tensor. The type (real or complex) and shape are same as input. .. note:: - If ``input`` is 2D real tensor, it is regarded as (n_channels, n_samples). - If ``input`` is 3D complex tensor, it is regarded as (n_channels, n_bins, n_frames). - If ``input`` is 3D real tensor, it is regarded as (batch_size, n_channels, n_samples). - If ``input`` is 4D complex tensor, it is regarded as \ (batch_size, n_channels, n_bins, n_frames). Examples: .. code-block:: python >>> import numpy as np >>> from ssspy.transform import pca >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> n_sources = n_channels >>> rng = np.random.default_rng(42) >>> spectrogram_mix = \ ... rng.standard_normal((n_channels, n_bins, n_frames)) \ ... + 1j * rng.standard_normal((n_channels, n_bins, n_frames)) >>> spectrogram_mix_ortho = pca(spectrogram_mix) >>> spectrogram_mix_ortho.shape (2, 2049, 128) """ if input.ndim == 2: if np.iscomplexobj(input): raise ValueError("Real tensor is expected, but given complex tensor.") else: X = input.transpose(1, 0) covariance = np.mean(X[:, :, np.newaxis] * X[:, np.newaxis, :], axis=0) _, V = np.linalg.eigh(covariance) if ascend: V = V[..., ::-1] Y = X @ V output = Y.transpose(1, 0) elif input.ndim == 3: if np.iscomplexobj(input): X = input.transpose(1, 2, 0) covariance = np.mean(X[:, :, :, np.newaxis] * X[:, :, np.newaxis, :].conj(), axis=1) _, V = np.linalg.eigh(covariance) if ascend: V = V[..., ::-1] Y = X @ V.conj() output = Y.transpose(2, 0, 1) else: X = input.transpose(0, 2, 1) covariance = np.mean(X[:, :, :, np.newaxis] * X[:, :, np.newaxis, :], axis=1) _, V = np.linalg.eigh(covariance) if ascend: V = V[..., ::-1] Y = X @ V output = Y.transpose(0, 2, 1) elif input.ndim == 4: if np.iscomplexobj(input): X = input.transpose(0, 2, 3, 1) covariance = np.mean( X[:, :, :, :, np.newaxis] * X[:, :, :, np.newaxis, :].conj(), axis=2 ) _, V = np.linalg.eigh(covariance) if ascend: V = V[..., ::-1] Y = X @ V.conj() output = Y.transpose(0, 3, 1, 2) else: raise ValueError("Complex tensor is expected, but given real tensor.") else: raise ValueError( "The dimension of input is expected 3 or 4, but given {}.".format(input.ndim) ) return output ================================================ FILE: ssspy/transform/whiten.py ================================================ import numpy as np def whiten(input: np.ndarray) -> np.ndarray: r"""Apply whitening (a.k.a sphering). Args: input (numpy.ndarray): Input tensor to be whitened. Returns: numpy.ndarray of Whitened tensor. The type (real or complex) and shape are same as input. .. note:: - If ``input`` is 2D real tensor, it is regarded as (n_channels, n_samples). - If ``input`` is 3D complex tensor, it is regarded as (n_channels, n_bins, n_frames). - If ``input`` is 3D real tensor, it is regarded as (batch_size, n_channels, n_samples). - If ``input`` is 4D complex tensor, it is regarded as (batch_size, n_channels, n_bins, n_frames). Examples: .. code-block:: python >>> import numpy as np >>> from ssspy.transform import whiten >>> n_channels, n_bins, n_frames = 2, 2049, 128 >>> n_sources = n_channels >>> rng = np.random.default_rng(42) >>> spectrogram_mix = \ ... rng.standard_normal((n_channels, n_bins, n_frames)) \ ... + 1j * rng.standard_normal((n_channels, n_bins, n_frames)) >>> spectrogram_mix_whitened = whiten(spectrogram_mix) >>> spectrogram_mix_whitened.shape (2, 2049, 128) """ if input.ndim == 2: if np.iscomplexobj(input): raise ValueError("Real tensor is expected, but given complex tensor.") else: n_channels = input.shape[0] X = input.transpose(1, 0) covariance = np.mean(X[:, :, np.newaxis] * X[:, np.newaxis, :], axis=0) W, V = np.linalg.eigh(covariance) D_diag = 1 / np.sqrt(W) D_diag = np.diag(D_diag) V_transpose = V.transpose(1, 0) output = D_diag @ V_transpose @ X.transpose(1, 0) elif input.ndim == 3: if np.iscomplexobj(input): n_channels = input.shape[0] X = input.transpose(1, 2, 0) covariance = np.mean(X[:, :, :, np.newaxis] * X[:, :, np.newaxis, :].conj(), axis=1) W, V = np.linalg.eigh(covariance) D_diag = 1 / np.sqrt(W) D_diag = D_diag[:, :, np.newaxis] D_diag = D_diag * np.eye(n_channels) V_Hermite = V.transpose(0, 2, 1).conj() Y = D_diag @ V_Hermite @ X.transpose(0, 2, 1) output = Y.transpose(1, 0, 2) else: n_channels = input.shape[1] X = input.transpose(0, 2, 1) covariance = np.mean(X[:, :, :, np.newaxis] * X[:, :, np.newaxis, :], axis=1) W, V = np.linalg.eigh(covariance) D_diag = 1 / np.sqrt(W) D_diag = D_diag[:, :, np.newaxis] D_diag = D_diag * np.eye(n_channels) V_transpose = V.transpose(0, 2, 1) output = D_diag @ V_transpose @ X.transpose(0, 2, 1) elif input.ndim == 4: if np.iscomplexobj(input): n_channels = input.shape[1] X = input.transpose(0, 2, 3, 1) covariance = np.mean( X[:, :, :, :, np.newaxis] * X[:, :, :, np.newaxis, :].conj(), axis=2 ) W, V = np.linalg.eigh(covariance) D_diag = 1 / np.sqrt(W) D_diag = D_diag[:, :, :, np.newaxis] D_diag = D_diag * np.eye(n_channels) V_Hermite = V.transpose(0, 1, 3, 2).conj() Y = D_diag @ V_Hermite @ X.transpose(0, 1, 3, 2) output = Y.transpose(0, 2, 1, 3) else: raise ValueError("Complex tensor is expected, but given real tensor.") else: raise ValueError( "The dimension of input is expected 2, 3, or 4, but given {}.".format(input.ndim) ) return output ================================================ FILE: ssspy/utils/__init__.py ================================================ ================================================ FILE: ssspy/utils/dataset/__init__.py ================================================ from typing import Tuple import numpy as np from .mird import download as download_mird from .sisec2010 import download as download_sisec2010 __all__ = ["download_sample_speech_data"] sisec2010_tags = ["dev1_female3", "dev1_female4"] def download_sample_speech_data( sisec2010_root: str = ".data/SiSEC2010", mird_root: str = ".data/MIRD", n_sources: int = 3, sisec2010_tag: str = "dev1_female3", max_duration: float = 10, reverb_duration: float = 0.16, conv: bool = True, ) -> Tuple[np.ndarray, int]: r"""Download sample speech data to test sepration methods. This function returns source images of sample speech data. Args: sisec2010_root (str): Path to save SiSEC2010 dataset. Default: ".data/SiSEC2010". mird_root (str): Path to save MIRD dataset. Default: ".data/MIRD". n_sources (int): Number of sources included in sample data. sisec2010_tag (str): Tag of SiSEC 2010 data. Choose ``dev1_female3`` or ``dev1_female4``. Default: ``dev1_female3``. max_duration (float): Maximum duration. Default: ``160000``. reverb_duration (float): Duration of reverberation in MIRD. Choose ``0.16``, ``0.36``, ``0.61``. Default: ``0.16``. conv (bool): Convolutive mixture or not. Defalt: ``True``. Returns: Tuple of source images and sampling rate. The source images is numpy.ndarry with shape of (n_channels, n_sources, n_samples). """ assert sisec2010_tag in sisec2010_tags, "Choose sisec2010_tag from {}".format(sisec2010_tags) sample_rate = 16000 # Only 16khz is supported. max_samples = int(sample_rate * max_duration) sisec2010_npz_path = download_sisec2010( root=sisec2010_root, n_sources=n_sources, tag=sisec2010_tag ) sisec2010_npz = np.load(sisec2010_npz_path) assert sample_rate == sisec2010_npz["sample_rate"].item(), "Invalid sampling rate is detected." if conv: mird_npz_path = download_mird( root=mird_root, n_sources=n_sources, reverb_duration=reverb_duration ) mird_npz = np.load(mird_npz_path) assert sample_rate == mird_npz["sample_rate"].item(), "Invalid sampling rate is detected." waveform_src_img = [] for src_idx in range(n_sources): key = "src_{}".format(src_idx + 1) waveform_src = sisec2010_npz[key][:max_samples] n_samples = len(waveform_src) _waveform_src_img = [] for waveform_rir in mird_npz[key]: waveform_conv = np.convolve(waveform_src, waveform_rir)[:n_samples] _waveform_src_img.append(waveform_conv) _waveform_src_img = np.stack(_waveform_src_img, axis=0) # (n_channels, n_samples) waveform_src_img.append(_waveform_src_img) waveform_src_img = np.stack(waveform_src_img, axis=1) # (n_channels, n_sources, n_samples) else: waveform_src_img = [] rng = np.random.default_rng(seed=42) mixing = rng.standard_normal((n_sources, n_sources)) for src_idx in range(n_sources): key = "src_{}".format(src_idx + 1) _mixing = mixing[:, src_idx] waveform_src = sisec2010_npz[key][:max_samples] _waveform_src_img = _mixing[:, np.newaxis] * waveform_src waveform_src_img.append(_waveform_src_img) waveform_src_img = np.stack(waveform_src_img, axis=1) # (n_channels, n_sources, n_samples) return waveform_src_img, sample_rate ================================================ FILE: ssspy/utils/dataset/mird.py ================================================ import os import shutil import urllib.request import numpy as np reverb_durations = [0.16, 0.36, 0.61] def download(root: str = ".data/MIRD", n_sources: int = 3, reverb_duration: float = 0.16) -> str: assert reverb_duration in reverb_durations, "reverb_duration should be chosen from {}.".format( reverb_durations ) filename = ( "Impulse_response_Acoustic_Lab_Bar-Ilan_University__" "Reverberation_{reverb_duration:.3f}s__3-3-3-8-3-3-3.zip" ) filename = filename.format(reverb_duration=reverb_duration) url = ( "https://www.iks.rwth-aachen.de/fileadmin/user_upload/downloads/" "forschung/tools-downloads/{filename}" ) url = url.format(filename=filename) zip_path = os.path.join(root, filename) degrees = [30, 345, 0, 60, 315] channels = [3, 4, 2, 5, 1, 6, 0, 7] sample_rate = 16000 duration = reverb_duration degrees = degrees[:n_sources] channels = channels[:n_sources] n_channels = len(channels) n_samples = int(sample_rate * duration) template_rir_name = ( "Impulse_response_Acoustic_Lab_Bar-Ilan_University_" "(Reverberation_{:.3f}s)_3-3-3-8-3-3-3_1m_{:03d}.mat" ) os.makedirs(root, exist_ok=True) if not os.path.exists(zip_path): urllib.request.urlretrieve(url, zip_path) rir_path = os.path.join(root, template_rir_name.format(reverb_duration, 0)) if not os.path.exists(rir_path): shutil.unpack_archive(zip_path, root) npz_path = os.path.join(root, "MIRD-{}ch.npz".format(n_channels)) assert n_channels == n_sources, "Mixing system should be determined." if not os.path.exists(npz_path): rirs = {} for src_idx, degree in enumerate(degrees): rir_path = os.path.join(root, template_rir_name.format(duration, degree)) rir = resample_mird_rir(rir_path, sample_rate_out=sample_rate) rirs["src_{}".format(src_idx + 1)] = rir[channels, :n_samples] np.savez( npz_path, sample_rate=sample_rate, n_sources=n_sources, n_channels=n_channels, **rirs ) return npz_path def resample_mird_rir(rir_path: str, sample_rate_out: int) -> np.ndarray: import scipy.signal as ss from scipy.io import loadmat sample_rate_in = 48000 rir_mat = loadmat(rir_path) rir = rir_mat["impulse_response"] rir_resampled = ss.resample_poly(rir, sample_rate_out, sample_rate_in, axis=0) return rir_resampled.T ================================================ FILE: ssspy/utils/dataset/sisec2010.py ================================================ import os import shutil import urllib.request import numpy as np from ...io import wavread def download(root: str = ".data/SiSEC2010", n_sources: int = 3, tag: str = "dev1_female3") -> str: filename = "dev1.zip" url = "http://www.irisa.fr/metiss/SiSEC10/underdetermined/{}".format(filename) zip_path = os.path.join(root, filename) os.makedirs(root, exist_ok=True) if not os.path.exists(zip_path): urllib.request.urlretrieve(url, zip_path) if not os.path.exists(os.path.join(root, "{}_inst_matrix.mat".format(tag))): shutil.unpack_archive(zip_path, root) source_paths = [] for src_idx in range(n_sources): source_path = os.path.join(root, "{}_src_{}.wav".format(tag, src_idx + 1)) source_paths.append(source_path) channels = [3, 4, 2, 5] sample_rate = 16000 source_paths = source_paths[:n_sources] channels = channels[:n_sources] n_channels = len(channels) npz_path = os.path.join(root, "SiSEC2010-{}ch.npz".format(n_channels)) assert n_channels == n_sources, "Mixing system should be determined." if not os.path.exists(npz_path): dry_sources = {} for src_idx, source_path in enumerate(source_paths): data, _ = wavread(source_path, return_2d=False) dry_sources["src_{}".format(src_idx + 1)] = data np.savez( npz_path, sample_rate=sample_rate, n_sources=n_sources, n_channels=n_channels, **dry_sources, ) return npz_path ================================================ FILE: ssspy/utils/flooring.py ================================================ from typing import Any, Callable, Optional, Union import numpy as np from ..special.flooring import identity def choose_flooring_fn( flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self", method: Optional[Any] = None, ) -> Callable[[np.ndarray], np.ndarray]: if flooring_fn is None: assert method is None, "method is given, but flooring function is not specified." flooring_fn = identity elif type(flooring_fn) is str and flooring_fn == "self": if method is None or not hasattr(method, "flooring_fn"): flooring_fn = identity else: flooring_fn = method.flooring_fn assert callable(flooring_fn), "flooring_fn should be callable." return flooring_fn ================================================ FILE: ssspy/utils/select_pair.py ================================================ import itertools from typing import Iterable, Optional, Tuple def sequential_pair_selector( n_sources: int, stop: Optional[int] = None, step: int = 1, sort: bool = False ) -> Iterable[Tuple[int, int]]: r"""Select pair in pairwise update. Args: n_sources (int): Number of sources. step (int): This parameter determines step size. For instance, if ``sequential_pair_selector(n_sources=6, step=2, sort=False)``, this function yields ``0, 1``, ``2, 3``, ``4, 5``, ``0, 1``, ``2, 3``, ``4, 5``. Default: ``1``. sort (bool): Sort pair to ensure :math:`m>> for m, n in combination_pair_selector(4): ... print(m, n) 0 1 1 2 2 3 3 0 """ if stop is None: stop = n_sources for m in range(0, stop, step): m, n = m % n_sources, (m + 1) % n_sources if sort: m, n = (n, m) if m > n else (m, n) yield m, n def combination_pair_selector(n_sources: int, sort: bool = False) -> Iterable[Tuple[int, int]]: r"""Select pair in pairwise update. Args: n_sources (int): Number of sources. sort (bool): Sort pair to ensure :math:`m>> for m, n in combination_pair_selector(4): ... print(m, n) 0 1 0 2 0 3 1 2 1 3 2 3 """ for m, n in itertools.combinations(range(n_sources), 2): if sort: m, n = (n, m) if m > n else (m, n) yield m, n ================================================ FILE: tests/conftest.py ================================================ # conftest.py is based on # https://docs.pytest.org/en/latest/example/simple.html#control-skipping-of-tests-according-to-command-line-option # and # https://docs.pytest.org/en/latest/deprecations.html#pytest-namespace import pytest def pytest_addoption(parser): parser.addoption( "--run-redundant", action="store_true", default=False, help="Run redandant tests." ) def pytest_configure(): pytest.run_redundant = False def pytest_collection_modifyitems(config, items): if config.getoption("--run-redundant"): pytest.run_redundant = True ================================================ FILE: tests/dummy/callback.py ================================================ def dummy_function(_) -> None: pass class DummyCallback: def __init__(self) -> None: pass def __call__(self, _) -> None: pass ================================================ FILE: tests/dummy/io.py ================================================ import os import struct import numpy as np def save_invalid_wavfile( path: str, invalid_riff: bool = False, invalid_ftype: bool = False, invalid_fmt_chunk_marker: bool = False, invalid_fmt_chunk_size: bool = False, invalid_fmt: bool = False, invalid_byte_rate: bool = False, invalid_data_chunk_marker: bool = False, ) -> None: os.makedirs(os.path.dirname(path), exist_ok=True) n_channels = 1 sample_rate = 16000 bits_per_sample = 16 duration = 5 byte_rate = (bits_per_sample * sample_rate * n_channels) // 8 block_align = byte_rate // sample_rate total_file_size = byte_rate * duration + 44 rng = np.random.default_rng(42) num_frames = sample_rate * duration bytes_per_sample = block_align // n_channels vmax = 2 ** (bits_per_sample - 1) valid_file_size = 0 with open(path, mode="wb") as f: if invalid_riff: data = b"RIFX" else: data = b"RIFF" f.write(data) valid_file_size += 4 data = struct.pack(" Tuple[np.ndarray, int]: hash = hashlib.sha256(sisec2010_root.encode("utf-8")).hexdigest() hash += hashlib.sha256(mird_root.encode("utf-8")).hexdigest() hash += hashlib.sha256(str(n_sources).encode("utf-8")).hexdigest() hash += hashlib.sha256(sisec2010_tag.encode("utf-8")).hexdigest() hash += hashlib.sha256(str(max_duration).encode("utf-8")).hexdigest() hash += hashlib.sha256(str(conv).encode("utf-8")).hexdigest() # because concatenated hash is too long hash = hashlib.sha256(hash.encode("utf-8")).hexdigest() npz_path = os.path.join(cache_dir, "{}.npz".format(hash)) if os.path.exists(npz_path): npz = np.load(npz_path) waveform_src_img, sample_rate = npz["waveform_src_img"], npz["sample_rate"] sample_rate = sample_rate.item() else: waveform_src_img, sample_rate = _download( sisec2010_root=sisec2010_root, mird_root=mird_root, n_sources=n_sources, sisec2010_tag=sisec2010_tag, max_duration=max_duration, conv=conv, ) os.makedirs(cache_dir, exist_ok=True) np.savez(npz_path, waveform_src_img=waveform_src_img, sample_rate=sample_rate) return waveform_src_img, sample_rate def download_ssspy_data(path: str, filename: Optional[str] = None, branch: str = "main") -> None: """Download file from https://github.com/tky823/ssspy-data. Args: path (str): Path to file in https://github.com/tky823/ssspy-data. filename (str, optional): File name to save data. If ``None``, base name of ``path`` is used. branch (str, optional): Branch name of https://github.com/tky823/ssspy-data. """ url = f"https://github.com/tky823/ssspy-data/raw/{branch}/{path}" if filename is None: filename = os.path.basename(url) root = os.path.dirname(filename) if root: os.makedirs(root, exist_ok=True) if not os.path.exists(filename): urllib.request.urlretrieve(url, filename) def load_regression_data(root: str, filenames: Optional[List[str]] = None) -> Tuple: """Load regression data. Args: root (str): Root to save regression data, where url.json is placed. filenames (str, optional): Filenames to download. Returns: tuple: Tuple containing data of specified filenames. """ url_json_path = os.path.join(root, "url.json") with open(url_json_path) as f: urls = json.load(f) if filenames is None: warnings.warn("It is recommended to specify filenames to ensure order.", UserWarning) filenames = [] for file in urls["files"]: filename = file["filename"] filenames.append(filename) npz = {} for file in urls["files"]: filename = file["filename"] location = file["location"] if filename not in filenames: continue path = os.path.join(root, filename) download_ssspy_data(location, path) npz[filename] = np.load(path) sorted_npz = [] for filename in filenames: sorted_npz.append(npz[filename]) return tuple(sorted_npz) ================================================ FILE: tests/mock/regression/bss/cacgmm/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/cacgmm/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/fdica/aux_laplace_fdica/IP1/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/fdica/aux_laplace_fdica/IP1/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/fdica/aux_laplace_fdica/IP2/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/fdica/aux_laplace_fdica/IP2/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/fdica/grad_laplace_fdica/holonomic/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/fdica/grad_laplace_fdica/holonomic/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/fdica/grad_laplace_fdica/nonholonomic/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/fdica/grad_laplace_fdica/nonholonomic/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/fdica/natural_grad_laplace_fdica/holonomic/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/fdica/natural_grad_laplace_fdica/holonomic/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/fdica/natural_grad_laplace_fdica/nonholonomic/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/fdica/natural_grad_laplace_fdica/nonholonomic/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/ilrma/gauss_ilrma/IP1/ME/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/ilrma/gauss_ilrma/IP1/ME/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/ilrma/gauss_ilrma/IP1/MM/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/ilrma/gauss_ilrma/IP1/MM/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/ilrma/gauss_ilrma/IP2/ME/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/ilrma/gauss_ilrma/IP2/ME/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/ilrma/gauss_ilrma/IP2/MM/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/ilrma/gauss_ilrma/IP2/MM/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/ilrma/gauss_ilrma/IPA/ME/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/ilrma/gauss_ilrma/IPA/ME/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/ilrma/gauss_ilrma/IPA/MM/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/ilrma/gauss_ilrma/IPA/MM/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/ilrma/gauss_ilrma/ISS1/ME/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/ilrma/gauss_ilrma/ISS1/ME/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/ilrma/gauss_ilrma/ISS1/MM/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/ilrma/gauss_ilrma/ISS1/MM/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/ilrma/gauss_ilrma/ISS2/ME/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/ilrma/gauss_ilrma/ISS2/ME/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/ilrma/gauss_ilrma/ISS2/MM/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/ilrma/gauss_ilrma/ISS2/MM/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/ilrma/ggd_ilrma/IP1/MM/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/ilrma/ggd_ilrma/IP1/MM/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/ilrma/ggd_ilrma/IP2/MM/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/ilrma/ggd_ilrma/IP2/MM/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/ilrma/ggd_ilrma/ISS1/MM/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/ilrma/ggd_ilrma/ISS1/MM/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/ilrma/ggd_ilrma/ISS2/MM/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/ilrma/ggd_ilrma/ISS2/MM/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/ilrma/t_ilrma/IP1/ME/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/ilrma/t_ilrma/IP1/ME/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/ilrma/t_ilrma/IP1/MM/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/ilrma/t_ilrma/IP1/MM/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/ilrma/t_ilrma/IP2/ME/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/ilrma/t_ilrma/IP2/ME/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/ilrma/t_ilrma/IP2/MM/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/ilrma/t_ilrma/IP2/MM/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/ilrma/t_ilrma/ISS1/ME/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/ilrma/t_ilrma/ISS1/ME/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/ilrma/t_ilrma/ISS1/MM/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/ilrma/t_ilrma/ISS1/MM/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/ilrma/t_ilrma/ISS2/ME/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/ilrma/t_ilrma/ISS2/ME/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/ilrma/t_ilrma/ISS2/MM/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/ilrma/t_ilrma/ISS2/MM/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/ipsdta/gauss_ipsdta/VCD/MM/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/ipsdta/gauss_ipsdta/VCD/MM/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/ipsdta/t_ipsdta/VCD/MM/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/ipsdta/t_ipsdta/VCD/MM/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/iva/aux_iva/IP1/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/iva/aux_iva/IP1/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/iva/aux_iva/IP2/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/iva/aux_iva/IP2/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/iva/aux_iva/IPA/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/iva/aux_iva/IPA/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/iva/aux_iva/ISS1/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/iva/aux_iva/ISS1/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/iva/aux_iva/ISS2/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/iva/aux_iva/ISS2/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/iva/fast_iva/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/iva/fast_iva/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/iva/grad_iva/holonomic/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/iva/grad_iva/holonomic/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/iva/grad_iva/nonholonomic/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/iva/grad_iva/nonholonomic/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/iva/natural_grad_iva/holonomic/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/iva/natural_grad_iva/holonomic/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/iva/natural_grad_iva/nonholonomic/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/iva/natural_grad_iva/nonholonomic/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/mnmf/fast_gauss_mnmf/IP1/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/mnmf/fast_gauss_mnmf/IP1/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/mnmf/fast_gauss_mnmf/IP2/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/mnmf/fast_gauss_mnmf/IP2/target.npz" } ] } ================================================ FILE: tests/mock/regression/bss/mnmf/gauss_mnmf/url.json ================================================ { "files": [ { "filename": "input.npz", "location": "npz/canon_8k_reverbed.npz" }, { "filename": "target.npz", "location": "npz/bss/mnmf/gauss_mnmf/target.npz" } ] } ================================================ FILE: tests/package/algorithm/test_minimal_distortion_principle.py ================================================ from typing import Optional import numpy as np import pytest from ssspy.algorithm import minimal_distortion_principle parameters = [(2, 0), (3, 2), (2, None)] @pytest.mark.parametrize("n_sources, reference_id", parameters) def test_minimal_distortion_principle(n_sources: int, reference_id: Optional[int]): rng = np.random.default_rng(0) n_channels = n_sources n_bins, n_frames = 5, 8 spectrogram_mix = rng.standard_normal( (n_channels, n_bins, n_frames) ) + 1j * rng.standard_normal((n_channels, n_bins, n_frames)) demix_filter = rng.standard_normal((n_bins, n_sources, n_channels)) + 1j * rng.standard_normal( (n_bins, n_sources, n_channels) ) spectrogram_est = demix_filter @ spectrogram_mix.transpose(1, 0, 2) spectrogram_est = spectrogram_est.transpose(1, 0, 2) spectrogram_est_scaled = minimal_distortion_principle( spectrogram_est, spectrogram_mix, reference_id=reference_id ) if reference_id is None: for _spectrogram_est_scaled in spectrogram_est_scaled: assert spectrogram_mix.shape == _spectrogram_est_scaled.shape else: assert spectrogram_mix.shape == spectrogram_est.shape ================================================ FILE: tests/package/algorithm/test_permutation_alignment.py ================================================ import numpy as np import pytest from ssspy.algorithm.permutation_alignment import ( correlation_based_permutation_solver, score_based_permutation_solver, ) rng = np.random.default_rng(0) parameters_give_demixing_filter = [True, False] @pytest.mark.parametrize("give_demixing_filter", parameters_give_demixing_filter) def test_correlation_based_permutation_solver(give_demixing_filter: bool): n_sources = 3 n_channels = n_sources n_bins, n_frames = 4, 16 shape = (n_channels, n_bins, n_frames) mixture = rng.standard_normal(shape) + 1j * rng.standard_normal(shape) shape = (n_bins, n_sources, n_channels) demix_filter = rng.standard_normal(shape) + 1j * rng.standard_normal(shape) separated = demix_filter @ mixture.transpose(1, 0, 2) if give_demixing_filter: separated, demix_filter = correlation_based_permutation_solver(separated, demix_filter) assert demix_filter.shape == (n_bins, n_sources, n_channels) else: separated = correlation_based_permutation_solver(separated) assert separated.shape == (n_bins, n_sources, n_frames) @pytest.mark.parametrize("give_demixing_filter", parameters_give_demixing_filter) def test_score_based_permutation_solver(give_demixing_filter: bool): n_sources = 3 n_channels = n_sources n_bins, n_frames = 4, 16 shape = (n_channels, n_bins, n_frames) mixture = rng.standard_normal(shape) + 1j * rng.standard_normal(shape) shape = (n_bins, n_sources, n_channels) demix_filter = rng.standard_normal(shape) + 1j * rng.standard_normal(shape) separated = demix_filter @ mixture.transpose(1, 0, 2) if give_demixing_filter: separated, demix_filter = score_based_permutation_solver(separated, demix_filter) assert demix_filter.shape == (n_bins, n_sources, n_channels) else: separated = correlation_based_permutation_solver(separated) assert separated.shape == (n_bins, n_sources, n_frames) ================================================ FILE: tests/package/algorithm/test_projection_back.py ================================================ from typing import Optional import numpy as np import pytest from ssspy.algorithm import projection_back parameters = [(2, 0), (3, 2), (2, None)] @pytest.mark.parametrize("n_sources, reference_id", parameters) def test_projection_back_demix_filter(n_sources: int, reference_id: Optional[int]): np.random.seed(111) n_channels = n_sources n_bins, n_frames = 17, 10 spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) + 1j * np.random.randn( n_channels, n_bins, n_frames ) demix_filter = np.random.randn(n_bins, n_sources, n_channels) + 1j * np.random.randn( n_bins, n_sources, n_channels ) demix_filter_scaled = projection_back(demix_filter, reference_id=reference_id) spectrogram_est = demix_filter_scaled @ spectrogram_mix.transpose(1, 0, 2) if reference_id is None: spectrogram_est = spectrogram_est.transpose(0, 2, 1, 3) for _spectrogram_est in spectrogram_est: assert spectrogram_mix.shape == _spectrogram_est.shape else: spectrogram_est = spectrogram_est.transpose(1, 0, 2) assert spectrogram_mix.shape == spectrogram_est.shape @pytest.mark.parametrize("n_sources, reference_id", parameters) def test_projection_back_output(n_sources: int, reference_id: Optional[int]): np.random.seed(111) n_channels = n_sources n_bins, n_frames = 17, 10 spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) + 1j * np.random.randn( n_channels, n_bins, n_frames ) demix_filter = np.random.randn(n_bins, n_sources, n_channels) + 1j * np.random.randn( n_bins, n_sources, n_channels ) spectrogram_est = demix_filter @ spectrogram_mix.transpose(1, 0, 2) spectrogram_est = spectrogram_est.transpose(1, 0, 2) spectrogram_est_scaled = projection_back( spectrogram_est, reference=spectrogram_mix, reference_id=reference_id ) if reference_id is None: for _spectrogram_est_scaled in spectrogram_est_scaled: assert spectrogram_mix.shape == _spectrogram_est_scaled.shape else: assert spectrogram_mix.shape == spectrogram_est.shape ================================================ FILE: tests/package/bss/test_admmbss.py ================================================ import math from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import pytest import scipy.signal as ss from dummy.callback import DummyCallback, dummy_function from dummy.utils.dataset import download_sample_speech_data from ssspy.bss.admmbss import ADMMBSS, ADMMBSSBase, MaskingADMMBSS max_duration = 0.5 n_fft = 2048 hop_length = 1024 n_bins = n_fft // 2 + 1 n_iter = 5 parameters_admmbss = [ (2, None, {}), ( 3, dummy_function, {"demix_filter": np.tile(-np.eye(3, dtype=np.complex128), reps=(n_bins, 1, 1))}, ), (2, [DummyCallback(), dummy_function], {}), ( 2, None, { # n_frames=9 "auxiliary1": np.ones((n_bins, 2, 2), dtype=np.complex128), "auxiliary2": np.zeros((1, 2, n_bins, 9), dtype=np.complex128), "dual1": np.ones((n_bins, 2, 2), dtype=np.complex128), "dual2": np.zeros((1, 2, n_bins, 9), dtype=np.complex128), }, ), ] def contrast_fn(y: np.ndarray) -> np.ndarray: r"""Contrast function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: np.ndarray of the shape is (n_sources, n_frames). """ return 2 * np.linalg.norm(y, axis=1) def penalty_fn(y: np.ndarray) -> float: loss = contrast_fn(y) loss = np.sum(loss.mean(axis=-1)) return loss def prox_penalty(y: np.ndarray, step_size: float = 1) -> np.ndarray: r"""Proximal operator of penalty function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). step_size (float): Step size. Default: 1. Returns: np.ndarray of the shape is (n_sources, n_bins, n_frames). """ norm = np.linalg.norm(y, axis=1, keepdims=True) # to suppress warning RuntimeWarning norm = np.where(norm < step_size, step_size, norm) return y * np.maximum(1 - step_size / norm, 0) def test_admmbss_base(): admmbss = ADMMBSSBase(penalty_fn=penalty_fn, prox_penalty=prox_penalty) print(admmbss) @pytest.mark.parametrize("n_sources, callbacks, reset_kwargs", parameters_admmbss) def test_admmbss( n_sources: int, callbacks: Optional[Union[Callable[[ADMMBSS], None], List[Callable[[ADMMBSS], None]]]], reset_kwargs: Dict[Any, Any], ): np.random.seed(111) waveform_src_img, _ = download_sample_speech_data( sisec2010_root="./tests/.data/SiSEC2010", mird_root="./tests/.data/MIRD", n_sources=n_sources, sisec2010_tag="dev1_female3", max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft - hop_length ) admmbss = ADMMBSS(penalty_fn=penalty_fn, prox_penalty=prox_penalty, callbacks=callbacks) spectrogram_mix_normalized = admmbss.normalize_by_spectral_norm(spectrogram_mix) spectrogram_est = admmbss(spectrogram_mix_normalized, n_iter=n_iter, **reset_kwargs) assert spectrogram_mix.shape == spectrogram_est.shape print(admmbss) @pytest.mark.parametrize("n_sources, callbacks, reset_kwargs", parameters_admmbss) def test_masking_admmbss( n_sources: int, callbacks: Optional[Union[Callable[[ADMMBSS], None], List[Callable[[ADMMBSS], None]]]], reset_kwargs: Dict[Any, Any], ) -> None: np.random.seed(111) waveform_src_img, _ = download_sample_speech_data( sisec2010_root="./tests/.data/SiSEC2010", mird_root="./tests/.data/MIRD", n_sources=n_sources, sisec2010_tag="dev1_female3", max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft - hop_length ) def hva_mask_fn(y: np.ndarray, mask_iter: int = 2) -> np.ndarray: """Masking function to emphasize harmonic components. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: np.ndarray of mask. The shape is (n_sources, n_bins, n_frames). """ n_sources, n_bins, _ = y.shape gamma = 1 / n_sources y = np.maximum(np.abs(y), 1e-10) zeta = np.log(y) zeta_mean = zeta.mean(axis=1, keepdims=True) rho = zeta - zeta_mean nu = np.fft.irfft(rho, axis=1, norm="backward") nu = nu[:, :n_bins] varsigma = np.minimum(1, nu) for _ in range(mask_iter): varsigma = (1 - np.cos(math.pi * varsigma)) / 2 xi = np.fft.irfft(varsigma * nu, axis=1, norm="forward") xi = xi[:, :n_bins] varrho = xi + zeta_mean v = np.exp(2 * varrho) mask = (v / v.sum(axis=0)) ** gamma return mask admmbss = MaskingADMMBSS(mask_fn=hva_mask_fn, callbacks=callbacks) spectrogram_mix_normalized = admmbss.normalize_by_spectral_norm(spectrogram_mix) if "auxiliary2" in reset_kwargs: auxiliary2 = reset_kwargs.pop("auxiliary2") if auxiliary2.ndim == 4: auxiliary2 = auxiliary2.squeeze(axis=0) reset_kwargs["auxiliary2"] = auxiliary2 if "dual2" in reset_kwargs: dual2 = reset_kwargs.pop("dual2") if dual2.ndim == 4: dual2 = dual2.squeeze(axis=0) reset_kwargs["dual2"] = dual2 spectrogram_est = admmbss(spectrogram_mix_normalized, n_iter=n_iter, **reset_kwargs) assert spectrogram_mix.shape == spectrogram_est.shape print(admmbss) ================================================ FILE: tests/package/bss/test_base.py ================================================ from typing import Callable, List, Optional, Union import pytest from dummy.callback import DummyCallback, dummy_function from ssspy.bss.base import IterativeMethodBase n_iter = 3 parameters_callbacks = [None, dummy_function, [DummyCallback(), dummy_function]] parameters_record_loss = [True, False] @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("record_loss", parameters_record_loss) def test_iterative_method_base( callbacks: Optional[ Union[Callable[[IterativeMethodBase], None], List[Callable[[IterativeMethodBase], None]]] ], record_loss: bool, ): method = IterativeMethodBase(callbacks=callbacks, record_loss=record_loss) with pytest.raises(NotImplementedError) as exc_info: method(n_iter=n_iter) assert exc_info.type is NotImplementedError ================================================ FILE: tests/package/bss/test_cacgmm.py ================================================ from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import pytest import scipy.signal as ss from dummy.callback import DummyCallback, dummy_function from dummy.utils.dataset import download_sample_speech_data from ssspy.bss.cacgmm import CACGMM max_duration = 0.5 window = "hann" n_fft = 512 hop_length = 256 n_bins = n_fft // 2 + 1 n_iter = 3 rng = np.random.default_rng(42) parameters_callbacks = [None, dummy_function, [DummyCallback(), dummy_function]] parameters_permutation_alignment = [ "posterior_score", "amplitude_score", "amplitude_correlation", ] parameters_cacgmm = [(2, 2, {}), (3, 2, {})] @pytest.mark.parametrize("n_sources, n_channels, reset_kwargs", parameters_cacgmm) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("permutation_alignment", parameters_permutation_alignment) def test_cacgmm( n_sources: int, n_channels: int, callbacks: Optional[Union[Callable[[CACGMM], None], List[Callable[[CACGMM], None]]]], permutation_alignment: bool, reset_kwargs: Dict[str, Any], ): if n_sources < 4: sisec2010_tag = "dev1_female3" elif n_sources == 4: sisec2010_tag = "dev1_female4" else: raise ValueError("n_sources should be less than 5.") waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag=sisec2010_tag, max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) waveform_mix = waveform_mix[:n_channels] _, _, spectrogram_mix = ss.stft( waveform_mix, window=window, nperseg=n_fft, noverlap=n_fft - hop_length ) cacgmm = CACGMM( n_sources=n_sources, callbacks=callbacks, permutation_alignment=permutation_alignment, rng=rng, ) spectrogram_est = cacgmm(spectrogram_mix, n_iter=n_iter, **reset_kwargs) assert spectrogram_est.shape == (n_sources,) + spectrogram_mix.shape[-2:] assert type(cacgmm.loss[-1]) is float # when posterior is not given _spectrogram_est = cacgmm.separate(spectrogram_mix) assert np.allclose(_spectrogram_est, spectrogram_est) print(cacgmm) def test_cacgmm_zero_norm() -> None: """Test input with zero norm.""" n_channels, n_sources, n_samples = 2, 3, 10 * 8000 waveform_src_img = rng.standard_normal((n_channels, n_sources, n_samples)) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) waveform_mix = waveform_mix[:n_channels] _, _, spectrogram_mix = ss.stft( waveform_mix, window=window, nperseg=n_fft, noverlap=n_fft - hop_length ) # set 0 at most grids in 0th frequency bin spectrogram_mix[:, 0, 1:-1] = 0 assert np.linalg.norm(spectrogram_mix, axis=0).any() cacgmm = CACGMM(n_sources=n_sources, rng=rng) spectrogram_est = cacgmm(spectrogram_mix, n_iter=n_iter) assert spectrogram_est.shape == (n_sources,) + spectrogram_mix.shape[-2:] assert type(cacgmm.loss[-1]) is float # when posterior is not given _spectrogram_est = cacgmm.separate(spectrogram_mix) assert np.allclose(_spectrogram_est, spectrogram_est) ================================================ FILE: tests/package/bss/test_fdica.py ================================================ from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import pytest import scipy.signal as ss from dummy.callback import DummyCallback, dummy_function from dummy.utils.dataset import download_sample_speech_data from ssspy.bss.fdica import ( AuxFDICA, AuxLaplaceFDICA, GradFDICA, GradFDICABase, GradLaplaceFDICA, NaturalGradFDICA, NaturalGradLaplaceFDICA, ) max_duration = 0.5 n_fft = 512 hop_length = 256 n_bins = n_fft // 2 + 1 n_iter = 3 parameters_callbacks = [None, dummy_function, [DummyCallback(), dummy_function]] parameters_is_holonomic = [True, False] parameters_scale_restoration = [True, False, "projection_back", "minimal_distortion_principle"] parameters_spatial_algorithm = ["IP", "IP1", "IP2"] parameters_grad_fdica = [ (2, {}), ( 3, {"demix_filter": np.tile(-np.eye(3, dtype=np.complex128), reps=(n_bins, 1, 1))}, ), ] parameters_aux_fdica = [ (2, {}), ( 3, {"demix_filter": np.tile(-np.eye(3, dtype=np.complex128), reps=(n_bins, 1, 1))}, ), ] @pytest.mark.parametrize("callbacks", parameters_callbacks) def test_grad_fdica_base( callbacks: Optional[Union[Callable[[GradFDICA], None], List[Callable[[GradFDICA], None]]]], ): np.random.seed(111) def contrast_fn(y): return 2 * np.abs(y) def score_fn(y): denominator = np.maximum(np.abs(y), 1e-10) return y / denominator fdica = GradFDICABase(contrast_fn=contrast_fn, score_fn=score_fn, callbacks=callbacks) print(fdica) @pytest.mark.parametrize("n_sources, reset_kwargs", parameters_grad_fdica) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("is_holonomic", parameters_is_holonomic) def test_grad_fdica( n_sources: int, callbacks: Optional[Union[Callable[[GradFDICA], None], List[Callable[[GradFDICA], None]]]], is_holonomic: bool, reset_kwargs: Dict[Any, Any], ): np.random.seed(111) waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag="dev1_female3", max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft - hop_length ) def contrast_fn(y): return 2 * np.abs(y) def score_fn(y): denominator = np.maximum(np.abs(y), 1e-10) return y / denominator fdica = GradFDICA( contrast_fn=contrast_fn, score_fn=score_fn, callbacks=callbacks, is_holonomic=is_holonomic ) spectrogram_est = fdica(spectrogram_mix, n_iter=n_iter) assert spectrogram_mix.shape == spectrogram_est.shape assert type(fdica.loss[-1]) is float print(fdica) @pytest.mark.parametrize("n_sources, reset_kwargs", parameters_grad_fdica) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("is_holonomic", parameters_is_holonomic) def test_natural_grad_fdica( n_sources: int, callbacks: Optional[ Union[Callable[[NaturalGradFDICA], None], List[Callable[[NaturalGradFDICA], None]]] ], is_holonomic: bool, reset_kwargs: Dict[Any, Any], ): np.random.seed(111) waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag="dev1_female3", max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft - hop_length ) def contrast_fn(y): return 2 * np.abs(y) def score_fn(y): denominator = np.maximum(np.abs(y), 1e-10) return y / denominator fdica = NaturalGradFDICA( contrast_fn=contrast_fn, score_fn=score_fn, callbacks=callbacks, is_holonomic=is_holonomic ) spectrogram_est = fdica(spectrogram_mix, n_iter=n_iter) assert spectrogram_mix.shape == spectrogram_est.shape assert type(fdica.loss[-1]) is float print(fdica) @pytest.mark.parametrize("n_sources, reset_kwargs", parameters_aux_fdica) @pytest.mark.parametrize("spatial_algorithm", parameters_spatial_algorithm) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("scale_restoration", parameters_scale_restoration) def test_aux_fdica( n_sources: int, spatial_algorithm: str, callbacks: Optional[Union[Callable[[AuxFDICA], None], List[Callable[[AuxFDICA], None]]]], scale_restoration: Union[str, bool], reset_kwargs: Dict[Any, Any], ): if spatial_algorithm in ["IP"] and not pytest.run_redundant: pytest.skip(reason="Need --run-redundant option to run.") np.random.seed(111) waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag="dev1_female3", max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft - hop_length ) def contrast_fn(y): return 2 * np.abs(y) def d_contrast_fn(y): return 2 * np.ones_like(y) fdica = AuxFDICA( spatial_algorithm=spatial_algorithm, contrast_fn=contrast_fn, d_contrast_fn=d_contrast_fn, callbacks=callbacks, scale_restoration=scale_restoration, ) spectrogram_est = fdica(spectrogram_mix, n_iter=n_iter) assert spectrogram_mix.shape == spectrogram_est.shape assert type(fdica.loss[-1]) is float print(fdica) @pytest.mark.parametrize("n_sources, reset_kwargs", parameters_grad_fdica) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("is_holonomic", parameters_is_holonomic) def test_grad_laplace_fdica( n_sources: int, callbacks: Optional[ Union[Callable[[GradLaplaceFDICA], None], List[Callable[[GradLaplaceFDICA], None]]] ], is_holonomic: bool, reset_kwargs: Dict[Any, Any], ): np.random.seed(111) waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag="dev1_female3", max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft - hop_length ) fdica = GradLaplaceFDICA(callbacks=callbacks, is_holonomic=is_holonomic) spectrogram_est = fdica(spectrogram_mix, n_iter=n_iter) assert spectrogram_mix.shape == spectrogram_est.shape assert type(fdica.loss[-1]) is float print(fdica) @pytest.mark.parametrize("n_sources, reset_kwargs", parameters_grad_fdica) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("is_holonomic", parameters_is_holonomic) def test_natural_grad_laplace_fdica( n_sources: int, callbacks: Optional[ Union[ Callable[[NaturalGradLaplaceFDICA], None], List[Callable[[NaturalGradLaplaceFDICA], None]], ] ], is_holonomic: bool, reset_kwargs: Dict[Any, Any], ): np.random.seed(111) waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag="dev1_female3", max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft - hop_length ) fdica = NaturalGradLaplaceFDICA(callbacks=callbacks, is_holonomic=is_holonomic) spectrogram_est = fdica(spectrogram_mix, n_iter=n_iter) assert spectrogram_mix.shape == spectrogram_est.shape assert type(fdica.loss[-1]) is float print(fdica) @pytest.mark.parametrize("n_sources, reset_kwargs", parameters_aux_fdica) @pytest.mark.parametrize("spatial_algorithm", parameters_spatial_algorithm) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("scale_restoration", parameters_scale_restoration) def test_aux_laplace_fdica( n_sources: int, spatial_algorithm: str, callbacks: Optional[ Union[Callable[[AuxLaplaceFDICA], None], List[Callable[[AuxLaplaceFDICA], None]]] ], scale_restoration: Union[str, bool], reset_kwargs: Dict[Any, Any], ): if spatial_algorithm in ["IP"] and not pytest.run_redundant: pytest.skip(reason="Need --run-redundant option to run.") np.random.seed(111) waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag="dev1_female3", max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft - hop_length ) fdica = AuxLaplaceFDICA( spatial_algorithm=spatial_algorithm, callbacks=callbacks, scale_restoration=scale_restoration, ) spectrogram_est = fdica(spectrogram_mix, n_iter=n_iter) assert spectrogram_mix.shape == spectrogram_est.shape assert type(fdica.loss[-1]) is float print(fdica) ================================================ FILE: tests/package/bss/test_hva.py ================================================ from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import pytest import scipy.signal as ss from dummy.callback import DummyCallback, dummy_function from dummy.utils.dataset import download_sample_speech_data from ssspy.bss.hva import HVA, MaskingADMMHVA, MaskingPDSHVA max_duration = 0.5 n_fft = 2048 hop_length = 1024 n_bins = n_fft // 2 + 1 n_iter = 5 parameters_hva = [ (2, None, {}), ( 3, dummy_function, {"demix_filter": np.tile(-np.eye(3, dtype=np.complex128), reps=(n_bins, 1, 1))}, ), (2, [DummyCallback(), dummy_function], {}), ] @pytest.mark.parametrize("n_sources, callbacks, reset_kwargs", parameters_hva) def test_masking_pdshva( n_sources: int, callbacks: Optional[ Union[Callable[[MaskingPDSHVA], None], List[Callable[[MaskingPDSHVA], None]]] ], reset_kwargs: Dict[Any, Any], ): waveform_src_img, _ = download_sample_speech_data( sisec2010_root="./tests/.data/SiSEC2010", mird_root="./tests/.data/MIRD", n_sources=n_sources, sisec2010_tag="dev1_female3", max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft - hop_length ) hva = MaskingPDSHVA(callbacks=callbacks) spectrogram_mix_normalized = hva.normalize_by_spectral_norm(spectrogram_mix) spectrogram_est = hva(spectrogram_mix_normalized, n_iter=n_iter, **reset_kwargs) assert spectrogram_mix.shape == spectrogram_est.shape print(hva) @pytest.mark.parametrize("n_sources, callbacks, reset_kwargs", parameters_hva) def test_masking_admmhva( n_sources: int, callbacks: Optional[ Union[Callable[[MaskingADMMHVA], None], List[Callable[[MaskingADMMHVA], None]]] ], reset_kwargs: Dict[Any, Any], ): waveform_src_img, _ = download_sample_speech_data( sisec2010_root="./tests/.data/SiSEC2010", mird_root="./tests/.data/MIRD", n_sources=n_sources, sisec2010_tag="dev1_female3", max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft - hop_length ) hva = MaskingADMMHVA(callbacks=callbacks) spectrogram_mix_normalized = hva.normalize_by_spectral_norm(spectrogram_mix) spectrogram_est = hva(spectrogram_mix_normalized, n_iter=n_iter, **reset_kwargs) assert spectrogram_mix.shape == spectrogram_est.shape print(hva) @pytest.mark.parametrize("n_sources, callbacks, reset_kwargs", parameters_hva) def test_hva( n_sources: int, callbacks: Optional[Union[Callable[[HVA], None], List[Callable[[HVA], None]]]], reset_kwargs: Dict[Any, Any], ): waveform_src_img, _ = download_sample_speech_data( sisec2010_root="./tests/.data/SiSEC2010", mird_root="./tests/.data/MIRD", n_sources=n_sources, sisec2010_tag="dev1_female3", max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft - hop_length ) hva = HVA(callbacks=callbacks) spectrogram_mix_normalized = hva.normalize_by_spectral_norm(spectrogram_mix) spectrogram_est = hva(spectrogram_mix_normalized, n_iter=n_iter, **reset_kwargs) assert spectrogram_mix.shape == spectrogram_est.shape print(hva) ================================================ FILE: tests/package/bss/test_ica.py ================================================ from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import pytest from dummy.callback import DummyCallback, dummy_function from dummy.utils.dataset import download_sample_speech_data from ssspy.bss.ica import ( FastICA, GradICA, GradICABase, GradLaplaceICA, NaturalGradICA, NaturalGradLaplaceICA, ) max_duration = 0.5 n_iter = 3 parameters_callbacks = [None, dummy_function, [DummyCallback(), dummy_function]] parameters_is_holonomic = [True, False] parameters_grad_ica = [ (2, {}), (3, {"demix_filter": -np.eye(3)}), ] parameters_fast_ica = [ (2, {}), (3, {"demix_filter": -np.eye(3)}), ] @pytest.mark.parametrize("n_sources, reset_kwargs", parameters_grad_ica) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("is_holonomic", parameters_is_holonomic) def test_grad_ica_base( n_sources: int, callbacks: Optional[Union[Callable[[GradICA], None], List[Callable[[GradICA], None]]]], is_holonomic: bool, reset_kwargs: Dict[Any, Any], ): def contrast_fn(x): return np.log(1 + np.exp(x)) def score_fn(x): return 1 / (1 + np.exp(-x)) ica = GradICABase(contrast_fn=contrast_fn, score_fn=score_fn, callbacks=callbacks) print(ica) @pytest.mark.parametrize("n_sources, reset_kwargs", parameters_grad_ica) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("is_holonomic", parameters_is_holonomic) def test_grad_ica( n_sources: int, callbacks: Optional[Union[Callable[[GradICA], None], List[Callable[[GradICA], None]]]], is_holonomic: bool, reset_kwargs: Dict[Any, Any], ): waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag="dev1_female3", max_duration=max_duration, conv=False, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) def contrast_fn(x): return np.log(1 + np.exp(x)) def score_fn(x): return 1 / (1 + np.exp(-x)) ica = GradICA( contrast_fn=contrast_fn, score_fn=score_fn, callbacks=callbacks, is_holonomic=is_holonomic ) waveform_est = ica(waveform_mix, n_iter=n_iter) assert waveform_mix.shape == waveform_est.shape assert type(ica.loss[-1]) is float print(ica) @pytest.mark.parametrize("n_sources, reset_kwargs", parameters_grad_ica) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("is_holonomic", parameters_is_holonomic) def test_natural_grad_ica( n_sources: int, callbacks: Optional[Union[Callable[[GradICA], None], List[Callable[[GradICA], None]]]], is_holonomic: bool, reset_kwargs: Dict[Any, Any], ): waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag="dev1_female3", max_duration=max_duration, conv=False, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) def contrast_fn(x): return np.log(1 + np.exp(x)) def score_fn(x): return 1 / (1 + np.exp(-x)) ica = NaturalGradICA( contrast_fn=contrast_fn, score_fn=score_fn, callbacks=callbacks, is_holonomic=is_holonomic ) waveform_est = ica(waveform_mix, n_iter=n_iter) assert waveform_mix.shape == waveform_est.shape assert type(ica.loss[-1]) is float print(ica) @pytest.mark.parametrize("n_sources, reset_kwargs", parameters_grad_ica) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("is_holonomic", parameters_is_holonomic) def test_grad_laplace_ica( n_sources: int, callbacks: Optional[Union[Callable[[GradICA], None], List[Callable[[GradICA], None]]]], is_holonomic: bool, reset_kwargs: Dict[Any, Any], ): waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag="dev1_female3", max_duration=max_duration, conv=False, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) ica = GradLaplaceICA(callbacks=callbacks, is_holonomic=is_holonomic) waveform_est = ica(waveform_mix, n_iter=n_iter) assert waveform_mix.shape == waveform_est.shape assert type(ica.loss[-1]) is float print(ica) @pytest.mark.parametrize("n_sources, reset_kwargs", parameters_grad_ica) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("is_holonomic", parameters_is_holonomic) def test_natural_grad_laplace_ica( n_sources: int, callbacks: Optional[Union[Callable[[GradICA], None], List[Callable[[GradICA], None]]]], is_holonomic: bool, reset_kwargs: Dict[Any, Any], ): waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag="dev1_female3", max_duration=max_duration, conv=False, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) ica = NaturalGradLaplaceICA(callbacks=callbacks, is_holonomic=is_holonomic) waveform_est = ica(waveform_mix, n_iter=n_iter) assert waveform_mix.shape == waveform_est.shape assert type(ica.loss[-1]) is float print(ica) @pytest.mark.parametrize("n_sources, reset_kwargs", parameters_fast_ica) @pytest.mark.parametrize("callbacks", parameters_callbacks) def test_fast_ica( n_sources: int, callbacks: Optional[Union[Callable[[FastICA], None], List[Callable[[FastICA], None]]]], reset_kwargs: Dict[Any, Any], ): waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag="dev1_female3", max_duration=max_duration, conv=False, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) def contrast_fn(x): return np.log(1 + np.exp(x)) def score_fn(x): return 1 / (1 + np.exp(-x)) def d_score_fn(x): s = 1 / (1 + np.exp(-x)) return s * (1 - s) ica = FastICA( contrast_fn=contrast_fn, score_fn=score_fn, d_score_fn=d_score_fn, callbacks=callbacks ) waveform_est = ica(waveform_mix, n_iter=n_iter) assert waveform_mix.shape == waveform_est.shape assert type(ica.loss[-1]) is float print(ica) ================================================ FILE: tests/package/bss/test_ilrma.py ================================================ from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import pytest import scipy.signal as ss from dummy.callback import DummyCallback, dummy_function from dummy.utils.dataset import download_sample_speech_data from ssspy.bss.ilrma import GGDILRMA, TILRMA, GaussILRMA, ILRMABase max_duration = 0.5 n_fft = 512 hop_length = 256 n_bins = n_fft // 2 + 1 n_iter = 3 rng = np.random.default_rng(42) parameters_dof = [100] parameters_beta = [0.5, 1.5] parameters_spatial_algorithm = ["IP", "IP1", "IP2", "ISS", "ISS1", "ISS2", "IPA"] parameters_source_algorithm = ["MM", "ME"] parameters_callbacks = [None, dummy_function, [DummyCallback(), dummy_function]] parameters_scale_restoration = [True, False, "projection_back", "minimal_distortion_principle"] parameters_ilrma_base = [2] parameters_ilrma_latent = [ ( 2, 4, 2, { "demix_filter": np.tile(np.eye(2, dtype=np.complex128), (n_bins, 1, 1)), "latent": rng.random((2, 4)), "basis": rng.random((n_bins, 4)), }, ), (3, 3, 1, {}), ] parameters_ilrma_wo_latent = [ ( 2, 2, 2, { "demix_filter": np.tile(np.eye(2, dtype=np.complex128), (n_bins, 1, 1)), "basis": rng.random((2, n_bins, 2)), }, ), ( 3, 1, 1, {}, ), ] parameters_normalization_latent = [True, False, "power"] parameters_normalization_wo_latent = [True, False, "power", "projection_back"] @pytest.mark.parametrize( "n_basis", parameters_ilrma_base, ) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("scale_restoration", parameters_scale_restoration) def test_ilrma_base( n_basis: int, callbacks: Optional[Union[Callable[[GaussILRMA], None], List[Callable[[GaussILRMA], None]]]], scale_restoration: Union[str, bool], ): ilrma = ILRMABase( n_basis, partitioning=True, callbacks=callbacks, scale_restoration=scale_restoration, rng=np.random.default_rng(42), ) print(ilrma) @pytest.mark.parametrize( "n_sources, n_basis, domain, reset_kwargs", parameters_ilrma_latent, ) @pytest.mark.parametrize("spatial_algorithm", parameters_spatial_algorithm) @pytest.mark.parametrize("source_algorithm", parameters_source_algorithm) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("normalization", parameters_normalization_latent) @pytest.mark.parametrize("scale_restoration", parameters_scale_restoration) def test_gauss_ilrma_latent( n_sources: int, n_basis: int, spatial_algorithm: str, source_algorithm: str, domain: float, callbacks: Optional[Union[Callable[[GaussILRMA], None], List[Callable[[GaussILRMA], None]]]], normalization: Optional[Union[str, bool]], scale_restoration: Union[str, bool], reset_kwargs: Dict[str, Any], ): if spatial_algorithm in ["IP", "ISS"] and not pytest.run_redundant: pytest.skip(reason="Need --run-redundant option to run.") if n_sources < 4: sisec2010_tag = "dev1_female3" elif n_sources == 4: sisec2010_tag = "dev1_female4" else: raise ValueError("n_sources should be less than 5.") waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag=sisec2010_tag, max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft - hop_length ) kwargs = { "spatial_algorithm": spatial_algorithm, "source_algorithm": source_algorithm, "domain": domain, "partitioning": True, "callbacks": callbacks, "normalization": normalization, "scale_restoration": scale_restoration, "rng": np.random.default_rng(42), } if source_algorithm == "ME" and domain != 2: with pytest.raises(AssertionError) as e: ilrma = GaussILRMA(n_basis, **kwargs) assert str(e.value) == "domain parameter should be 2 when you specify ME algorithm." else: ilrma = GaussILRMA(n_basis, **kwargs) spectrogram_est = ilrma(spectrogram_mix, n_iter=n_iter, **reset_kwargs) assert spectrogram_mix.shape == spectrogram_est.shape assert type(ilrma.loss[-1]) is float if spatial_algorithm in ["ISS", "ISS1", "ISS2"]: assert ilrma.demix_filter is None print(ilrma) @pytest.mark.parametrize( "n_sources, n_basis, domain, reset_kwargs", parameters_ilrma_wo_latent, ) @pytest.mark.parametrize("spatial_algorithm", parameters_spatial_algorithm) @pytest.mark.parametrize("source_algorithm", parameters_source_algorithm) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("normalization", parameters_normalization_wo_latent) @pytest.mark.parametrize("scale_restoration", parameters_scale_restoration) def test_gauss_ilrma_wo_latent( n_sources: int, n_basis: int, spatial_algorithm: str, source_algorithm: str, domain: float, callbacks: Optional[Union[Callable[[GaussILRMA], None], List[Callable[[GaussILRMA], None]]]], normalization: Optional[Union[str, bool]], scale_restoration: Union[str, bool], reset_kwargs: Dict[str, Any], ): if spatial_algorithm in ["IP", "ISS"] and not pytest.run_redundant: pytest.skip(reason="Need --run-redundant option to run.") if n_sources < 4: sisec2010_tag = "dev1_female3" elif n_sources == 4: sisec2010_tag = "dev1_female4" else: raise ValueError("n_sources should be less than 5.") waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag=sisec2010_tag, max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft - hop_length ) kwargs = { "spatial_algorithm": spatial_algorithm, "source_algorithm": source_algorithm, "domain": domain, "partitioning": False, "callbacks": callbacks, "normalization": normalization, "scale_restoration": scale_restoration, "rng": np.random.default_rng(42), } if source_algorithm == "ME" and domain != 2: with pytest.raises(AssertionError) as e: ilrma = GaussILRMA(n_basis, **kwargs) assert str(e.value) == "domain parameter should be 2 when you specify ME algorithm." else: ilrma = GaussILRMA(n_basis, **kwargs) spectrogram_est = ilrma(spectrogram_mix, n_iter=n_iter, **reset_kwargs) assert spectrogram_mix.shape == spectrogram_est.shape assert type(ilrma.loss[-1]) is float if spatial_algorithm in ["ISS", "ISS1", "ISS2"]: assert ilrma.demix_filter is None print(ilrma) @pytest.mark.parametrize( "n_sources, n_basis, domain, reset_kwargs", parameters_ilrma_latent, ) @pytest.mark.parametrize("dof", parameters_dof) @pytest.mark.parametrize("spatial_algorithm", parameters_spatial_algorithm) @pytest.mark.parametrize("source_algorithm", parameters_source_algorithm) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("normalization", parameters_normalization_latent) @pytest.mark.parametrize("scale_restoration", parameters_scale_restoration) def test_t_ilrma_latent( n_sources: int, n_basis: int, dof: float, spatial_algorithm: str, source_algorithm: str, domain: float, callbacks: Optional[Union[Callable[[GaussILRMA], None], List[Callable[[GaussILRMA], None]]]], normalization: Optional[Union[str, bool]], scale_restoration: Union[str, bool], reset_kwargs: Dict[str, Any], ): if spatial_algorithm in ["IP", "ISS"] and not pytest.run_redundant: pytest.skip(reason="Need --run-redundant option to run.") if n_sources < 4: sisec2010_tag = "dev1_female3" elif n_sources == 4: sisec2010_tag = "dev1_female4" else: raise ValueError("n_sources should be less than 5.") waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag=sisec2010_tag, max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft - hop_length ) kwargs = { "dof": dof, "spatial_algorithm": spatial_algorithm, "source_algorithm": source_algorithm, "domain": domain, "partitioning": True, "callbacks": callbacks, "normalization": normalization, "scale_restoration": scale_restoration, "rng": np.random.default_rng(42), } if spatial_algorithm == "IPA": with pytest.raises(ValueError) as e: ilrma = TILRMA(n_basis, **kwargs) assert str(e.value) == "IPA is not supported for t-ILRMA." elif source_algorithm == "ME" and domain != 2: with pytest.raises(AssertionError) as e: ilrma = TILRMA(n_basis, **kwargs) assert str(e.value) == "domain parameter should be 2 when you specify ME algorithm." else: ilrma = TILRMA(n_basis, **kwargs) spectrogram_est = ilrma(spectrogram_mix, n_iter=n_iter, **reset_kwargs) assert spectrogram_mix.shape == spectrogram_est.shape assert type(ilrma.loss[-1]) is float if spatial_algorithm in ["ISS", "ISS1", "ISS2"]: assert ilrma.demix_filter is None print(ilrma) @pytest.mark.parametrize( "n_sources, n_basis, domain, reset_kwargs", parameters_ilrma_wo_latent, ) @pytest.mark.parametrize("dof", parameters_dof) @pytest.mark.parametrize("spatial_algorithm", parameters_spatial_algorithm) @pytest.mark.parametrize("source_algorithm", parameters_source_algorithm) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("normalization", parameters_normalization_wo_latent) @pytest.mark.parametrize("scale_restoration", parameters_scale_restoration) def test_t_ilrma_wo_latent( n_sources: int, n_basis: int, dof: float, spatial_algorithm: str, source_algorithm: str, domain: float, callbacks: Optional[Union[Callable[[GaussILRMA], None], List[Callable[[GaussILRMA], None]]]], normalization: Optional[Union[str, bool]], scale_restoration: Union[str, bool], reset_kwargs: Dict[str, Any], ): if spatial_algorithm in ["IP", "ISS"] and not pytest.run_redundant: pytest.skip(reason="Need --run-redundant option to run.") if n_sources < 4: sisec2010_tag = "dev1_female3" elif n_sources == 4: sisec2010_tag = "dev1_female4" else: raise ValueError("n_sources should be less than 5.") waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag=sisec2010_tag, max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft - hop_length ) kwargs = { "dof": dof, "spatial_algorithm": spatial_algorithm, "source_algorithm": source_algorithm, "domain": domain, "partitioning": False, "callbacks": callbacks, "normalization": normalization, "scale_restoration": scale_restoration, "rng": np.random.default_rng(42), } if spatial_algorithm == "IPA": with pytest.raises(ValueError) as e: ilrma = TILRMA(n_basis, **kwargs) assert str(e.value) == "IPA is not supported for t-ILRMA." elif source_algorithm == "ME" and domain != 2: with pytest.raises(AssertionError) as e: ilrma = TILRMA(n_basis, **kwargs) assert str(e.value) == "domain parameter should be 2 when you specify ME algorithm." else: ilrma = TILRMA(n_basis, **kwargs) spectrogram_est = ilrma(spectrogram_mix, n_iter=n_iter, **reset_kwargs) assert spectrogram_mix.shape == spectrogram_est.shape assert type(ilrma.loss[-1]) is float if spatial_algorithm in ["ISS", "ISS1", "ISS2"]: assert ilrma.demix_filter is None print(ilrma) @pytest.mark.parametrize( "n_sources, n_basis, domain, reset_kwargs", parameters_ilrma_latent, ) @pytest.mark.parametrize("beta", parameters_beta) @pytest.mark.parametrize("spatial_algorithm", parameters_spatial_algorithm) @pytest.mark.parametrize("source_algorithm", parameters_source_algorithm) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("normalization", parameters_normalization_latent) @pytest.mark.parametrize("scale_restoration", parameters_scale_restoration) def test_ggd_ilrma_latent( n_sources: int, n_basis: int, beta: float, spatial_algorithm: str, source_algorithm: str, domain: float, callbacks: Optional[Union[Callable[[GaussILRMA], None], List[Callable[[GaussILRMA], None]]]], normalization: Optional[Union[str, bool]], scale_restoration: Union[str, bool], reset_kwargs: Dict[str, Any], ): if spatial_algorithm in ["IP", "ISS"] and not pytest.run_redundant: pytest.skip(reason="Need --run-redundant option to run.") if n_sources < 4: sisec2010_tag = "dev1_female3" elif n_sources == 4: sisec2010_tag = "dev1_female4" else: raise ValueError("n_sources should be less than 5.") waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag=sisec2010_tag, max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft - hop_length ) kwargs = { "beta": beta, "spatial_algorithm": spatial_algorithm, "source_algorithm": source_algorithm, "domain": domain, "partitioning": True, "callbacks": callbacks, "normalization": normalization, "scale_restoration": scale_restoration, "rng": np.random.default_rng(42), } if source_algorithm == "ME": with pytest.raises(AssertionError) as e: ilrma = GGDILRMA(n_basis, **kwargs) assert str(e.value) == "Not support {}.".format(source_algorithm) elif spatial_algorithm == "IPA": with pytest.raises(ValueError) as e: ilrma = GGDILRMA(n_basis, **kwargs) assert str(e.value) == "IPA is not supported for GGD-ILRMA." else: ilrma = GGDILRMA(n_basis, **kwargs) spectrogram_est = ilrma(spectrogram_mix, n_iter=n_iter, **reset_kwargs) assert spectrogram_mix.shape == spectrogram_est.shape assert type(ilrma.loss[-1]) is float if spatial_algorithm in ["ISS", "ISS1", "ISS2"]: assert ilrma.demix_filter is None print(ilrma) @pytest.mark.parametrize( "n_sources, n_basis, domain, reset_kwargs", parameters_ilrma_wo_latent, ) @pytest.mark.parametrize("beta", parameters_beta) @pytest.mark.parametrize("spatial_algorithm", parameters_spatial_algorithm) @pytest.mark.parametrize("source_algorithm", parameters_source_algorithm) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("normalization", parameters_normalization_wo_latent) @pytest.mark.parametrize("scale_restoration", parameters_scale_restoration) def test_ggd_ilrma_wo_latent( n_sources: int, n_basis: int, beta: float, spatial_algorithm: str, source_algorithm: str, domain: float, callbacks: Optional[Union[Callable[[GaussILRMA], None], List[Callable[[GaussILRMA], None]]]], normalization: Optional[Union[str, bool]], scale_restoration: Union[str, bool], reset_kwargs: Dict[str, Any], ): if spatial_algorithm in ["IP", "ISS"] and not pytest.run_redundant: pytest.skip(reason="Need --run-redundant option to run.") if n_sources < 4: sisec2010_tag = "dev1_female3" elif n_sources == 4: sisec2010_tag = "dev1_female4" else: raise ValueError("n_sources should be less than 5.") waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag=sisec2010_tag, max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft - hop_length ) kwargs = { "beta": beta, "spatial_algorithm": spatial_algorithm, "source_algorithm": source_algorithm, "domain": domain, "partitioning": False, "callbacks": callbacks, "normalization": normalization, "scale_restoration": scale_restoration, "rng": np.random.default_rng(42), } if source_algorithm == "ME": with pytest.raises(AssertionError) as e: ilrma = GGDILRMA(n_basis, **kwargs) assert str(e.value) == "Not support {}.".format(source_algorithm) elif spatial_algorithm == "IPA": with pytest.raises(ValueError) as e: ilrma = GGDILRMA(n_basis, **kwargs) assert str(e.value) == "IPA is not supported for GGD-ILRMA." else: ilrma = GGDILRMA(n_basis, **kwargs) spectrogram_est = ilrma(spectrogram_mix, n_iter=n_iter, **reset_kwargs) assert spectrogram_mix.shape == spectrogram_est.shape assert type(ilrma.loss[-1]) is float if spatial_algorithm in ["ISS", "ISS1", "ISS2"]: assert ilrma.demix_filter is None print(ilrma) ================================================ FILE: tests/package/bss/test_ipsdta.py ================================================ from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import pytest import scipy.signal as ss from dummy.callback import DummyCallback, dummy_function from dummy.utils.dataset import download_sample_speech_data from ssspy.bss.ipsdta import TIPSDTA, BlockDecompositionIPSDTABase, GaussIPSDTA, IPSDTABase max_duration = 0.1 n_fft = 256 hop_length = 128 window = "hann" n_bins = n_fft // 2 + 1 n_iter = 3 rng = np.random.default_rng(42) parameters_dof = [100] parameters_spatial_algorithm = ["FPI", "VCD"] parameters_callbacks = [None, dummy_function, [DummyCallback(), dummy_function]] parameters_source_normalization = [True, False] parameters_scale_restoration = [True, False, "projection_back", "minimal_distortion_principle"] parameters_ipsdta_base = [2] parameters_block_decomposition_ipsdta_base = [4] parameters_ipsdta = [ ( 2, 2, 43, { "demix_filter": np.tile(np.eye(2, dtype=np.complex128), (n_bins, 1, 1)), }, ), (3, 2, 64, {}), ] @pytest.mark.parametrize( "n_basis", parameters_ipsdta_base, ) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("scale_restoration", parameters_scale_restoration) def test_ipsdta_base( n_basis: int, callbacks: Optional[Union[Callable[[IPSDTABase], None], List[Callable[[IPSDTABase], None]]]], scale_restoration: Union[str, bool], ): ipsdta = IPSDTABase( n_basis, callbacks=callbacks, scale_restoration=scale_restoration, record_loss=False, rng=rng, ) print(ipsdta) @pytest.mark.parametrize( "n_basis", parameters_ipsdta_base, ) @pytest.mark.parametrize( "n_blocks", parameters_block_decomposition_ipsdta_base, ) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("scale_restoration", parameters_scale_restoration) def test_block_decomposition_ipsdta_base( n_basis: int, n_blocks: int, callbacks: Optional[ Union[ Callable[[BlockDecompositionIPSDTABase], None], List[Callable[[BlockDecompositionIPSDTABase], None]], ] ], scale_restoration: Union[str, bool], ): ipsdta = BlockDecompositionIPSDTABase( n_basis, n_blocks, callbacks=callbacks, scale_restoration=scale_restoration, record_loss=False, rng=rng, ) print(ipsdta) @pytest.mark.parametrize( "n_sources, n_basis, n_blocks, reset_kwargs", parameters_ipsdta, ) @pytest.mark.parametrize("spatial_algorithm", parameters_spatial_algorithm) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("source_normalization", parameters_source_normalization) @pytest.mark.parametrize("scale_restoration", parameters_scale_restoration) def test_gauss_ipsdta( n_sources: int, n_basis: int, n_blocks: int, spatial_algorithm: str, callbacks: Optional[Union[Callable[[GaussIPSDTA], None], List[Callable[[GaussIPSDTA], None]]]], source_normalization: Optional[Union[str, bool]], scale_restoration: Union[str, bool], reset_kwargs: Dict[str, Any], ): if n_sources < 4: sisec2010_tag = "dev1_female3" elif n_sources == 4: sisec2010_tag = "dev1_female4" else: raise ValueError("n_sources should be less than 5.") waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag=sisec2010_tag, max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window=window, nperseg=n_fft, noverlap=n_fft - hop_length ) ipsdta = GaussIPSDTA( n_basis, n_blocks, spatial_algorithm=spatial_algorithm, callbacks=callbacks, source_normalization=source_normalization, scale_restoration=scale_restoration, rng=rng, ) if spatial_algorithm == "FPI": with pytest.raises(NotImplementedError) as e: spectrogram_est = ipsdta(spectrogram_mix, n_iter=n_iter, **reset_kwargs) assert str(e.value) == "IPSDTA with fixed-point iteration is not supported." else: spectrogram_est = ipsdta(spectrogram_mix, n_iter=n_iter, **reset_kwargs) assert spectrogram_mix.shape == spectrogram_est.shape assert type(ipsdta.loss[-1]) is float print(ipsdta) @pytest.mark.parametrize( "n_sources, n_basis, n_blocks, reset_kwargs", parameters_ipsdta, ) @pytest.mark.parametrize("dof", parameters_dof) @pytest.mark.parametrize("spatial_algorithm", parameters_spatial_algorithm) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("source_normalization", parameters_source_normalization) @pytest.mark.parametrize("scale_restoration", parameters_scale_restoration) def test_t_ipsdta( n_sources: int, n_basis: int, n_blocks: int, dof: float, spatial_algorithm: str, callbacks: Optional[Union[Callable[[GaussIPSDTA], None], List[Callable[[GaussIPSDTA], None]]]], source_normalization: Optional[Union[str, bool]], scale_restoration: Union[str, bool], reset_kwargs: Dict[str, Any], ): if n_sources < 4: sisec2010_tag = "dev1_female3" elif n_sources == 4: sisec2010_tag = "dev1_female4" else: raise ValueError("n_sources should be less than 5.") waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag=sisec2010_tag, max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window=window, nperseg=n_fft, noverlap=n_fft - hop_length ) ipsdta = TIPSDTA( n_basis, n_blocks, dof=dof, spatial_algorithm=spatial_algorithm, callbacks=callbacks, source_normalization=source_normalization, scale_restoration=scale_restoration, rng=rng, ) if spatial_algorithm != "VCD": with pytest.raises(NotImplementedError) as e: spectrogram_est = ipsdta(spectrogram_mix, n_iter=n_iter, **reset_kwargs) assert str(e.value) == "Not support {}.".format(spatial_algorithm) else: spectrogram_est = ipsdta(spectrogram_mix, n_iter=n_iter, **reset_kwargs) assert spectrogram_mix.shape == spectrogram_est.shape assert type(ipsdta.loss[-1]) is float print(ipsdta) ================================================ FILE: tests/package/bss/test_iterative_methods.py ================================================ import numpy as np from ssspy.bss.base import IterativeMethodBase from ssspy.bss.cacgmm import CACGMM from ssspy.bss.fdica import ( AuxFDICA, AuxLaplaceFDICA, GradFDICA, GradLaplaceFDICA, NaturalGradFDICA, NaturalGradLaplaceFDICA, ) from ssspy.bss.ica import FastICA, GradICA, GradLaplaceICA, NaturalGradICA, NaturalGradLaplaceICA from ssspy.bss.ilrma import GGDILRMA, TILRMA, GaussILRMA from ssspy.bss.ipsdta import TIPSDTA, GaussIPSDTA from ssspy.bss.iva import ( PDSIVA, AuxGaussIVA, AuxIVA, AuxLaplaceIVA, FasterIVA, FastIVA, GradGaussIVA, GradIVA, GradLaplaceIVA, NaturalGradGaussIVA, NaturalGradIVA, NaturalGradLaplaceIVA, ) from ssspy.bss.mnmf import FastGaussMNMF, GaussMNMF from ssspy.bss.pdsbss import PDSBSS def test_grad_ica_inheritance() -> None: def contrast_fn(x): return np.log(1 + np.exp(x)) def score_fn(x): return 1 / (1 + np.exp(-x)) ica = GradICA(contrast_fn=contrast_fn, score_fn=score_fn) assert isinstance(ica, IterativeMethodBase) ica = GradLaplaceICA() assert isinstance(ica, IterativeMethodBase) def test_natural_grad_ica_inheritance() -> None: def contrast_fn(x): return np.log(1 + np.exp(x)) def score_fn(x): return 1 / (1 + np.exp(-x)) ica = NaturalGradICA(contrast_fn=contrast_fn, score_fn=score_fn) assert isinstance(ica, IterativeMethodBase) ica = NaturalGradLaplaceICA() assert isinstance(ica, IterativeMethodBase) def test_fast_ica_inheritance() -> None: def contrast_fn(x): return np.log(1 + np.exp(x)) def score_fn(x): return 1 / (1 + np.exp(-x)) def d_score_fn(x): s = 1 / (1 + np.exp(-x)) return s * (1 - s) ica = FastICA(contrast_fn=contrast_fn, score_fn=score_fn, d_score_fn=d_score_fn) assert isinstance(ica, IterativeMethodBase) def test_grad_fdica_inheritance() -> None: def contrast_fn(y): return 2 * np.abs(y) def score_fn(y): denominator = np.maximum(np.abs(y), 1e-10) return y / denominator fdica = GradFDICA(contrast_fn=contrast_fn, score_fn=score_fn) assert isinstance(fdica, IterativeMethodBase) fdica = GradLaplaceFDICA() assert isinstance(fdica, IterativeMethodBase) def test_natural_grad_fdica_inheritance() -> None: def contrast_fn(y): return 2 * np.abs(y) def score_fn(y): denominator = np.maximum(np.abs(y), 1e-10) return y / denominator fdica = NaturalGradFDICA(contrast_fn=contrast_fn, score_fn=score_fn) assert isinstance(fdica, IterativeMethodBase) fdica = NaturalGradLaplaceFDICA() assert isinstance(fdica, IterativeMethodBase) def test_aux_fdica_inheritance() -> None: def contrast_fn(y): return 2 * np.abs(y) def d_contrast_fn(y): return 2 * np.ones_like(y) fdica = AuxFDICA( contrast_fn=contrast_fn, d_contrast_fn=d_contrast_fn, ) assert isinstance(fdica, IterativeMethodBase) fdica = AuxLaplaceFDICA() assert isinstance(fdica, IterativeMethodBase) def test_grad_iva_inheritance() -> None: def contrast_fn(y: np.ndarray) -> np.ndarray: r"""Contrast function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: np.ndarray: The shape is (n_sources, n_frames). """ return 2 * np.linalg.norm(y, axis=1) def score_fn(y) -> np.ndarray: r"""Score function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: np.ndarray: The shape is (n_sources, n_bins, n_frames). """ norm = np.linalg.norm(y, axis=1, keepdims=True) norm = np.maximum(norm, 1e-10) return y / norm iva = GradIVA(contrast_fn=contrast_fn, score_fn=score_fn) assert isinstance(iva, IterativeMethodBase) iva = GradLaplaceIVA() assert isinstance(iva, IterativeMethodBase) iva = GradGaussIVA() assert isinstance(iva, IterativeMethodBase) def test_natural_grad_iva_inheritance() -> None: def contrast_fn(y: np.ndarray) -> np.ndarray: r"""Contrast function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: np.ndarray: The shape is (n_sources, n_frames). """ return 2 * np.linalg.norm(y, axis=1) def score_fn(y): r"""Score function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: np.ndarray: The shape is (n_sources, n_bins, n_frames). """ norm = np.linalg.norm(y, axis=1, keepdims=True) norm = np.maximum(norm, 1e-10) return y / norm iva = NaturalGradIVA(contrast_fn=contrast_fn, score_fn=score_fn) assert isinstance(iva, IterativeMethodBase) iva = NaturalGradLaplaceIVA() assert isinstance(iva, IterativeMethodBase) iva = NaturalGradGaussIVA() assert isinstance(iva, IterativeMethodBase) def test_fast_iva_inheritance() -> None: def contrast_fn(y: np.ndarray) -> np.ndarray: r"""Contrast function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: np.ndarray: The shape is (n_sources, n_frames). """ return 2 * np.linalg.norm(y, axis=1) def d_contrast_fn(y) -> np.ndarray: r"""Derivative of contrast function. Args: y (np.ndarray): The shape is (n_sources, n_frames). Returns: np.ndarray: The shape is (n_sources, n_frames). """ return 2 * np.ones_like(y) def dd_contrast_fn(y) -> np.ndarray: r"""Second order derivative of contrast function. Args: y (np.ndarray): The shape is (n_sources, n_frames). Returns: np.ndarray: The shape is (n_sources, n_frames). """ return 2 * np.zeros_like(y) iva = FastIVA( contrast_fn=contrast_fn, d_contrast_fn=d_contrast_fn, dd_contrast_fn=dd_contrast_fn, ) assert isinstance(iva, IterativeMethodBase) def test_faster_iva_inheritance() -> None: def contrast_fn(y: np.ndarray) -> np.ndarray: r"""Contrast function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: np.ndarray: The shape is (n_sources, n_frames). """ return 2 * np.linalg.norm(y, axis=1) def d_contrast_fn(y) -> np.ndarray: r"""Derivative of contrast function. Args: y (np.ndarray): The shape is (n_sources, n_frames). Returns: np.ndarray: The shape is (n_sources, n_frames). """ return 2 * np.ones_like(y) iva = FasterIVA(contrast_fn=contrast_fn, d_contrast_fn=d_contrast_fn) assert isinstance(iva, IterativeMethodBase) def test_aux_iva_inheritance() -> None: def contrast_fn(y: np.ndarray) -> np.ndarray: r"""Contrast function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: np.ndarray: The shape is (n_sources, n_frames). """ return 2 * np.linalg.norm(y, axis=1) def d_contrast_fn(y) -> np.ndarray: r"""Derivative of contrast function. Args: y (np.ndarray): The shape is (n_sources, n_frames). Returns: np.ndarray: The shape is (n_sources, n_frames). """ return 2 * np.ones_like(y) iva = AuxIVA( contrast_fn=contrast_fn, d_contrast_fn=d_contrast_fn, ) assert isinstance(iva, IterativeMethodBase) iva = AuxLaplaceIVA() assert isinstance(iva, IterativeMethodBase) iva = AuxGaussIVA() assert isinstance(iva, IterativeMethodBase) def test_pds_iva_inheritance() -> None: iva = PDSIVA( contrast_fn=None, prox_penalty=None, ) assert isinstance(iva, IterativeMethodBase) def test_ilrma_inheritance() -> None: n_basis = 2 ilrma = GaussILRMA(n_basis=n_basis) assert isinstance(ilrma, IterativeMethodBase) ilrma = TILRMA(n_basis=n_basis, dof=1000) assert isinstance(ilrma, IterativeMethodBase) ilrma = GGDILRMA(n_basis=n_basis, beta=1.95) assert isinstance(ilrma, IterativeMethodBase) def test_ipsdta_inheritance() -> None: n_basis = 2 n_blocks = 2 ipsdta = GaussIPSDTA(n_basis=n_basis, n_blocks=n_blocks) assert isinstance(ipsdta, IterativeMethodBase) ipsdta = TIPSDTA(n_basis=n_basis, n_blocks=n_blocks, dof=1000) assert isinstance(ipsdta, IterativeMethodBase) def test_mnmf_inheritance() -> None: n_basis = 2 mnmf = GaussMNMF(n_basis=n_basis) assert isinstance(mnmf, IterativeMethodBase) mnmf = FastGaussMNMF(n_basis=n_basis) assert isinstance(mnmf, IterativeMethodBase) def test_pdsbss_inheritance() -> None: def contrast_fn(y: np.ndarray) -> np.ndarray: r"""Contrast function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: np.ndarray of the shape is (n_sources, n_frames). """ return 2 * np.linalg.norm(y, axis=1) def penalty_fn(y: np.ndarray) -> float: loss = contrast_fn(y) loss = np.sum(loss.mean(axis=-1)) return loss def prox_penalty(y: np.ndarray, step_size: float = 1) -> np.ndarray: r"""Proximal operator of penalty function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). step_size (float): Step size. Default: 1. Returns: np.ndarray of the shape is (n_sources, n_bins, n_frames). """ norm = np.linalg.norm(y, axis=1, keepdims=True) return y * np.maximum(1 - step_size / norm, 0) pdsbss = PDSBSS(penalty_fn=penalty_fn, prox_penalty=prox_penalty) assert isinstance(pdsbss, IterativeMethodBase) def test_cacgmm_inheritance() -> None: cacgmm = CACGMM() assert isinstance(cacgmm, IterativeMethodBase) ================================================ FILE: tests/package/bss/test_iva.py ================================================ from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import pytest import scipy.signal as ss from dummy.callback import DummyCallback, dummy_function from dummy.utils.dataset import download_sample_speech_data from ssspy.bss.iva import ( PDSIVA, AuxGaussIVA, AuxIVA, AuxIVABase, AuxLaplaceIVA, FasterIVA, FastIVA, FastIVABase, GradGaussIVA, GradIVA, GradIVABase, GradLaplaceIVA, IVABase, NaturalGradGaussIVA, NaturalGradIVA, NaturalGradLaplaceIVA, ) max_duration = 0.5 n_fft = 512 hop_length = 256 n_bins = n_fft // 2 + 1 n_iter = 3 parameters_spatial_algorithm = ["IP", "IP1", "IP2", "ISS", "ISS1", "ISS2", "IPA"] parameters_callbacks = [None, dummy_function, [DummyCallback(), dummy_function]] parameters_is_holonomic = [True, False] parameters_scale_restoration = [True, False, "projection_back", "minimal_distortion_principle"] parameters_grad_iva = [ (2, {}), ( 3, {"demix_filter": np.tile(-np.eye(3, dtype=np.complex128), reps=(n_bins, 1, 1))}, ), ] parameters_fast_iva = [ (2, "dev1_female3", {}), ( 3, "dev1_female3", {"demix_filter": np.tile(-np.eye(3, dtype=np.complex128), reps=(n_bins, 1, 1))}, ), (2, "dev1_female3", {"demix_filter": None}), ] parameters_aux_iva = [ (2, "dev1_female3", {}), ( 3, "dev1_female3", {"demix_filter": np.tile(-np.eye(3, dtype=np.complex128), reps=(n_bins, 1, 1))}, ), (2, "dev1_female3", {"demix_filter": None}), ( 3, "dev1_female3", {"demix_filter": np.tile(-np.eye(3, dtype=np.complex128), reps=(n_bins, 1, 1))}, ), (4, "dev1_female4", {"demix_filter": None}), ] parameters_pds_iva = [ (2, "dev1_female3", {}), ( 3, "dev1_female3", {"demix_filter": np.tile(-np.eye(3, dtype=np.complex128), reps=(n_bins, 1, 1))}, ), (4, "dev1_female4", {}), ] @pytest.mark.parametrize("callbacks", parameters_callbacks) def test_iva_base( callbacks: Optional[Union[Callable[[AuxIVA], None], List[Callable[[AuxIVA], None]]]], ): iva = IVABase(callbacks=callbacks) print(iva) @pytest.mark.parametrize("callbacks", parameters_callbacks) def test_fast_iva_base( callbacks: Optional[Union[Callable[[AuxIVA], None], List[Callable[[AuxIVA], None]]]], ): np.random.seed(111) iva = FastIVABase(callbacks=callbacks) print(iva) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("is_holonomic", parameters_is_holonomic) def test_grad_iva_base( callbacks: Optional[Union[Callable[[GradIVA], None], List[Callable[[GradIVA], None]]]], is_holonomic: bool, ): np.random.seed(111) def contrast_fn(y: np.ndarray) -> np.ndarray: r"""Contrast function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: np.ndarray: The shape is (n_sources, n_frames). """ return 2 * np.linalg.norm(y, axis=1) def score_fn(y) -> np.ndarray: r"""Score function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: np.ndarray: The shape is (n_sources, n_bins, n_frames). """ norm = np.linalg.norm(y, axis=1, keepdims=True) norm = np.maximum(norm, 1e-10) return y / norm iva = GradIVABase( contrast_fn=contrast_fn, score_fn=score_fn, callbacks=callbacks, is_holonomic=is_holonomic ) print(iva) @pytest.mark.parametrize("n_sources, reset_kwargs", parameters_grad_iva) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("is_holonomic", parameters_is_holonomic) def test_grad_iva( n_sources: int, callbacks: Optional[Union[Callable[[GradIVA], None], List[Callable[[GradIVA], None]]]], is_holonomic: bool, reset_kwargs: Dict[Any, Any], ): np.random.seed(111) waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag="dev1_female3", max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft - hop_length ) def contrast_fn(y: np.ndarray) -> np.ndarray: r"""Contrast function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: np.ndarray: The shape is (n_sources, n_frames). """ return 2 * np.linalg.norm(y, axis=1) def score_fn(y) -> np.ndarray: r"""Score function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: np.ndarray: The shape is (n_sources, n_bins, n_frames). """ norm = np.linalg.norm(y, axis=1, keepdims=True) norm = np.maximum(norm, 1e-10) return y / norm iva = GradIVA( contrast_fn=contrast_fn, score_fn=score_fn, callbacks=callbacks, is_holonomic=is_holonomic ) spectrogram_est = iva(spectrogram_mix, n_iter=n_iter) assert spectrogram_mix.shape == spectrogram_est.shape assert type(iva.loss[-1]) is float print(iva) @pytest.mark.parametrize("n_sources, reset_kwargs", parameters_grad_iva) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("is_holonomic", parameters_is_holonomic) def test_natural_grad_iva( n_sources: int, callbacks: Optional[ Union[Callable[[NaturalGradIVA], None], List[Callable[[NaturalGradIVA], None]]] ], is_holonomic: bool, reset_kwargs: Dict[Any, Any], ): np.random.seed(111) waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag="dev1_female3", max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft - hop_length ) def contrast_fn(y: np.ndarray) -> np.ndarray: r"""Contrast function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: np.ndarray: The shape is (n_sources, n_frames). """ return 2 * np.linalg.norm(y, axis=1) def score_fn(y): r"""Score function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: np.ndarray: The shape is (n_sources, n_bins, n_frames). """ norm = np.linalg.norm(y, axis=1, keepdims=True) norm = np.maximum(norm, 1e-10) return y / norm iva = NaturalGradIVA( contrast_fn=contrast_fn, score_fn=score_fn, callbacks=callbacks, is_holonomic=is_holonomic ) spectrogram_est = iva(spectrogram_mix, n_iter=n_iter) assert spectrogram_mix.shape == spectrogram_est.shape assert type(iva.loss[-1]) is float print(iva) @pytest.mark.parametrize("n_sources, sisec2010_tag, reset_kwargs", parameters_fast_iva) @pytest.mark.parametrize("callbacks", parameters_callbacks) def test_fast_iva( n_sources: int, sisec2010_tag: str, callbacks: Optional[Union[Callable[[AuxIVA], None], List[Callable[[AuxIVA], None]]]], reset_kwargs: Dict[Any, Any], ): np.random.seed(111) waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag=sisec2010_tag, max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft - hop_length ) def contrast_fn(y: np.ndarray) -> np.ndarray: r"""Contrast function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: np.ndarray: The shape is (n_sources, n_frames). """ return 2 * np.linalg.norm(y, axis=1) def d_contrast_fn(y) -> np.ndarray: r"""Derivative of contrast function. Args: y (np.ndarray): The shape is (n_sources, n_frames). Returns: np.ndarray: The shape is (n_sources, n_frames). """ return 2 * np.ones_like(y) def dd_contrast_fn(y) -> np.ndarray: r"""Second order derivative of contrast function. Args: y (np.ndarray): The shape is (n_sources, n_frames). Returns: np.ndarray: The shape is (n_sources, n_frames). """ return 2 * np.zeros_like(y) iva = FastIVA( contrast_fn=contrast_fn, d_contrast_fn=d_contrast_fn, dd_contrast_fn=dd_contrast_fn, callbacks=callbacks, ) spectrogram_est = iva(spectrogram_mix, n_iter=n_iter) assert spectrogram_mix.shape == spectrogram_est.shape assert type(iva.loss[-1]) is float print(iva) @pytest.mark.parametrize("n_sources, sisec2010_tag, reset_kwargs", parameters_fast_iva) @pytest.mark.parametrize("callbacks", parameters_callbacks) def test_faster_iva( n_sources: int, sisec2010_tag: str, callbacks: Optional[Union[Callable[[AuxIVA], None], List[Callable[[AuxIVA], None]]]], reset_kwargs: Dict[Any, Any], ): np.random.seed(111) waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag=sisec2010_tag, max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft - hop_length ) def contrast_fn(y: np.ndarray) -> np.ndarray: r"""Contrast function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: np.ndarray: The shape is (n_sources, n_frames). """ return 2 * np.linalg.norm(y, axis=1) def d_contrast_fn(y) -> np.ndarray: r"""Derivative of contrast function. Args: y (np.ndarray): The shape is (n_sources, n_frames). Returns: np.ndarray: The shape is (n_sources, n_frames). """ return 2 * np.ones_like(y) iva = FasterIVA(contrast_fn=contrast_fn, d_contrast_fn=d_contrast_fn, callbacks=callbacks) spectrogram_est = iva(spectrogram_mix, n_iter=n_iter) assert spectrogram_mix.shape == spectrogram_est.shape assert type(iva.loss[-1]) is float print(iva) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("scale_restoration", parameters_scale_restoration) def test_aux_iva_base( callbacks: Optional[Union[Callable[[AuxIVA], None], List[Callable[[AuxIVA], None]]]], scale_restoration: Union[str, bool], ): np.random.seed(111) def contrast_fn(y: np.ndarray) -> np.ndarray: r"""Contrast function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: np.ndarray: The shape is (n_sources, n_frames). """ return 2 * np.linalg.norm(y, axis=1) def d_contrast_fn(y) -> np.ndarray: r"""Derivative of contrast function. Args: y (np.ndarray): The shape is (n_sources, n_frames). Returns: np.ndarray: The shape is (n_sources, n_frames). """ return 2 * np.ones_like(y) iva = AuxIVABase( contrast_fn=contrast_fn, d_contrast_fn=d_contrast_fn, callbacks=callbacks, scale_restoration=scale_restoration, ) print(iva) @pytest.mark.parametrize("n_sources, sisec2010_tag, reset_kwargs", parameters_aux_iva) @pytest.mark.parametrize("spatial_algorithm", parameters_spatial_algorithm) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("scale_restoration", parameters_scale_restoration) def test_aux_iva( n_sources: int, sisec2010_tag: str, spatial_algorithm: str, callbacks: Optional[Union[Callable[[AuxIVA], None], List[Callable[[AuxIVA], None]]]], scale_restoration: Union[str, bool], reset_kwargs: Dict[Any, Any], ): if spatial_algorithm in ["IP", "ISS"] and not pytest.run_redundant: pytest.skip(reason="Need --run-redundant option to run.") np.random.seed(111) waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag=sisec2010_tag, max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft - hop_length ) def contrast_fn(y: np.ndarray) -> np.ndarray: r"""Contrast function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: np.ndarray: The shape is (n_sources, n_frames). """ return 2 * np.linalg.norm(y, axis=1) def d_contrast_fn(y) -> np.ndarray: r"""Derivative of contrast function. Args: y (np.ndarray): The shape is (n_sources, n_frames). Returns: np.ndarray: The shape is (n_sources, n_frames). """ return 2 * np.ones_like(y) iva = AuxIVA( spatial_algorithm=spatial_algorithm, contrast_fn=contrast_fn, d_contrast_fn=d_contrast_fn, callbacks=callbacks, scale_restoration=scale_restoration, ) spectrogram_est = iva(spectrogram_mix, n_iter=n_iter) assert spectrogram_mix.shape == spectrogram_est.shape assert type(iva.loss[-1]) is float if spatial_algorithm in ["ISS", "ISS1", "ISS2"]: assert iva.demix_filter is None print(iva) @pytest.mark.parametrize("n_sources, sisec2010_tag, reset_kwargs", parameters_pds_iva) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("scale_restoration", parameters_scale_restoration) def test_pds_iva( n_sources: int, sisec2010_tag: str, callbacks: Optional[Union[Callable[[AuxIVA], None], List[Callable[[AuxIVA], None]]]], scale_restoration: Union[str, bool], reset_kwargs: Dict[Any, Any], ): np.random.seed(111) waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag=sisec2010_tag, max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft - hop_length ) iva = PDSIVA( contrast_fn=None, prox_penalty=None, callbacks=callbacks, scale_restoration=scale_restoration, ) spectrogram_est = iva(spectrogram_mix, n_iter=n_iter, **reset_kwargs) assert spectrogram_mix.shape == spectrogram_est.shape assert type(iva.loss[-1]) is float print(iva) @pytest.mark.parametrize("specify_contrast_fn", [True, False]) def test_iva_insufficient_fn(specify_contrast_fn: bool): def _contrast_fn(y: np.ndarray) -> np.ndarray: return np.linalg.norm(y, axis=1) def _prox_penalty(y: np.ndarray, step_size: float = 1) -> np.ndarray: norm = np.linalg.norm(y, axis=1, keepdims=True) return y * np.maximum(1 - step_size / norm, 0) if specify_contrast_fn: contrast_fn = _contrast_fn prox_penalty = None else: contrast_fn = None prox_penalty = _prox_penalty with pytest.raises(ValueError) as e: _ = PDSIVA( contrast_fn=contrast_fn, prox_penalty=prox_penalty, ) if specify_contrast_fn: assert str(e.value) == "Set prox_penalty." else: assert str(e.value) == "Set contrast_fn." @pytest.mark.parametrize("n_sources, reset_kwargs", parameters_grad_iva) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("is_holonomic", parameters_is_holonomic) def test_grad_laplace_iva( n_sources: int, callbacks: Optional[ Union[Callable[[GradLaplaceIVA], None], List[Callable[[GradLaplaceIVA], None]]] ], is_holonomic: bool, reset_kwargs: Dict[Any, Any], ): np.random.seed(111) waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag="dev1_female3", max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft - hop_length ) iva = GradLaplaceIVA(callbacks=callbacks, is_holonomic=is_holonomic) spectrogram_est = iva(spectrogram_mix, n_iter=n_iter) assert spectrogram_mix.shape == spectrogram_est.shape assert type(iva.loss[-1]) is float print(iva) @pytest.mark.parametrize("n_sources, reset_kwargs", parameters_grad_iva) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("is_holonomic", parameters_is_holonomic) def test_grad_gauss_iva( n_sources: int, callbacks: Optional[ Union[Callable[[GradGaussIVA], None], List[Callable[[GradGaussIVA], None]]] ], is_holonomic: bool, reset_kwargs: Dict[Any, Any], ): np.random.seed(111) waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag="dev1_female3", max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft - hop_length ) iva = GradGaussIVA(callbacks=callbacks, is_holonomic=is_holonomic) spectrogram_est = iva(spectrogram_mix, n_iter=n_iter) assert spectrogram_mix.shape == spectrogram_est.shape assert type(iva.loss[-1]) is float print(iva) @pytest.mark.parametrize("n_sources, reset_kwargs", parameters_grad_iva) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("is_holonomic", parameters_is_holonomic) def test_natural_grad_laplace_iva( n_sources: int, callbacks: Optional[ Union[ Callable[[NaturalGradLaplaceIVA], None], List[Callable[[NaturalGradLaplaceIVA], None]] ] ], is_holonomic: bool, reset_kwargs: Dict[Any, Any], ): np.random.seed(111) waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag="dev1_female3", max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft - hop_length ) iva = NaturalGradLaplaceIVA(callbacks=callbacks, is_holonomic=is_holonomic) spectrogram_est = iva(spectrogram_mix, n_iter=n_iter) assert spectrogram_mix.shape == spectrogram_est.shape assert type(iva.loss[-1]) is float print(iva) @pytest.mark.parametrize("n_sources, reset_kwargs", parameters_grad_iva) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("is_holonomic", parameters_is_holonomic) def test_natural_grad_gauss_iva( n_sources: int, callbacks: Optional[ Union[Callable[[NaturalGradGaussIVA], None], List[Callable[[NaturalGradGaussIVA], None]]] ], is_holonomic: bool, reset_kwargs: Dict[Any, Any], ): np.random.seed(111) waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag="dev1_female3", max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft - hop_length ) iva = NaturalGradGaussIVA(callbacks=callbacks, is_holonomic=is_holonomic) spectrogram_est = iva(spectrogram_mix, n_iter=n_iter) assert spectrogram_mix.shape == spectrogram_est.shape assert type(iva.loss[-1]) is float print(iva) @pytest.mark.parametrize("n_sources, sisec2010_tag, reset_kwargs", parameters_aux_iva) @pytest.mark.parametrize("spatial_algorithm", parameters_spatial_algorithm) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("scale_restoration", parameters_scale_restoration) def test_aux_laplace_iva( n_sources: int, sisec2010_tag: str, spatial_algorithm: str, callbacks: Optional[ Union[ Callable[[NaturalGradLaplaceIVA], None], List[Callable[[NaturalGradLaplaceIVA], None]] ] ], scale_restoration: Union[str, bool], reset_kwargs: Dict[Any, Any], ): if spatial_algorithm in ["IP", "ISS"] and not pytest.run_redundant: pytest.skip(reason="Need --run-redundant option to run.") np.random.seed(111) waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag=sisec2010_tag, max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft - hop_length ) iva = AuxLaplaceIVA( spatial_algorithm=spatial_algorithm, callbacks=callbacks, scale_restoration=scale_restoration, ) spectrogram_est = iva(spectrogram_mix, n_iter=n_iter) assert spectrogram_mix.shape == spectrogram_est.shape assert type(iva.loss[-1]) is float if spatial_algorithm in ["ISS", "ISS1", "ISS2"]: assert iva.demix_filter is None print(iva) @pytest.mark.parametrize("n_sources, sisec2010_tag, reset_kwargs", parameters_aux_iva) @pytest.mark.parametrize("spatial_algorithm", parameters_spatial_algorithm) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("scale_restoration", parameters_scale_restoration) def test_aux_gauss_iva( n_sources: int, sisec2010_tag: str, spatial_algorithm: str, callbacks: Optional[ Union[Callable[[NaturalGradGaussIVA], None], List[Callable[[NaturalGradGaussIVA], None]]] ], scale_restoration: Union[str, bool], reset_kwargs: Dict[Any, Any], ): if spatial_algorithm in ["IP", "ISS"] and not pytest.run_redundant: pytest.skip(reason="Need --run-redundant option to run.") np.random.seed(111) waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag=sisec2010_tag, max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft - hop_length ) iva = AuxGaussIVA( spatial_algorithm=spatial_algorithm, callbacks=callbacks, scale_restoration=scale_restoration, ) spectrogram_est = iva(spectrogram_mix, n_iter=n_iter) assert spectrogram_mix.shape == spectrogram_est.shape assert type(iva.loss[-1]) is float if spatial_algorithm in ["ISS", "ISS1", "ISS2"]: assert iva.demix_filter is None print(iva) ================================================ FILE: tests/package/bss/test_mnmf.py ================================================ from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import pytest import scipy.signal as ss from dummy.callback import DummyCallback, dummy_function from dummy.utils.dataset import download_sample_speech_data from ssspy.bss.mnmf import FastGaussMNMF, FastMNMFBase, GaussMNMF, MNMFBase max_duration = 0.1 n_fft = 256 hop_length = 128 window = "hann" n_bins = n_fft // 2 + 1 n_iter = 3 rng = np.random.default_rng(42) parameters_diagonalizer_algorithm = ["IP", "IP1", "IP2"] parameters_partitioning = [True, False] parameters_callbacks = [None, dummy_function, [DummyCallback(), dummy_function]] parameters_normalization = [True, False] parameters_mnmf_base = [2] parameters_mnmf = [ (2, 2, 2, {}), (3, 2, 3, {}), ] @pytest.mark.parametrize( "n_basis", parameters_mnmf_base, ) @pytest.mark.parametrize("callbacks", parameters_callbacks) def test_mnmf_base( n_basis: int, callbacks: Optional[Union[Callable[[MNMFBase], None], List[Callable[[MNMFBase], None]]]], ): ipsdta = MNMFBase( n_basis, callbacks=callbacks, record_loss=False, rng=rng, ) print(ipsdta) @pytest.mark.parametrize( "n_basis", parameters_mnmf_base, ) @pytest.mark.parametrize("callbacks", parameters_callbacks) def test_fast_mnmf_base( n_basis: int, callbacks: Optional[ Union[Callable[[FastMNMFBase], None], List[Callable[[FastMNMFBase], None]]] ], ): ipsdta = FastMNMFBase( n_basis, callbacks=callbacks, record_loss=False, rng=rng, ) print(ipsdta) @pytest.mark.parametrize( "n_sources, n_channels, n_basis, reset_kwargs", parameters_mnmf, ) @pytest.mark.parametrize( "partitioning", parameters_partitioning, ) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("normalization", parameters_normalization) def test_gauss_mnmf( n_sources: int, n_channels: int, n_basis: int, partitioning: bool, callbacks: Optional[Union[Callable[[GaussMNMF], None], List[Callable[[GaussMNMF], None]]]], normalization: Optional[Union[str, bool]], reset_kwargs: Dict[str, Any], ): if n_sources < 4: sisec2010_tag = "dev1_female3" elif n_sources == 4: sisec2010_tag = "dev1_female4" else: raise ValueError("n_sources should be less than 5.") waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag=sisec2010_tag, max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img[:n_channels], axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window=window, nperseg=n_fft, noverlap=n_fft - hop_length ) mnmf = GaussMNMF( n_basis, n_sources=n_sources, partitioning=partitioning, callbacks=callbacks, normalization=normalization, rng=rng, ) spectrogram_est = mnmf(spectrogram_mix, n_iter=n_iter, **reset_kwargs) assert spectrogram_est.shape == (n_sources,) + spectrogram_mix.shape[1:] assert type(mnmf.loss[-1]) is float print(mnmf) @pytest.mark.parametrize( "n_sources, n_channels, n_basis, reset_kwargs", parameters_mnmf, ) @pytest.mark.parametrize("diagonalizer_algorithm", parameters_diagonalizer_algorithm) @pytest.mark.parametrize( "partitioning", parameters_partitioning, ) @pytest.mark.parametrize("callbacks", parameters_callbacks) @pytest.mark.parametrize("normalization", parameters_normalization) def test_fast_gauss_mnmf( n_sources: int, n_channels: int, n_basis: int, diagonalizer_algorithm: str, partitioning: bool, callbacks: Optional[Union[Callable[[GaussMNMF], None], List[Callable[[GaussMNMF], None]]]], normalization: Optional[Union[str, bool]], reset_kwargs: Dict[str, Any], ): if diagonalizer_algorithm in ["IP"] and not pytest.run_redundant: pytest.skip(reason="Need --run-redundant option to run.") if n_sources < 4: sisec2010_tag = "dev1_female3" elif n_sources == 4: sisec2010_tag = "dev1_female4" else: raise ValueError("n_sources should be less than 5.") waveform_src_img, _ = download_sample_speech_data( n_sources=n_sources, sisec2010_tag=sisec2010_tag, max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img[:n_channels], axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window=window, nperseg=n_fft, noverlap=n_fft - hop_length ) if partitioning: with pytest.raises(AssertionError) as e: mnmf = FastGaussMNMF( n_basis, n_sources=n_sources, diagonalizer_algorithm=diagonalizer_algorithm, partitioning=partitioning, callbacks=callbacks, normalization=normalization, rng=rng, ) assert str(e.value) == "partitioning function is not supported." else: mnmf = FastGaussMNMF( n_basis, n_sources=n_sources, diagonalizer_algorithm=diagonalizer_algorithm, partitioning=partitioning, callbacks=callbacks, normalization=normalization, rng=rng, ) spectrogram_est = mnmf(spectrogram_mix, n_iter=n_iter, **reset_kwargs) assert spectrogram_est.shape == (n_sources,) + spectrogram_mix.shape[1:] assert type(mnmf.loss[-1]) is float print(mnmf) ================================================ FILE: tests/package/bss/test_pair_selector.py ================================================ import pytest from ssspy.bss._select_pair import combination_pair_selector, sequential_pair_selector parameters_n_sources = [4] parameters_step = [1, 2] parameters_ascend = [True, False] @pytest.mark.parametrize("n_sources", parameters_n_sources) @pytest.mark.parametrize("step", parameters_step) @pytest.mark.parametrize("ascend", parameters_ascend) def test_sequential_pair_selector(n_sources: int, step: int, ascend: bool): with pytest.warns(UserWarning) as record: for m, n in sequential_pair_selector(n_sources, step=step, sort=ascend): if ascend: assert m < n assert len(record) == 1 assert str(record[0].message) == "Use ssspy.utils.select_pair.sequential_pair_selector instead." @pytest.mark.parametrize("n_sources", parameters_n_sources) @pytest.mark.parametrize("ascend", parameters_ascend) def test_combination_pair_selector(n_sources: int, ascend: bool): with pytest.warns(UserWarning) as record: for m, n in combination_pair_selector(n_sources, sort=ascend): if ascend: assert m < n assert len(record) == 1 assert ( str(record[0].message) == "Use ssspy.utils.select_pair.combination_pair_selector instead." ) ================================================ FILE: tests/package/bss/test_pdsbss.py ================================================ import functools from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import pytest import scipy.signal as ss from dummy.callback import DummyCallback, dummy_function from dummy.utils.dataset import download_sample_speech_data from ssspy.bss.pdsbss import PDSBSS, MaskingPDSBSS, PDSBSSBase max_duration = 0.5 n_fft = 2048 hop_length = 1024 n_bins = n_fft // 2 + 1 n_iter = 5 parameters_pdsbss = [ (2, None, {}), ( 3, dummy_function, {"demix_filter": np.tile(-np.eye(3, dtype=np.complex128), reps=(n_bins, 1, 1))}, ), (2, [DummyCallback(), dummy_function], {}), ] parameters_set_panalty_fn = [True, False] def contrast_fn(y: np.ndarray) -> np.ndarray: r"""Contrast function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: np.ndarray of the shape is (n_sources, n_frames). """ return 2 * np.linalg.norm(y, axis=1) def penalty_fn(y: np.ndarray) -> float: loss = contrast_fn(y) loss = np.sum(loss.mean(axis=-1)) return loss def prox_penalty(y: np.ndarray, step_size: float = 1) -> np.ndarray: r"""Proximal operator of penalty function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). step_size (float): Step size. Default: 1. Returns: np.ndarray of the shape is (n_sources, n_bins, n_frames). """ norm = np.linalg.norm(y, axis=1, keepdims=True) return y * np.maximum(1 - step_size / norm, 0) def mask_fn(y: np.ndarray, step_size: float = 1) -> np.ndarray: r"""Masking function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). step_size (float): Step size. Default: 1. Returns: np.ndarray of the shape is (n_sources, n_bins, n_frames). """ norm = np.linalg.norm(y, axis=1, keepdims=True) mask = np.maximum(1 - step_size / norm, 0) mask = np.tile(mask, (1, y.shape[1], 1)) return mask def test_pds_base(): pdsbss = PDSBSSBase(penalty_fn=penalty_fn, prox_penalty=prox_penalty) print(pdsbss) @pytest.mark.parametrize("n_sources, callbacks, reset_kwargs", parameters_pdsbss) @pytest.mark.parametrize("set_panalty_fn", parameters_set_panalty_fn) def test_pdsbss( n_sources: int, callbacks: Optional[Union[Callable[[PDSBSS], None], List[Callable[[PDSBSS], None]]]], reset_kwargs: Dict[Any, Any], set_panalty_fn: bool, ): np.random.seed(111) waveform_src_img, _ = download_sample_speech_data( sisec2010_root="./tests/.data/SiSEC2010", mird_root="./tests/.data/MIRD", n_sources=n_sources, sisec2010_tag="dev1_female3", max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft - hop_length ) if set_panalty_fn: pdsbss = PDSBSS(penalty_fn=penalty_fn, prox_penalty=prox_penalty, callbacks=callbacks) else: pdsbss = PDSBSS(prox_penalty=prox_penalty, callbacks=callbacks) spectrogram_mix_normalized = pdsbss.normalize_by_spectral_norm(spectrogram_mix) spectrogram_est = pdsbss(spectrogram_mix_normalized, n_iter=n_iter, **reset_kwargs) assert spectrogram_mix.shape == spectrogram_est.shape print(pdsbss) @pytest.mark.parametrize("n_sources, callbacks, reset_kwargs", parameters_pdsbss) def test_masking_pdsbss( n_sources: int, callbacks: Optional[ Union[Callable[[MaskingPDSBSS], None], List[Callable[[MaskingPDSBSS], None]]] ], reset_kwargs: Dict[Any, Any], ): np.random.seed(111) waveform_src_img, _ = download_sample_speech_data( sisec2010_root="./tests/.data/SiSEC2010", mird_root="./tests/.data/MIRD", n_sources=n_sources, sisec2010_tag="dev1_female3", max_duration=max_duration, conv=True, ) waveform_mix = np.sum(waveform_src_img, axis=1) # (n_channels, n_samples) _, _, spectrogram_mix = ss.stft( waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft - hop_length ) pdsbss = MaskingPDSBSS(mask_fn=functools.partial(mask_fn, step_size=1), callbacks=callbacks) spectrogram_mix_normalized = pdsbss.normalize_by_spectral_norm(spectrogram_mix) spectrogram_est = pdsbss(spectrogram_mix_normalized, n_iter=n_iter, **reset_kwargs) assert spectrogram_mix.shape == spectrogram_est.shape print(pdsbss) ================================================ FILE: tests/package/bss/test_proxbss.py ================================================ import numpy as np from ssspy.bss.proxbss import ProxBSSBase def contrast_fn(y: np.ndarray) -> np.ndarray: r"""Contrast function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: np.ndarray of the shape is (n_sources, n_frames). """ return 2 * np.linalg.norm(y, axis=1) def penalty_fn(y: np.ndarray) -> float: loss = contrast_fn(y) loss = np.sum(loss.mean(axis=-1)) return loss def prox_penalty(y: np.ndarray, step_size: float = 1) -> np.ndarray: r"""Proximal operator of penalty function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). step_size (float): Step size. Default: 1. Returns: np.ndarray of the shape is (n_sources, n_bins, n_frames). """ norm = np.linalg.norm(y, axis=1, keepdims=True) return y * np.maximum(1 - step_size / norm, 0) def test_proxbss_base() -> None: proxbss = ProxBSSBase(penalty_fn=penalty_fn, prox_penalty=prox_penalty) print(proxbss) ================================================ FILE: tests/package/bss/test_psd_legacy.py ================================================ from typing import Tuple import numpy as np import pytest from ssspy.bss._psd import to_psd from ssspy.special import add_flooring rng = np.random.default_rng(42) parameters_shape = [(5, 2, 2), (3, 3)] parameters_kwargs = [{}, {"flooring_fn": None}, {"flooring_fn": add_flooring}] @pytest.mark.parametrize("shape", parameters_shape) @pytest.mark.parametrize("kwargs", parameters_kwargs) def test_to_psd_real(shape: Tuple[int], kwargs): X = rng.standard_normal(shape) X = X @ X.swapaxes(-1, -2) X = to_psd(X, **kwargs) eigvals = np.linalg.eigvalsh(X) assert np.all(X == X.swapaxes(-1, -2)) assert np.min(eigvals) > 0 @pytest.mark.parametrize("shape", parameters_shape) @pytest.mark.parametrize("kwargs", parameters_kwargs) def test_to_psd_complex(shape: Tuple[int], kwargs): X = rng.standard_normal(shape) + 1j * rng.standard_normal(shape) X = X @ X.swapaxes(-1, -2).conj() X = to_psd(X, **kwargs) eigvals = np.linalg.eigvalsh(X) assert np.all(X == X.swapaxes(-1, -2).conj()) assert np.min(eigvals) > 0 ================================================ FILE: tests/package/bss/test_solve_permutation.py ================================================ import numpy as np import pytest from ssspy.bss._solve_permutation import correlation_based_permutation_solver rng = np.random.default_rng(0) parameters_give_demixing_filter = [True, False] @pytest.mark.parametrize("give_demixing_filter", parameters_give_demixing_filter) def test_correlation_based_permutation_solver(give_demixing_filter: bool): n_sources = 3 n_channels = n_sources n_bins, n_frames = 4, 16 shape = (n_channels, n_bins, n_frames) mixture = rng.standard_normal(shape) + 1j * rng.standard_normal(shape) shape = (n_bins, n_sources, n_channels) demix_filter = rng.standard_normal(shape) + 1j * rng.standard_normal(shape) separated = demix_filter @ mixture.transpose(1, 0, 2) with pytest.warns(UserWarning) as record: if give_demixing_filter: separated, demix_filter = correlation_based_permutation_solver(separated, demix_filter) assert demix_filter.shape == (n_bins, n_sources, n_channels) else: separated = correlation_based_permutation_solver(separated) assert separated.shape == (n_bins, n_sources, n_frames) assert len(record) == 1 assert ( str(record[0].message) == "Use ssspy.algorithm.permutation_alignment.correlation_based_permutation_solver instead." ) ================================================ FILE: tests/package/bss/test_update_spatial_model.py ================================================ from typing import Callable, Iterable, Optional, Tuple import numpy as np import pytest from ssspy.bss._update_spatial_model import ( _psd_inv, update_by_block_decomposition_vcd, update_by_ip1, update_by_ip2, update_by_ip2_one_pair, update_by_iss1, update_by_iss2, ) from ssspy.special import add_flooring, max_flooring from ssspy.utils.select_pair import combination_pair_selector, sequential_pair_selector def negative_pair_selector(n_sources): for m in range(n_sources): m, n = m % n_sources, (m + 1) % n_sources m, n = m - n_sources, n - n_sources yield m, n parameters = [(31, 20)] parameters_block_decomposition_vcd = [(15, 2, 20)] parameters_n_sources = [2, 3] parameters_flooring_fn = [max_flooring, add_flooring, None] parameters_overwrite = [True, False] parameters_singular_fn = [ lambda x: np.abs(x) < max_flooring(x), lambda x: np.abs(x) < add_flooring(x), None, ] parameters_pair_selector = [ sequential_pair_selector, combination_pair_selector, negative_pair_selector, None, ] @pytest.mark.parametrize("n_bins, n_frames", parameters) @pytest.mark.parametrize("n_sources", parameters_n_sources) @pytest.mark.parametrize("flooring_fn", parameters_flooring_fn) def test_update_by_ip1( n_bins: int, n_frames: int, n_sources: int, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]], ): n_channels = n_sources rng = np.random.default_rng(42) varphi = 1 / rng.random((n_sources, n_frames)) X = rng.standard_normal((n_channels, n_bins, n_frames)) real = rng.standard_normal((n_bins, n_sources, n_sources)) imag = rng.standard_normal((n_bins, n_sources, n_sources)) W = real + 1j * imag XX_Hermite = X[:, np.newaxis, :, :] * X[np.newaxis, :, :, :].conj() XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3) GXX = varphi[:, np.newaxis, np.newaxis, :] * XX_Hermite[:, np.newaxis, :, :, :] U = np.mean(GXX, axis=-1) W_updated = update_by_ip1(W, U, flooring_fn=flooring_fn) assert W_updated.shape == W.shape @pytest.mark.parametrize("n_bins, n_frames", parameters) @pytest.mark.parametrize("n_sources", parameters_n_sources) @pytest.mark.parametrize("flooring_fn", parameters_flooring_fn) @pytest.mark.parametrize("pair_selector", parameters_pair_selector) def test_update_by_ip2( n_bins: int, n_frames: int, n_sources: int, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]], pair_selector: Optional[Callable[[int], Iterable[Tuple[int, int]]]], ): n_channels = n_sources rng = np.random.default_rng(42) varphi = 1 / rng.random((n_sources, n_frames)) X = rng.standard_normal((n_channels, n_bins, n_frames)) real = rng.standard_normal((n_bins, n_sources, n_sources)) imag = rng.standard_normal((n_bins, n_sources, n_sources)) W = real + 1j * imag XX_Hermite = X[:, np.newaxis, :, :] * X[np.newaxis, :, :, :].conj() XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3) GXX = varphi[:, np.newaxis, np.newaxis, :] * XX_Hermite[:, np.newaxis, :, :, :] U = np.mean(GXX, axis=-1) W_updated = update_by_ip2(W, U, flooring_fn=flooring_fn, pair_selector=pair_selector) assert W_updated.shape == W.shape @pytest.mark.parametrize("n_bins, n_frames", parameters) @pytest.mark.parametrize("n_sources", parameters_n_sources) @pytest.mark.parametrize("flooring_fn", parameters_flooring_fn) def test_update_by_ip2_one_pair( n_bins: int, n_frames: int, n_sources: int, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]], ): n_channels = n_sources rng = np.random.default_rng(42) varphi = 1 / rng.random((2, n_bins, n_frames)) X = rng.standard_normal((n_channels, n_bins, n_frames)) real = rng.standard_normal((n_bins, n_sources, n_channels)) imag = rng.standard_normal((n_bins, n_sources, n_channels)) W = real + 1j * imag XX = X[:, np.newaxis] * X[np.newaxis, :].conj() GXX = np.mean(varphi[:, np.newaxis, np.newaxis, :, :] * XX[np.newaxis, :, :, :, :], axis=-1) GXX = GXX.transpose(3, 0, 1, 2) W_updated = update_by_ip2_one_pair(W, GXX, pair=(1, 0), flooring_fn=flooring_fn) assert W_updated.shape == (n_bins, 2, n_channels) @pytest.mark.parametrize("n_bins, n_frames", parameters) @pytest.mark.parametrize("n_sources", parameters_n_sources) @pytest.mark.parametrize("flooring_fn", parameters_flooring_fn) def test_update_by_iss1( n_bins: int, n_frames: int, n_sources: int, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]], ): rng = np.random.default_rng(42) varphi = 1 / rng.random((n_sources, n_bins, n_frames)) real = rng.standard_normal((n_sources, n_bins, n_frames)) imag = rng.standard_normal((n_sources, n_bins, n_frames)) Y = real + 1j * imag Y_updated = update_by_iss1(Y, varphi, flooring_fn=flooring_fn) assert Y_updated.shape == Y.shape @pytest.mark.parametrize("n_bins, n_frames", parameters) @pytest.mark.parametrize("n_sources", parameters_n_sources) @pytest.mark.parametrize("flooring_fn", parameters_flooring_fn) @pytest.mark.parametrize("pair_selector", parameters_pair_selector) def test_update_by_iss2( n_bins: int, n_frames: int, n_sources: int, flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]], pair_selector: Optional[Callable[[int], Iterable[Tuple[int, int]]]], ): rng = np.random.default_rng(42) varphi = 1 / rng.random((n_sources, n_bins, n_frames)) real = rng.standard_normal((n_sources, n_bins, n_frames)) imag = rng.standard_normal((n_sources, n_bins, n_frames)) Y = real + 1j * imag Y_updated = update_by_iss2(Y, varphi, flooring_fn=flooring_fn, pair_selector=pair_selector) assert Y_updated.shape == Y.shape @pytest.mark.parametrize("n_blocks, n_neighbors, n_frames", parameters_block_decomposition_vcd) @pytest.mark.parametrize("n_sources", parameters_n_sources) @pytest.mark.parametrize("singular_fn", parameters_singular_fn) @pytest.mark.parametrize("overwrite", parameters_overwrite) def test_update_by_block_decomposition_vcd( n_blocks: int, n_neighbors: int, n_frames: int, n_sources: int, singular_fn: Optional[Callable[[np.ndarray], np.ndarray]], overwrite: bool, ): na = np.newaxis n_channels = n_sources rng = np.random.default_rng(42) R = rng.random((n_blocks, n_neighbors, n_sources, n_channels, n_frames)) X = rng.standard_normal((n_channels, n_blocks, n_neighbors, n_frames)) real = rng.standard_normal((n_blocks, n_neighbors, n_sources, n_channels)) imag = rng.standard_normal((n_blocks, n_neighbors, n_sources, n_channels)) W = real + 1j * imag R = R[:, :, na, :, :, :] * np.eye(n_neighbors)[:, :, na, na, na] R = R[:, :, :, :, :, na, :] * np.eye(n_channels)[:, :, na] XX_Hermite = X[:, na, :, :, na] * X[na, :, :, na, :].conj() XX_Hermite = XX_Hermite.transpose(2, 3, 4, 0, 1, 5) RXX = np.mean(R * XX_Hermite[:, :, :, na], axis=-1) W_updated = update_by_block_decomposition_vcd( W, weighted_covariance=RXX, singular_fn=singular_fn, overwrite=overwrite ) assert W_updated.shape == W.shape def test_psd_inv() -> None: rng = np.random.default_rng(42) n_bins, n_frames = 129, 100 n_sources = n_channels = 4 varphi = 1 / rng.random((n_sources, n_frames)) X = rng.standard_normal((n_channels, n_bins, n_frames)) XX_Hermite = X[:, np.newaxis, :, :] * X[np.newaxis, :, :, :].conj() XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3) GXX = varphi[:, np.newaxis, np.newaxis, :] * XX_Hermite[:, np.newaxis, :, :, :] U = np.mean(GXX, axis=-1) U_inv = _psd_inv(U) eye = np.eye(n_sources) assert np.allclose(U @ U_inv, eye) ================================================ FILE: tests/package/io/test_wavread.py ================================================ import os import tempfile import numpy as np import pytest from dummy.io import save_invalid_wavfile from dummy.utils.dataset import download_ssspy_data from scipy.io import wavfile from ssspy import wavread, wavwrite parameters_frame_offset = [0, 10] parameters_num_frames = [None, 100] parameters_channels_first = [True, False, None] parameters_float = [True, False] parameters_channels = [1, 2] @pytest.mark.parametrize("frame_offset", parameters_frame_offset) @pytest.mark.parametrize("num_frames", parameters_num_frames) @pytest.mark.parametrize("channels_first", parameters_channels_first) def test_wavread_monoral(frame_offset: int, num_frames: int, channels_first: bool): path = "audio/monoral_16k_5sec.wav" filename = "./tests/mock/{}".format(path) download_ssspy_data(path, filename=filename) if channels_first is not None: return_2d = True else: return_2d = False # load file using scipy sample_rate_scipy, waveform_scipy = wavfile.read(filename) waveform_scipy = waveform_scipy / 2**15 # load file using ssspy waveform_ssspy, sample_rate_ssspy = wavread( filename, frame_offset=frame_offset, num_frames=num_frames, return_2d=return_2d, channels_first=channels_first, ) assert sample_rate_scipy == sample_rate_ssspy if return_2d: if channels_first: waveform_ssspy = waveform_ssspy.squeeze(axis=0) else: waveform_ssspy = waveform_ssspy.squeeze(axis=1) if num_frames is not None: assert np.all(waveform_scipy[frame_offset : frame_offset + num_frames] == waveform_ssspy) else: assert np.all(waveform_scipy[frame_offset:] == waveform_ssspy) @pytest.mark.parametrize("frame_offset", parameters_frame_offset) @pytest.mark.parametrize("num_frames", parameters_num_frames) @pytest.mark.parametrize("channels_first", parameters_channels_first) def test_wavread_stereo(frame_offset: int, num_frames: int, channels_first: bool): path = "audio/stereo_16k_5sec.wav" filename = "./tests/mock/{}".format(path) download_ssspy_data(path, filename=filename) # load file using scipy sample_rate_scipy, waveform_scipy = wavfile.read(filename) waveform_scipy = waveform_scipy / 2**15 # load file using ssspy waveform_ssspy, sample_rate_ssspy = wavread( filename, frame_offset=frame_offset, num_frames=num_frames, channels_first=channels_first ) assert sample_rate_scipy == sample_rate_ssspy if channels_first: # same order as that of scipy waveform_ssspy = waveform_ssspy.transpose(1, 0) if num_frames is not None: assert np.all(waveform_scipy[frame_offset : frame_offset + num_frames] == waveform_ssspy) else: assert np.all(waveform_scipy[frame_offset:] == waveform_ssspy) @pytest.mark.parametrize("frame_offset", parameters_frame_offset) def test_wavread_invalid_monoral(frame_offset: int): path = "audio/monoral_16k_5sec.wav" filename = "./tests/mock/{}".format(path) download_ssspy_data(path, filename=filename) max_frame = 5 * 16000 valid_num_frames = max_frame - frame_offset # valid data wavread(filename, frame_offset=frame_offset, num_frames=valid_num_frames) # invalid memory size invalid_num_frames = valid_num_frames + 1 with pytest.raises(ValueError) as e: wavread(filename, frame_offset=frame_offset, num_frames=invalid_num_frames) assert str(e.value) == f"num_frames={invalid_num_frames} exceeds maximum frame {max_frame}." @pytest.mark.parametrize("frame_offset", parameters_frame_offset) def test_wavread_invalid_stereo(frame_offset: int): path = "audio/stereo_16k_5sec.wav" filename = "./tests/mock/{}".format(path) download_ssspy_data(path, filename=filename) max_frame = 5 * 16000 valid_num_frames = max_frame - frame_offset # valid data wavread(filename, frame_offset=frame_offset, num_frames=valid_num_frames) # invalid memory size invalid_num_frames = valid_num_frames + 1 with pytest.raises(ValueError) as e: wavread(filename, frame_offset=frame_offset, num_frames=invalid_num_frames) assert str(e.value) == f"num_frames={invalid_num_frames} exceeds maximum frame {max_frame}." @pytest.mark.parametrize("is_float", parameters_float) def test_wavio_1d(is_float: np.dtype): rng = np.random.default_rng(0) filename = "valid.wav" sample_rate = 16000 duration = 5 bits_per_sample = 16 bytes_per_sample = bits_per_sample // 8 num_frames = sample_rate * duration vmax = 2 ** (bits_per_sample - 1) waveform = rng.integers(-vmax, vmax, size=(num_frames,), dtype=f" np.ndarray: def _sqrtm(X) -> np.ndarray: return np.stack([sqrtm(x) for x in X], axis=0) if inverse == "left": AB = np.linalg.solve(A, B) G = A @ _sqrtm(AB) elif inverse == "right": AB = np.linalg.solve(B, A) AB = AB.swapaxes(-2, -1).conj() G = _sqrtm(AB) @ B else: raise ValueError(f"Invalid inverse={inverse} is given.") return G @pytest.mark.parametrize("type", parameters_type) def test_gmean(type: int): rng = np.random.default_rng(0) size = (16, 32, 4, 1) def create_psd(): x = rng.random(size) + 1j * rng.random(size) XX = x * x.transpose(0, 1, 3, 2).conj() return np.mean(XX, axis=0) A = create_psd() B = create_psd() G1 = gmeanmh(A, B, type=type) if type == 1: assert np.allclose(G1 @ np.linalg.inv(A) @ G1, B) elif type == 2: assert np.allclose(G1 @ A @ G1, B) elif type == 3: assert np.allclose(G1 @ np.linalg.inv(A) @ G1, np.linalg.inv(B)) else: raise ValueError("Invalid type={} is given.".format(type)) if type == 2: A = np.linalg.inv(A) elif type == 3: B = np.linalg.inv(B) G2 = gmeanmh_scipy(A, B, inverse="left") G3 = gmeanmh_scipy(A, B, inverse="right") assert np.allclose(G1, G2) assert np.allclose(G1, G3) ================================================ FILE: tests/package/linalg/test_inv.py ================================================ import numpy as np import pytest from ssspy.linalg import inv2 parameters_sources = [2, 5] @pytest.mark.parametrize("n_sources", parameters_sources) def test_inv2(n_sources: int): np.random.seed(111) shape = (n_sources, 2, 2) A = np.random.randn(*shape) + 1j * np.random.randn(*shape) B = inv2(A) assert np.allclose(A @ B, np.eye(2)) ================================================ FILE: tests/package/linalg/test_lqpqm.py ================================================ import numpy as np from ssspy.linalg.lqpqm import _find_largest_root def test_find_largest_root(): alpha = np.array([-1, 1, 1, -1 + 1j]) beta = np.array([0, 1, 1, -1 - 1j]) gamma = np.array([1, 1, 2, 1]) A = -np.real(alpha + beta + gamma) B = np.real(alpha * beta + beta * gamma + gamma * alpha) C = -np.real(alpha * beta * gamma) X = _find_largest_root(A, B, C) assert np.allclose(X, gamma) ================================================ FILE: tests/package/linalg/test_polynomial.py ================================================ import numpy as np from ssspy.linalg.polynomial import _find_cubic_roots, solve_cubic def test_find_cubic_roots(): rng = np.random.default_rng(0) n_bins, n_channels = 3, 2 P = rng.standard_normal((n_bins, n_channels)) Q = rng.standard_normal((n_bins, n_channels)) X = _find_cubic_roots(P, Q) Y = X**3 + P * X + Q assert np.allclose(Y, 0) def test_solve_cubic(): rng = np.random.default_rng(0) n_bins, n_channels = 3, 2 # real coefficients A = rng.standard_normal((n_bins, n_channels)) B = rng.standard_normal((n_bins, n_channels)) C = rng.standard_normal((n_bins, n_channels)) D = rng.standard_normal((n_bins, n_channels)) X = solve_cubic(A, B, C) Y = X**3 + A * X**2 + B * X + C assert np.allclose(Y, 0) X = solve_cubic(A, B, C, D) Y = A * X**3 + B * X**2 + C * X + D assert np.allclose(Y, 0) # corner case A = np.zeros_like(C) B = np.zeros_like(C) X = solve_cubic(A, B, C) Y = X**3 + A * X**2 + B * X + C assert np.allclose(Y, 0) # complex coefficients A = rng.standard_normal((n_bins, n_channels)) + 1j * rng.standard_normal((n_bins, n_channels)) B = rng.standard_normal((n_bins, n_channels)) + 1j * rng.standard_normal((n_bins, n_channels)) C = rng.standard_normal((n_bins, n_channels)) + 1j * rng.standard_normal((n_bins, n_channels)) D = rng.standard_normal((n_bins, n_channels)) + 1j * rng.standard_normal((n_bins, n_channels)) X = solve_cubic(A, B, C) Y = X**3 + A * X**2 + B * X + C assert np.allclose(Y, 0) X = solve_cubic(A, B, C, D) Y = A * X**3 + B * X**2 + C * X + D assert np.allclose(Y, 0) # corner case A = np.zeros_like(C) B = np.zeros_like(C) X = solve_cubic(A, B, C) Y = X**3 + A * X**2 + B * X + C assert np.allclose(Y, 0) ================================================ FILE: tests/package/linalg/test_sqrtm.py ================================================ import numpy as np import pytest from ssspy.linalg import invsqrtmh, sqrtmh parameters_sources = [2] parameters_channels = [3, 4] parameters_frames = [32] parameters_is_complex = [True, False] parameters_is_flooring = [True, False] @pytest.mark.parametrize("n_sources", parameters_sources) @pytest.mark.parametrize("n_channels", parameters_channels) @pytest.mark.parametrize("n_frames", parameters_frames) @pytest.mark.parametrize("is_complex", parameters_is_complex) def test_sqrtmh(n_sources: int, n_channels: int, n_frames: int, is_complex: bool): rng = np.random.default_rng(0) shape = (n_sources, n_channels, n_frames) if is_complex: x = rng.standard_normal(shape) + 1j * rng.standard_normal(shape) X = np.mean(x[:, :, np.newaxis, :] * x[:, np.newaxis, :, :].conj(), axis=-1) else: x = rng.standard_normal(shape) X = np.mean(x[:, :, np.newaxis, :] * x[:, np.newaxis, :, :], axis=-1) X_sqrt = sqrtmh(X) assert np.allclose(X, X_sqrt @ X_sqrt) @pytest.mark.parametrize("n_sources", parameters_sources) @pytest.mark.parametrize("n_channels", parameters_channels) @pytest.mark.parametrize("n_frames", parameters_frames) @pytest.mark.parametrize("is_complex", parameters_is_complex) @pytest.mark.parametrize("is_flooring", parameters_is_flooring) def test_invsqrtmh( n_sources: int, n_channels: int, n_frames: int, is_complex: bool, is_flooring: bool ): rng = np.random.default_rng(0) shape = (n_sources, n_channels, n_frames) if is_complex: x = rng.standard_normal(shape) + 1j * rng.standard_normal(shape) X = np.mean(x[:, :, np.newaxis, :] * x[:, np.newaxis, :, :].conj(), axis=-1) else: x = rng.standard_normal(shape) X = np.mean(x[:, :, np.newaxis, :] * x[:, np.newaxis, :, :], axis=-1) if is_flooring: X_invsqrt = invsqrtmh(X, flooring_fn=lambda x: np.maximum(x, 1e-10)) else: X_invsqrt = invsqrtmh(X) X_sqrt = np.linalg.inv(X_invsqrt) assert np.allclose(X, X_sqrt @ X_sqrt) ================================================ FILE: tests/package/special/test_logsumexp.py ================================================ from typing import Optional import numpy as np import pytest import scipy.special from ssspy.special import logsumexp parameters_axis = [0, 1, (0, 2), None] parameters_keepdims = [True, False] @pytest.mark.parametrize("axis", parameters_axis) @pytest.mark.parametrize("keepdims", parameters_keepdims) def test_logsumexp(axis: Optional[int], keepdims: bool): rng = np.random.default_rng(0) n_sources, n_channels = 4, 3 n_frames = 8 shape = (n_sources, n_frames, n_channels, n_channels) X = rng.random(shape) Y = logsumexp(X, axis=axis, keepdims=keepdims) Y_scipy = scipy.special.logsumexp(X, axis=axis, keepdims=keepdims) assert np.allclose(Y, Y_scipy) ================================================ FILE: tests/package/special/test_psd.py ================================================ from typing import Tuple import numpy as np import pytest from ssspy.special import add_flooring, to_psd rng = np.random.default_rng(42) parameters_shape = [(5, 2, 2), (3, 3)] parameters_kwargs = [{}, {"flooring_fn": None}, {"flooring_fn": add_flooring}] @pytest.mark.parametrize("shape", parameters_shape) @pytest.mark.parametrize("kwargs", parameters_kwargs) def test_to_psd_real(shape: Tuple[int], kwargs): X = rng.standard_normal(shape) X = X @ X.swapaxes(-1, -2) X = to_psd(X, **kwargs) eigvals = np.linalg.eigvalsh(X) assert np.all(X == X.swapaxes(-1, -2)) assert np.min(eigvals) > 0 @pytest.mark.parametrize("shape", parameters_shape) @pytest.mark.parametrize("kwargs", parameters_kwargs) def test_to_psd_complex(shape: Tuple[int], kwargs): X = rng.standard_normal(shape) + 1j * rng.standard_normal(shape) X = X @ X.swapaxes(-1, -2).conj() X = to_psd(X, **kwargs) eigvals = np.linalg.eigvalsh(X) assert np.all(X == X.swapaxes(-1, -2).conj()) assert np.min(eigvals) > 0 ================================================ FILE: tests/package/special/test_softmax.py ================================================ from typing import Optional import numpy as np import pytest import scipy.special from ssspy.special import softmax parameters_axis = [0, 1, (0, 2), None] @pytest.mark.parametrize("axis", parameters_axis) def test_logsumexp(axis: Optional[int]): rng = np.random.default_rng(0) n_sources, n_channels = 4, 3 n_frames = 8 shape = (n_sources, n_frames, n_channels, n_channels) X = rng.random(shape) Y = softmax(X, axis=axis) Y_scipy = scipy.special.softmax(X, axis=axis) assert np.allclose(Y, Y_scipy) ================================================ FILE: tests/package/transform/test_pca.py ================================================ import numpy as np import pytest from ssspy.transform import pca parameters_ascend = [True, False] parameters_batch_size = [1, 4] parameters_n_channels = [2, 3] parameters_pca_real = [10, 20] parameters_pca_complex = [(257, 8), (65, 12)] @pytest.mark.parametrize("ascend", parameters_ascend) @pytest.mark.parametrize("n_channels", parameters_n_channels) @pytest.mark.parametrize("n_samples", parameters_pca_real) def test_pca_real_2d(ascend: bool, n_channels: int, n_samples: int): np.random.seed(111) input = np.random.randn(n_channels, n_samples) output = pca(input, ascend=ascend) assert input.shape == output.shape covariance = output[:, np.newaxis, :] * output[np.newaxis, :, :] covariance = np.mean(covariance, axis=-1) mask = 1 - np.eye(n_channels) zero = np.zeros((n_channels, n_channels)) assert np.allclose(mask * covariance, zero) @pytest.mark.parametrize("ascend", parameters_ascend) @pytest.mark.parametrize("batch_size", parameters_batch_size) @pytest.mark.parametrize("n_channels", parameters_n_channels) @pytest.mark.parametrize("n_samples", parameters_pca_real) def test_pca_real_3d(ascend: bool, batch_size: int, n_channels: int, n_samples: int): np.random.seed(111) input = np.random.randn(batch_size, n_channels, n_samples) output = pca(input, ascend=ascend) assert input.shape == output.shape covariance = output[:, :, np.newaxis, :] * output[:, np.newaxis, :, :] covariance = np.mean(covariance, axis=-1) mask = 1 - np.eye(n_channels) zero = np.zeros((batch_size, n_channels, n_channels)) assert np.allclose(mask * covariance, zero) @pytest.mark.parametrize("ascend", parameters_ascend) @pytest.mark.parametrize("n_channels", parameters_n_channels) @pytest.mark.parametrize("n_bins, n_frames", parameters_pca_complex) def test_pca_complex_3d(ascend: bool, n_channels: int, n_bins: int, n_frames: int): np.random.seed(111) real = np.random.randn(n_channels, n_bins, n_frames) imag = np.random.randn(n_channels, n_bins, n_frames) input = real + 1j * imag output = pca(input, ascend=ascend) assert input.shape == output.shape covariance = output[:, np.newaxis, :, :] * output[np.newaxis, :, :, :].conj() covariance = np.mean(covariance, axis=-1) covariance = covariance.transpose(2, 0, 1) mask = 1 - np.eye(n_channels) zero = np.zeros((n_bins, n_channels, n_channels)) assert np.allclose(mask * covariance, zero) @pytest.mark.parametrize("ascend", parameters_ascend) @pytest.mark.parametrize("batch_size", parameters_batch_size) @pytest.mark.parametrize("n_channels", parameters_n_channels) @pytest.mark.parametrize("n_bins, n_frames", parameters_pca_complex) def test_pca_complex_4d(ascend: bool, batch_size: int, n_channels: int, n_bins: int, n_frames: int): np.random.seed(111) real = np.random.randn(batch_size, n_channels, n_bins, n_frames) imag = np.random.randn(batch_size, n_channels, n_bins, n_frames) input = real + 1j * imag output = pca(input, ascend=ascend) assert input.shape == output.shape covariance = output[:, :, np.newaxis, :, :] * output[:, np.newaxis, :, :, :].conj() covariance = np.mean(covariance, axis=-1) covariance = covariance.transpose(0, 3, 1, 2) mask = 1 - np.eye(n_channels) zero = np.zeros((batch_size, n_bins, n_channels, n_channels)) assert np.allclose(mask * covariance, zero) ================================================ FILE: tests/package/transform/test_whiten.py ================================================ import numpy as np import pytest from ssspy.transform import whiten parameters_batch_size = [1, 4] parameters_n_channels = [2, 3] parameters_whiten_real = [10, 20] parameters_whiten_complex = [(2049, 8), (513, 12)] @pytest.mark.parametrize("n_channels", parameters_n_channels) @pytest.mark.parametrize("n_samples", parameters_whiten_real) def test_whiten_real_2d(n_channels: int, n_samples: int): np.random.seed(111) input = np.random.randn(n_channels, n_samples) output = whiten(input) assert input.shape == output.shape covariance = output[:, np.newaxis, :] * output[np.newaxis, :, :] covariance = np.mean(covariance, axis=-1) eye = np.eye(n_channels) assert np.allclose(covariance, eye) @pytest.mark.parametrize("batch_size", parameters_batch_size) @pytest.mark.parametrize("n_channels", parameters_n_channels) @pytest.mark.parametrize("n_samples", parameters_whiten_real) def test_whiten_real_3d(batch_size: int, n_channels: int, n_samples: int): np.random.seed(111) input = np.random.randn(batch_size, n_channels, n_samples) output = whiten(input) assert input.shape == output.shape covariance = output[:, :, np.newaxis, :] * output[:, np.newaxis, :, :] covariance = np.mean(covariance, axis=-1) eye = np.eye(n_channels) assert np.allclose(covariance, eye) @pytest.mark.parametrize("n_channels", parameters_n_channels) @pytest.mark.parametrize("n_bins, n_frames", parameters_whiten_complex) def test_whiten_complex_3d(n_channels: int, n_bins: int, n_frames: int): np.random.seed(111) real = np.random.randn(n_channels, n_bins, n_frames) imag = np.random.randn(n_channels, n_bins, n_frames) input = real + 1j * imag output = whiten(input) assert input.shape == output.shape covariance = output[:, np.newaxis, :, :] * output[np.newaxis, :, :, :].conj() covariance = np.mean(covariance, axis=-1) covariance = covariance.transpose(2, 0, 1) eye = np.eye(n_channels) eye = np.tile(eye, reps=(n_bins, 1, 1)) assert np.allclose(covariance, eye) @pytest.mark.parametrize("batch_size", parameters_batch_size) @pytest.mark.parametrize("n_channels", parameters_n_channels) @pytest.mark.parametrize("n_bins, n_frames", parameters_whiten_complex) def test_whiten_complex_4d(batch_size: int, n_channels: int, n_bins: int, n_frames: int): np.random.seed(111) real = np.random.randn(batch_size, n_channels, n_bins, n_frames) imag = np.random.randn(batch_size, n_channels, n_bins, n_frames) input = real + 1j * imag output = whiten(input) assert input.shape == output.shape covariance = output[:, :, np.newaxis, :, :] * output[:, np.newaxis, :, :, :].conj() covariance = np.mean(covariance, axis=-1) covariance = covariance.transpose(0, 3, 1, 2) eye = np.eye(n_channels) eye = np.tile(eye, reps=(batch_size, n_bins, 1, 1)) assert np.allclose(covariance, eye) ================================================ FILE: tests/package/utils/test_dataset.py ================================================ import pytest from ssspy.utils.dataset import download_sample_speech_data parameters_dataset = [ (2, "dev1_female3"), (3, "dev1_female3"), (4, "dev1_female4"), ] parameters_max_duration = [1.2] parameters_conv = [True, False] @pytest.mark.parametrize("n_sources, sisec2010_tag", parameters_dataset) @pytest.mark.parametrize("max_duration", parameters_max_duration) @pytest.mark.parametrize("conv", parameters_conv) def test_conv_dataset(n_sources: int, sisec2010_tag: str, max_duration: int, conv: bool): waveform_src_img, sample_rate = download_sample_speech_data( sisec2010_root="./tests/.data/SiSEC2010", mird_root="./tests/.data/MIRD", n_sources=n_sources, sisec2010_tag=sisec2010_tag, max_duration=max_duration, conv=conv, ) n_channels = n_sources assert waveform_src_img.shape == (n_channels, n_sources, int(sample_rate * max_duration)) ================================================ FILE: tests/package/utils/test_select_pair.py ================================================ import pytest from ssspy.utils.select_pair import combination_pair_selector, sequential_pair_selector parameters_n_sources = [2, 3, 4] parameters_step = [1, 2] parameters_ascend = [True, False] @pytest.mark.parametrize("n_sources", parameters_n_sources) @pytest.mark.parametrize("step", parameters_step) @pytest.mark.parametrize("ascend", parameters_ascend) def test_sequential_pair_selector(n_sources: int, step: int, ascend: bool): for m, n in sequential_pair_selector(n_sources, step=step, sort=ascend): if ascend: assert m < n @pytest.mark.parametrize("n_sources", parameters_n_sources) @pytest.mark.parametrize("ascend", parameters_ascend) def test_combination_pair_selector(n_sources: int, ascend: bool): for m, n in combination_pair_selector(n_sources, sort=ascend): if ascend: assert m < n ================================================ FILE: tests/regression/bss/test_cacgmm.py ================================================ import sys from os import makedirs from os.path import dirname, join, realpath import numpy as np from ssspy.bss.cacgmm import CACGMM ssspy_tests_dir = dirname(dirname(dirname(realpath(__file__)))) sys.path.append(ssspy_tests_dir) from dummy.utils.dataset import load_regression_data # noqa: E402 cacgmm_root = join(ssspy_tests_dir, "mock", "regression", "bss", "cacgmm") def test_cacgmm(save_feature: bool = False): rng = np.random.default_rng(0) if save_feature: (npz_input,) = load_regression_data(root=cacgmm_root, filenames=["input.npz"]) spectrogram_tgt = None n_iter = 10 else: npz_input, npz_target = load_regression_data( root=cacgmm_root, filenames=["input.npz", "target.npz"] ) spectrogram_tgt = npz_target["spectrogram"] n_iter = npz_target["n_iter"].item() spectrogram_mix = npz_input["spectrogram"] cacgmm = CACGMM(rng=rng) spectrogram_est = cacgmm(spectrogram_mix, n_iter=n_iter) if save_feature: makedirs(cacgmm_root, exist_ok=True) np.savez( join(cacgmm_root, "target.npz"), spectrogram=spectrogram_est, n_iter=n_iter, ) else: assert np.allclose(spectrogram_est, spectrogram_tgt, atol=1e-7), np.max( np.abs(spectrogram_est - spectrogram_tgt) ) def save_all_features() -> None: test_cacgmm(save_feature=True) if __name__ == "__main__": save_all_features() ================================================ FILE: tests/regression/bss/test_fdica.py ================================================ import sys from os import makedirs from os.path import dirname, join, realpath import numpy as np import pytest from ssspy.bss.fdica import AuxLaplaceFDICA, GradLaplaceFDICA, NaturalGradLaplaceFDICA ssspy_tests_dir = dirname(dirname(dirname(realpath(__file__)))) sys.path.append(ssspy_tests_dir) from dummy.utils.dataset import load_regression_data # noqa: E402 fdica_root = join(ssspy_tests_dir, "mock", "regression", "bss", "fdica") n_sources = 2 parameters_is_holonomic = [True, False] parameters_spatial_algorithm = ["IP1", "IP2"] @pytest.mark.parametrize("is_holonomic", parameters_is_holonomic) def test_grad_laplace_fdica(is_holonomic: bool, save_feature: bool = False): if is_holonomic: root = join(fdica_root, "grad_laplace_fdica", "holonomic") else: root = join(fdica_root, "grad_laplace_fdica", "nonholonomic") if save_feature: (npz_input,) = load_regression_data(root=root, filenames=["input.npz"]) spectrogram_tgt = None n_iter = 10 else: npz_input, npz_target = load_regression_data( root=root, filenames=["input.npz", "target.npz"] ) spectrogram_tgt = npz_target["spectrogram"] n_iter = npz_target["n_iter"].item() spectrogram_mix = npz_input["spectrogram"] fdica = GradLaplaceFDICA(is_holonomic=is_holonomic) spectrogram_est = fdica(spectrogram_mix, n_iter=n_iter) if save_feature: np.savez( join(root, "target.npz"), spectrogram=spectrogram_est, n_iter=n_iter, ) else: assert np.allclose(spectrogram_est, spectrogram_tgt, atol=1e-7), np.max( np.abs(spectrogram_est - spectrogram_tgt) ) @pytest.mark.parametrize("is_holonomic", parameters_is_holonomic) def test_natural_grad_laplace_fdica(is_holonomic: bool, save_feature: bool = False): if is_holonomic: root = join(fdica_root, "natural_grad_laplace_fdica", "holonomic") else: root = join(fdica_root, "natural_grad_laplace_fdica", "nonholonomic") if save_feature: (npz_input,) = load_regression_data(root=root, filenames=["input.npz"]) spectrogram_tgt = None n_iter = 10 else: npz_input, npz_target = load_regression_data( root=root, filenames=["input.npz", "target.npz"] ) spectrogram_tgt = npz_target["spectrogram"] n_iter = npz_target["n_iter"].item() spectrogram_mix = npz_input["spectrogram"] fdica = NaturalGradLaplaceFDICA(is_holonomic=is_holonomic) spectrogram_est = fdica(spectrogram_mix, n_iter=n_iter) if save_feature: makedirs(root, exist_ok=True) np.savez( join(root, "target.npz"), spectrogram=spectrogram_est, n_iter=n_iter, ) else: assert np.allclose(spectrogram_est, spectrogram_tgt, atol=1e-7), np.max( np.abs(spectrogram_est - spectrogram_tgt) ) @pytest.mark.parametrize("spatial_algorithm", parameters_spatial_algorithm) def test_aux_laplace_fdica(spatial_algorithm: str, save_feature: bool = False): root = join(fdica_root, "aux_laplace_fdica", spatial_algorithm) if save_feature: (npz_input,) = load_regression_data(root=root, filenames=["input.npz"]) spectrogram_tgt = None n_iter = 10 else: npz_input, npz_target = load_regression_data( root=root, filenames=["input.npz", "target.npz"] ) spectrogram_tgt = npz_target["spectrogram"] n_iter = npz_target["n_iter"].item() spectrogram_mix = npz_input["spectrogram"] fdica = AuxLaplaceFDICA(spatial_algorithm=spatial_algorithm) spectrogram_est = fdica(spectrogram_mix, n_iter=n_iter) if save_feature: makedirs(root, exist_ok=True) np.savez( join(root, "target.npz"), spectrogram=spectrogram_est, n_iter=n_iter, ) else: assert np.allclose(spectrogram_est, spectrogram_tgt, atol=1e-7), np.max( np.abs(spectrogram_est - spectrogram_tgt) ) def save_all_features() -> None: for is_holonomic in parameters_is_holonomic: test_grad_laplace_fdica(is_holonomic=is_holonomic, save_feature=True) for is_holonomic in parameters_is_holonomic: test_natural_grad_laplace_fdica(is_holonomic=is_holonomic, save_feature=True) for spatial_algorithm in parameters_spatial_algorithm: test_aux_laplace_fdica(spatial_algorithm=spatial_algorithm, save_feature=True) if __name__ == "__main__": save_all_features() ================================================ FILE: tests/regression/bss/test_ilrma.py ================================================ import sys from os import makedirs from os.path import dirname, join, realpath import numpy as np import pytest from ssspy.bss.ilrma import GGDILRMA, TILRMA, GaussILRMA ssspy_tests_dir = dirname(dirname(dirname(realpath(__file__)))) sys.path.append(ssspy_tests_dir) from dummy.utils.dataset import load_regression_data # noqa: E402 ilrma_root = join(ssspy_tests_dir, "mock", "regression", "bss", "ilrma") parameters_spatial_algorithm = ["IP1", "IP2", "ISS1", "ISS2", "IPA"] parameters_source_algorithm = ["MM", "ME"] @pytest.mark.parametrize("spatial_algorithm", parameters_spatial_algorithm) @pytest.mark.parametrize("source_algorithm", parameters_source_algorithm) def test_gauss_ilrma(spatial_algorithm: str, source_algorithm: str, save_feature: bool = False): rng = np.random.default_rng(0) root = join(ilrma_root, "gauss_ilrma", spatial_algorithm, source_algorithm) if save_feature: (npz_input,) = load_regression_data(root=root, filenames=["input.npz"]) spectrogram_tgt = None n_basis = 2 n_iter = 10 else: npz_input, npz_target = load_regression_data( root=root, filenames=["input.npz", "target.npz"] ) spectrogram_tgt = npz_target["spectrogram"] n_basis = npz_target["n_basis"].item() n_iter = npz_target["n_iter"].item() spectrogram_mix = npz_input["spectrogram"] if save_feature: n_sources, n_bins, n_frames = spectrogram_mix.shape basis = rng.random((n_sources, n_bins, n_basis)) activation = rng.random((n_sources, n_basis, n_frames)) else: basis = npz_target["basis"] activation = npz_target["activation"] ilrma = GaussILRMA( n_basis=n_basis, spatial_algorithm=spatial_algorithm, source_algorithm=source_algorithm, rng=rng, ) spectrogram_est = ilrma( spectrogram_mix, n_iter=n_iter, basis=basis, activation=activation, ) if save_feature: makedirs(root, exist_ok=True) np.savez( join(root, "target.npz"), spectrogram=spectrogram_est, basis=basis, activation=activation, n_basis=n_basis, n_iter=n_iter, ) else: assert np.allclose(spectrogram_est, spectrogram_tgt, atol=1e-7), np.max( np.abs(spectrogram_est - spectrogram_tgt) ) @pytest.mark.parametrize("spatial_algorithm", parameters_spatial_algorithm) @pytest.mark.parametrize("source_algorithm", parameters_source_algorithm) def test_t_ilrma(spatial_algorithm: str, source_algorithm: str, save_feature: bool = False): if spatial_algorithm == "IPA": pytest.skip(reason="IPA is not supported for TILRMA.") rng = np.random.default_rng(0) root = join(ilrma_root, "t_ilrma", spatial_algorithm, source_algorithm) if save_feature: (npz_input,) = load_regression_data(root=root, filenames=["input.npz"]) spectrogram_tgt = None n_basis = 2 dof = 1000 n_iter = 10 else: npz_input, npz_target = load_regression_data( root=root, filenames=["input.npz", "target.npz"] ) spectrogram_tgt = npz_target["spectrogram"] n_basis = npz_target["n_basis"].item() dof = npz_target["dof"].item() n_iter = npz_target["n_iter"].item() spectrogram_mix = npz_input["spectrogram"] if save_feature: n_sources, n_bins, n_frames = spectrogram_mix.shape basis = rng.random((n_sources, n_bins, n_basis)) activation = rng.random((n_sources, n_basis, n_frames)) else: basis = npz_target["basis"] activation = npz_target["activation"] ilrma = TILRMA( n_basis=n_basis, dof=dof, spatial_algorithm=spatial_algorithm, source_algorithm=source_algorithm, rng=rng, ) spectrogram_est = ilrma( spectrogram_mix, n_iter=n_iter, basis=basis, activation=activation, ) if save_feature: makedirs(root, exist_ok=True) np.savez( join(root, "target.npz"), spectrogram=spectrogram_est, basis=basis, activation=activation, n_basis=n_basis, dof=dof, n_iter=n_iter, ) else: assert np.allclose(spectrogram_est, spectrogram_tgt, atol=1e-7), np.max( np.abs(spectrogram_est - spectrogram_tgt) ) @pytest.mark.parametrize("spatial_algorithm", parameters_spatial_algorithm) @pytest.mark.parametrize("source_algorithm", parameters_source_algorithm) def test_ggd_ilrma(spatial_algorithm: str, source_algorithm: str, save_feature: bool = False): if spatial_algorithm == "IPA": pytest.skip(reason="IPA is not supported for GGDILRMA.") if source_algorithm == "ME": pytest.skip(reason="ME is not supported for GGDILRMA.") rng = np.random.default_rng(0) root = join(ilrma_root, "ggd_ilrma", spatial_algorithm, source_algorithm) if save_feature: (npz_input,) = load_regression_data(root=root, filenames=["input.npz"]) spectrogram_tgt = None n_basis = 2 beta = 1.5 n_iter = 10 else: npz_input, npz_target = load_regression_data( root=root, filenames=["input.npz", "target.npz"] ) spectrogram_tgt = npz_target["spectrogram"] n_basis = npz_target["n_basis"].item() beta = npz_target["beta"].item() n_iter = npz_target["n_iter"].item() spectrogram_mix = npz_input["spectrogram"] if save_feature: n_sources, n_bins, n_frames = spectrogram_mix.shape basis = rng.random((n_sources, n_bins, n_basis)) activation = rng.random((n_sources, n_basis, n_frames)) else: basis = npz_target["basis"] activation = npz_target["activation"] ilrma = GGDILRMA( n_basis=n_basis, beta=beta, spatial_algorithm=spatial_algorithm, source_algorithm=source_algorithm, rng=rng, ) spectrogram_est = ilrma( spectrogram_mix, n_iter=n_iter, basis=basis, activation=activation, ) if save_feature: makedirs(root, exist_ok=True) np.savez( join(root, "target.npz"), spectrogram=spectrogram_est, basis=basis, activation=activation, n_basis=n_basis, beta=beta, n_iter=n_iter, ) else: assert np.allclose(spectrogram_est, spectrogram_tgt, atol=1e-7), np.max( np.abs(spectrogram_est - spectrogram_tgt) ) def save_all_features() -> None: for spatial_algorithm in parameters_spatial_algorithm: for source_algorithm in parameters_source_algorithm: test_gauss_ilrma( spatial_algorithm=spatial_algorithm, source_algorithm=source_algorithm, save_feature=True, ) for spatial_algorithm in parameters_spatial_algorithm: if spatial_algorithm == "IPA": continue for source_algorithm in parameters_source_algorithm: test_t_ilrma( spatial_algorithm=spatial_algorithm, source_algorithm=source_algorithm, save_feature=True, ) for spatial_algorithm in parameters_spatial_algorithm: if spatial_algorithm == "IPA": continue for source_algorithm in parameters_source_algorithm: if source_algorithm == "ME": continue test_ggd_ilrma( spatial_algorithm=spatial_algorithm, source_algorithm=source_algorithm, save_feature=True, ) if __name__ == "__main__": save_all_features() ================================================ FILE: tests/regression/bss/test_ipsdta.py ================================================ import sys from os import makedirs from os.path import dirname, join, realpath import numpy as np import pytest from ssspy.bss.ipsdta import TIPSDTA, GaussIPSDTA ssspy_tests_dir = dirname(dirname(dirname(realpath(__file__)))) sys.path.append(ssspy_tests_dir) from dummy.utils.dataset import load_regression_data # noqa: E402 ipsdta_root = join(ssspy_tests_dir, "mock", "regression", "bss", "ipsdta") parameters_spatial_algorithm = ["VCD"] parameters_source_algorithm = ["EM", "MM"] @pytest.mark.parametrize("spatial_algorithm", parameters_spatial_algorithm) @pytest.mark.parametrize("source_algorithm", parameters_source_algorithm) def test_gauss_ipsdta(spatial_algorithm: str, source_algorithm: str, save_feature: bool = False): if source_algorithm == "EM": pytest.skip(reason="EM is not supported for GaussIPSDTA.") rng = np.random.default_rng(0) root = join(ipsdta_root, "gauss_ipsdta", spatial_algorithm, source_algorithm) if save_feature: (npz_input,) = load_regression_data(root=root, filenames=["input.npz"]) spectrogram_tgt = None n_basis = 2 n_iter = 10 else: npz_input, npz_target = load_regression_data( root=root, filenames=["input.npz", "target.npz"] ) spectrogram_tgt = npz_target["spectrogram"] n_basis = npz_target["n_basis"].item() n_iter = npz_target["n_iter"].item() spectrogram_mix = npz_input["spectrogram"] if save_feature: n_blocks = spectrogram_mix.shape[1] // 2 n_sources, n_bins, n_frames = spectrogram_mix.shape n_neighbors = n_bins // n_blocks n_remains = n_bins % n_blocks eye = np.eye(n_neighbors, dtype=np.complex128) rand = rng.random((n_sources, n_basis, n_blocks - n_remains, n_neighbors)) T = rand[..., np.newaxis] * eye if n_remains > 0: eye = np.eye(n_neighbors + 1, dtype=np.complex128) rand = rng.random((n_sources, n_basis, n_remains, n_neighbors + 1)) T_high = rand[..., np.newaxis] * eye T = T, T_high V = rng.random((n_sources, n_basis, n_frames)) basis = T activation = V else: n_blocks = npz_target["n_blocks"].item() if "basis" in npz_target.keys(): basis = npz_target["basis"] else: basis_low = npz_target["basis_low"] basis_high = npz_target["basis_high"] basis = basis_low, basis_high activation = npz_target["activation"] ipsdta = GaussIPSDTA( n_basis=n_basis, n_blocks=n_blocks, spatial_algorithm=spatial_algorithm, source_algorithm=source_algorithm, rng=rng, ) spectrogram_est = ipsdta( spectrogram_mix, n_iter=n_iter, basis=basis, activation=activation, ) if isinstance(basis, tuple): basis_low, basis_high = basis basis = { "basis_low": basis_low, "basis_high": basis_high, } else: basis = { "basis": basis, } if save_feature: makedirs(root, exist_ok=True) np.savez( join(root, "target.npz"), spectrogram=spectrogram_est, **basis, activation=activation, n_basis=n_basis, n_blocks=n_blocks, n_iter=n_iter, ) else: assert np.allclose(spectrogram_est, spectrogram_tgt, atol=1e-7), np.max( np.abs(spectrogram_est - spectrogram_tgt) ) @pytest.mark.parametrize("spatial_algorithm", parameters_spatial_algorithm) @pytest.mark.parametrize("source_algorithm", parameters_source_algorithm) def test_t_ipsdta(spatial_algorithm: str, source_algorithm: str, save_feature: bool = False): if source_algorithm == "EM": pytest.skip(reason="EM is not supported for TIPSDTA.") rng = np.random.default_rng(0) root = join(ipsdta_root, "t_ipsdta", spatial_algorithm, source_algorithm) if save_feature: (npz_input,) = load_regression_data(root=root, filenames=["input.npz"]) spectrogram_tgt = None n_basis = 2 dof = 1000 n_iter = 10 else: npz_input, npz_target = load_regression_data( root=root, filenames=["input.npz", "target.npz"] ) spectrogram_tgt = npz_target["spectrogram"] n_basis = npz_target["n_basis"].item() dof = npz_target["dof"].item() n_iter = npz_target["n_iter"].item() spectrogram_mix = npz_input["spectrogram"] if save_feature: n_blocks = spectrogram_mix.shape[1] // 2 n_sources, n_bins, n_frames = spectrogram_mix.shape n_neighbors = n_bins // n_blocks n_remains = n_bins % n_blocks eye = np.eye(n_neighbors, dtype=np.complex128) rand = rng.random((n_sources, n_basis, n_blocks - n_remains, n_neighbors)) T = rand[..., np.newaxis] * eye if n_remains > 0: eye = np.eye(n_neighbors + 1, dtype=np.complex128) rand = rng.random((n_sources, n_basis, n_remains, n_neighbors + 1)) T_high = rand[..., np.newaxis] * eye T = T, T_high V = rng.random((n_sources, n_basis, n_frames)) basis = T activation = V else: n_blocks = npz_target["n_blocks"].item() if "basis" in npz_target.keys(): basis = npz_target["basis"] else: basis_low = npz_target["basis_low"] basis_high = npz_target["basis_high"] basis = basis_low, basis_high activation = npz_target["activation"] ipsdta = TIPSDTA( n_basis=n_basis, n_blocks=n_blocks, dof=dof, spatial_algorithm=spatial_algorithm, source_algorithm=source_algorithm, rng=rng, ) spectrogram_est = ipsdta( spectrogram_mix, n_iter=n_iter, basis=basis, activation=activation, ) if isinstance(basis, tuple): basis_low, basis_high = basis basis = { "basis_low": basis_low, "basis_high": basis_high, } else: basis = { "basis": basis, } if save_feature: makedirs(root, exist_ok=True) np.savez( join(root, "target.npz"), spectrogram=spectrogram_est, **basis, activation=activation, n_basis=n_basis, n_blocks=n_blocks, dof=dof, n_iter=n_iter, ) else: assert np.allclose(spectrogram_est, spectrogram_tgt, atol=1e-7), np.max( np.abs(spectrogram_est - spectrogram_tgt) ) def save_all_features() -> None: for spatial_algorithm in parameters_spatial_algorithm: for source_algorithm in parameters_source_algorithm: if source_algorithm == "EM": continue test_gauss_ipsdta( spatial_algorithm=spatial_algorithm, source_algorithm=source_algorithm, save_feature=True, ) for spatial_algorithm in parameters_spatial_algorithm: for source_algorithm in parameters_source_algorithm: if source_algorithm == "EM": continue test_t_ipsdta( spatial_algorithm=spatial_algorithm, source_algorithm=source_algorithm, save_feature=True, ) if __name__ == "__main__": save_all_features() ================================================ FILE: tests/regression/bss/test_iva.py ================================================ import sys from os import makedirs from os.path import dirname, join, realpath import numpy as np import pytest from ssspy.bss.iva import AuxIVA, FastIVA, GradIVA, NaturalGradIVA ssspy_tests_dir = dirname(dirname(dirname(realpath(__file__)))) sys.path.append(ssspy_tests_dir) from dummy.utils.dataset import load_regression_data # noqa: E402 iva_root = join(ssspy_tests_dir, "mock", "regression", "bss", "iva") parameters_is_holonomic = [True, False] parameters_spatial_algorithm = ["IP1", "IP2", "ISS1", "ISS2", "IPA"] @pytest.mark.parametrize("is_holonomic", parameters_is_holonomic) def test_grad_iva(is_holonomic: bool, save_feature: bool = False): if is_holonomic: root = join(iva_root, "grad_iva", "holonomic") else: root = join(iva_root, "grad_iva", "nonholonomic") if save_feature: (npz_input,) = load_regression_data(root=root, filenames=["input.npz"]) spectrogram_tgt = None n_iter = 10 else: npz_input, npz_target = load_regression_data( root=root, filenames=["input.npz", "target.npz"] ) spectrogram_tgt = npz_target["spectrogram"] n_iter = npz_target["n_iter"].item() spectrogram_mix = npz_input["spectrogram"] def contrast_fn(y: np.ndarray) -> np.ndarray: r"""Contrast function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: np.ndarray: The shape is (n_sources, n_frames). """ return 2 * np.linalg.norm(y, axis=1) def score_fn(y) -> np.ndarray: r"""Score function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: np.ndarray: The shape is (n_sources, n_bins, n_frames). """ norm = np.linalg.norm(y, axis=1, keepdims=True) norm = np.maximum(norm, 1e-10) return y / norm iva = GradIVA( contrast_fn=contrast_fn, score_fn=score_fn, is_holonomic=is_holonomic, ) spectrogram_est = iva(spectrogram_mix, n_iter=n_iter) if save_feature: makedirs(root, exist_ok=True) np.savez( join(root, "target.npz"), spectrogram=spectrogram_est, n_iter=n_iter, ) else: assert np.allclose(spectrogram_est, spectrogram_tgt, atol=1e-7), np.max( np.abs(spectrogram_est - spectrogram_tgt) ) @pytest.mark.parametrize("is_holonomic", parameters_is_holonomic) def test_natural_grad_iva(is_holonomic: bool, save_feature: bool = False): if is_holonomic: root = join(iva_root, "natural_grad_iva", "holonomic") else: root = join(iva_root, "natural_grad_iva", "nonholonomic") if save_feature: (npz_input,) = load_regression_data(root=root, filenames=["input.npz"]) spectrogram_tgt = None n_iter = 10 else: npz_input, npz_target = load_regression_data( root=root, filenames=["input.npz", "target.npz"] ) spectrogram_tgt = npz_target["spectrogram"] n_iter = npz_target["n_iter"].item() spectrogram_mix = npz_input["spectrogram"] def contrast_fn(y: np.ndarray) -> np.ndarray: r"""Contrast function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: np.ndarray: The shape is (n_sources, n_frames). """ return 2 * np.linalg.norm(y, axis=1) def score_fn(y) -> np.ndarray: r"""Score function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: np.ndarray: The shape is (n_sources, n_bins, n_frames). """ norm = np.linalg.norm(y, axis=1, keepdims=True) norm = np.maximum(norm, 1e-10) return y / norm iva = NaturalGradIVA( contrast_fn=contrast_fn, score_fn=score_fn, is_holonomic=is_holonomic, ) spectrogram_est = iva(spectrogram_mix, n_iter=n_iter) if save_feature: makedirs(root, exist_ok=True) np.savez( join(root, "target.npz"), spectrogram=spectrogram_est, n_iter=n_iter, ) else: assert np.allclose(spectrogram_est, spectrogram_tgt, atol=1e-7), np.max( np.abs(spectrogram_est - spectrogram_tgt) ) @pytest.mark.parametrize("spatial_algorithm", parameters_spatial_algorithm) def test_aux_iva(spatial_algorithm: str, save_feature: bool = False): root = join(iva_root, "aux_iva", spatial_algorithm) if save_feature: (npz_input,) = load_regression_data(root=root, filenames=["input.npz"]) spectrogram_tgt = None n_iter = 10 else: npz_input, npz_target = load_regression_data( root=root, filenames=["input.npz", "target.npz"] ) spectrogram_tgt = npz_target["spectrogram"] n_iter = npz_target["n_iter"].item() spectrogram_mix = npz_input["spectrogram"] def contrast_fn(y: np.ndarray) -> np.ndarray: r"""Contrast function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: np.ndarray: The shape is (n_sources, n_frames). """ return 2 * np.linalg.norm(y, axis=1) def d_contrast_fn(y) -> np.ndarray: r"""Derivative of contrast function. Args: y (np.ndarray): The shape is (n_sources, n_frames). Returns: np.ndarray: The shape is (n_sources, n_frames). """ return 2 * np.ones_like(y) iva = AuxIVA( spatial_algorithm=spatial_algorithm, contrast_fn=contrast_fn, d_contrast_fn=d_contrast_fn, ) spectrogram_est = iva(spectrogram_mix, n_iter=n_iter) if save_feature: makedirs(root, exist_ok=True) np.savez( join(root, "target.npz"), spectrogram=spectrogram_est, n_iter=n_iter, ) else: assert np.allclose(spectrogram_est, spectrogram_tgt, atol=1e-7), np.max( np.abs(spectrogram_est - spectrogram_tgt) ) def test_fast_iva(save_feature: bool = False): root = join(iva_root, "fast_iva") if save_feature: (npz_input,) = load_regression_data(root=root, filenames=["input.npz"]) spectrogram_tgt = None n_iter = 5 else: npz_input, npz_target = load_regression_data( root=root, filenames=["input.npz", "target.npz"] ) spectrogram_tgt = npz_target["spectrogram"] n_iter = npz_target["n_iter"].item() spectrogram_mix = npz_input["spectrogram"] def contrast_fn(y: np.ndarray) -> np.ndarray: r"""Contrast function. Args: y (np.ndarray): The shape is (n_sources, n_bins, n_frames). Returns: np.ndarray: The shape is (n_sources, n_frames). """ return 2 * np.linalg.norm(y, axis=1) def d_contrast_fn(y) -> np.ndarray: r"""Derivative of contrast function. Args: y (np.ndarray): The shape is (n_sources, n_frames). Returns: np.ndarray: The shape is (n_sources, n_frames). """ return 2 * np.ones_like(y) def dd_contrast_fn(y) -> np.ndarray: r"""Second roder derivative of contrast function. Args: y (np.ndarray): The shape is (n_sources, n_frames). Returns: np.ndarray: The shape is (n_sources, n_frames). """ return 2 * np.zeros_like(y) iva = FastIVA( contrast_fn=contrast_fn, d_contrast_fn=d_contrast_fn, dd_contrast_fn=dd_contrast_fn, ) spectrogram_est = iva(spectrogram_mix, n_iter=n_iter) if save_feature: makedirs(root, exist_ok=True) np.savez( join(root, "target.npz"), spectrogram=spectrogram_est, n_iter=n_iter, ) else: assert np.allclose(spectrogram_est, spectrogram_tgt, atol=1e-7), np.max( np.abs(spectrogram_est - spectrogram_tgt) ) def save_all_features() -> None: for is_holonomic in parameters_is_holonomic: test_grad_iva(is_holonomic=is_holonomic, save_feature=True) for is_holonomic in parameters_is_holonomic: test_natural_grad_iva(is_holonomic=is_holonomic, save_feature=True) for spatial_algorithm in parameters_spatial_algorithm: test_aux_iva(spatial_algorithm=spatial_algorithm, save_feature=True) test_fast_iva(save_feature=True) if __name__ == "__main__": save_all_features() ================================================ FILE: tests/regression/bss/test_mnmf.py ================================================ import sys from os import makedirs from os.path import dirname, join, realpath import numpy as np import pytest from ssspy.bss.mnmf import FastGaussMNMF, GaussMNMF ssspy_tests_dir = dirname(dirname(dirname(realpath(__file__)))) sys.path.append(ssspy_tests_dir) from dummy.utils.dataset import load_regression_data # noqa: E402 mnmf_root = join(ssspy_tests_dir, "mock", "regression", "bss", "mnmf") parameters_diagonalizer_algorithm = ["IP1", "IP2"] def test_gauss_mnmf(save_feature: bool = False): rng = np.random.default_rng(0) root = join(mnmf_root, "gauss_mnmf") if save_feature: (npz_input,) = load_regression_data(root=root, filenames=["input.npz"]) spectrogram_tgt = None n_basis = 2 n_iter = 5 else: npz_input, npz_target = load_regression_data( root=root, filenames=["input.npz", "target.npz"] ) spectrogram_tgt = npz_target["spectrogram"] n_basis = npz_target["n_basis"].item() n_iter = npz_target["n_iter"].item() spectrogram_mix = npz_input["spectrogram"] n_channels, n_bins, n_frames = spectrogram_mix.shape n_sources = n_channels if save_feature: basis = rng.random((n_sources, n_bins, n_basis)) activation = rng.random((n_sources, n_basis, n_frames)) spatial = np.eye(n_channels, dtype=spectrogram_mix.dtype) trace = np.trace(spatial, axis1=-2, axis2=-1) spatial = spatial / np.real(trace) spatial = np.tile(spatial, reps=(n_sources, n_bins, 1, 1)) else: basis = npz_target["basis"] activation = npz_target["activation"] spatial = npz_target["spatial"] mnmf = GaussMNMF( n_basis=n_basis, n_sources=n_sources, rng=rng, ) spectrogram_est = mnmf( spectrogram_mix, n_iter=n_iter, basis=basis, activation=activation, spatial=spatial, ) if save_feature: makedirs(root, exist_ok=True) np.savez( join(root, "target.npz"), spectrogram=spectrogram_est, basis=basis, activation=activation, n_basis=n_basis, spatial=spatial, n_iter=n_iter, ) else: assert np.allclose(spectrogram_est, spectrogram_tgt, atol=1e-7), np.max( np.abs(spectrogram_est - spectrogram_tgt) ) @pytest.mark.parametrize("diagonalizer_algorithm", parameters_diagonalizer_algorithm) def test_fast_gauss_mnmf(diagonalizer_algorithm: str, save_feature: bool = False): rng = np.random.default_rng(0) root = join(mnmf_root, "fast_gauss_mnmf", diagonalizer_algorithm) if save_feature: (npz_input,) = load_regression_data(root=root, filenames=["input.npz"]) spectrogram_tgt = None n_basis = 2 n_iter = 5 else: npz_input, npz_target = load_regression_data( root=root, filenames=["input.npz", "target.npz"] ) spectrogram_tgt = npz_target["spectrogram"] n_basis = npz_target["n_basis"].item() n_iter = npz_target["n_iter"].item() spectrogram_mix = npz_input["spectrogram"] n_channels, n_bins, n_frames = spectrogram_mix.shape n_sources = n_channels if save_feature: basis = rng.random((n_sources, n_bins, n_basis)) activation = rng.random((n_sources, n_basis, n_frames)) spatial = rng.random((n_bins, n_sources, n_channels)) diagonalizer = np.eye(n_channels, dtype=np.complex128) diagonalizer = np.tile(diagonalizer, reps=(n_bins, 1, 1)) else: basis = npz_target["basis"] activation = npz_target["activation"] spatial = npz_target["spatial"] diagonalizer = npz_target["diagonalizer"] mnmf = FastGaussMNMF( n_basis=n_basis, n_sources=n_sources, diagonalizer_algorithm=diagonalizer_algorithm, rng=rng, ) spectrogram_est = mnmf( spectrogram_mix, n_iter=n_iter, basis=basis, activation=activation, spatial=spatial, diagonalizer=diagonalizer, ) if save_feature: makedirs(root, exist_ok=True) np.savez( join(root, "target.npz"), spectrogram=spectrogram_est, basis=basis, activation=activation, spatial=spatial, diagonalizer=diagonalizer, n_basis=n_basis, n_iter=n_iter, ) else: assert np.allclose(spectrogram_est, spectrogram_tgt, atol=1e-7), np.max( np.abs(spectrogram_est - spectrogram_tgt) ) def save_all_features() -> None: test_gauss_mnmf(save_feature=True) for diagonalizer_algorithm in parameters_diagonalizer_algorithm: test_fast_gauss_mnmf(diagonalizer_algorithm=diagonalizer_algorithm, save_feature=True) if __name__ == "__main__": save_all_features() ================================================ FILE: tests/scripts/download_all.py ================================================ # It is expected to run from root ssspy directory import sys from os.path import dirname, realpath tests_dir = dirname(dirname(realpath(__file__))) sys.path.append(tests_dir) from dummy.utils.dataset import download_sample_speech_data # noqa: E402 from dummy.utils.dataset import download_ssspy_data # noqa: E402 def download_all() -> None: # Download sample speech data conditions = [ {"n_sources": 2, "sisec2010_tag": "dev1_female3"}, {"n_sources": 3, "sisec2010_tag": "dev1_female3"}, {"n_sources": 4, "sisec2010_tag": "dev1_female4"}, ] max_durations = [0.1, 0.5] for kwargs in conditions: for max_duration in max_durations: download_sample_speech_data(max_duration=max_duration, **kwargs) # Download sample audio for tests of IO paths = [ "audio/monoral_16k_5sec.wav", "audio/stereo_16k_5sec.wav", ] template_filename = "./tests/mock/{}" for path in paths: filename = template_filename.format(path) download_ssspy_data(path, filename=filename) if __name__ == "__main__": download_all()