[
  {
    "path": ".github/PULL_REQUEST_TEMPLATE.md",
    "content": "## Summary\n<!-- Describe brief summary of this PR. -->\n\n<!--\n\n## TODO\nDescribe what to do next.\n\n-->"
  },
  {
    "path": ".github/release.yaml",
    "content": "changelog:\n  categories:\n    - title: Breaking Changes 🛠\n      labels:\n        - breaking changes\n    - title: New Features 🎉\n      labels:\n        - new feature\n    - title: Bug Fixes 🐛\n      labels:\n        - bug\n        - bug fix\n    - title: Notebooks\n        - notebooks\n    - title: Other Changes\n      labels:\n        - \"*\"\n"
  },
  {
    "path": ".github/workflows/lint.yaml",
    "content": "name: lint\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - main\njobs:\n  lint:\n    name: Run linters\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n      - name: Set up Python 3.12\n        uses: actions/setup-python@v4\n        with:\n          python-version: \"3.12\"\n      - name: Install dependencies\n        run: |\n          python -m pip install --upgrade pip\n          pip install \".[dev]\"\n      - name: Run linters\n        run: |\n          # See pyproject.toml\n          isort --line-length 100 ssspy tests\n          flake8 --max-line-length=100 --ignore=E203,W503,W504 --exclude ssspy/_version.py ssspy tests\n      - name: Run formatters\n        run: |\n          python -m black --config pyproject.toml --check ssspy tests\n"
  },
  {
    "path": ".github/workflows/test_docs.yaml",
    "content": "name: tests for docs\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - main\njobs:\n  build:\n    name: Build docs\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n      - name: Set up Python 3.12\n        uses: actions/setup-python@v4\n        with:\n          python-version: \"3.12\"\n      - name: Install dependencies\n        run: |\n          sudo apt-get update\n          sudo apt-get install pandoc\n          python -m pip install --upgrade pip\n          pip install \".[docs,notebooks]\"\n      - name: Build docs\n        run: |\n          . ./docs/pre_build.sh\n          cd docs/\n          sphinx-build -W ./ ./_build/html/\n"
  },
  {
    "path": ".github/workflows/test_package_macos-13.yaml",
    "content": "name: macos-13\non:\n  workflow_call:\n    inputs:\n      python-version:\n        required: true\n        type: string\n    secrets:\n      CODECOV_TOKEN:\n        required: true\n      TEST_PYPI_API_TOKEN:\n        required: true\njobs:\n  package:\n    uses: ./.github/workflows/test_package_main.yaml\n    with:\n      # macos-13: x86_64, macos-latest: arm\n      # See https://docs.github.com/en/actions/using-github-hosted-runners/about-github-hosted-runners/about-github-hosted-runners\n      os: macos-13\n      python-version: ${{ inputs.python-version }}\n    secrets:\n      CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}\n      TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }}\n    permissions:\n      id-token: write\n"
  },
  {
    "path": ".github/workflows/test_package_macos-13_python-3.10.yaml",
    "content": "name: macos-13/3.10\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - main\njobs:\n  package:\n    uses: ./.github/workflows/test_package_macos-13.yaml\n    with:\n      python-version: \"3.10\"\n    secrets:\n      CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}\n      TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }}\n    permissions:\n      id-token: write\n"
  },
  {
    "path": ".github/workflows/test_package_macos-13_python-3.11.yaml",
    "content": "name: macos-13/3.11\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - main\njobs:\n  package:\n    uses: ./.github/workflows/test_package_macos-13.yaml\n    with:\n      python-version: \"3.11\"\n    secrets:\n      CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}\n      TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }}\n    permissions:\n      id-token: write\n"
  },
  {
    "path": ".github/workflows/test_package_macos-13_python-3.12.yaml",
    "content": "name: macos-13/3.12\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - main\njobs:\n  package:\n    uses: ./.github/workflows/test_package_macos-13.yaml\n    with:\n      python-version: \"3.12\"\n    secrets:\n      CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}\n      TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }}\n    permissions:\n      id-token: write\n"
  },
  {
    "path": ".github/workflows/test_package_macos-13_python-3.8.yaml",
    "content": "name: macos-13/3.8\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - main\njobs:\n  package:\n    uses: ./.github/workflows/test_package_macos-13.yaml\n    with:\n      python-version: \"3.8\"\n    secrets:\n      CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}\n      TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }}\n    permissions:\n      id-token: write\n"
  },
  {
    "path": ".github/workflows/test_package_macos-13_python-3.9.yaml",
    "content": "name: macos-13/3.9\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - main\njobs:\n  package:\n    uses: ./.github/workflows/test_package_macos-13.yaml\n    with:\n      python-version: \"3.9\"\n    secrets:\n      CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}\n      TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }}\n    permissions:\n      id-token: write\n"
  },
  {
    "path": ".github/workflows/test_package_macos-latest.yaml",
    "content": "name: macos-latest\non:\n  workflow_call:\n    inputs:\n      python-version:\n        required: true\n        type: string\n    secrets:\n      CODECOV_TOKEN:\n        required: true\n      TEST_PYPI_API_TOKEN:\n        required: true\njobs:\n  package:\n    uses: ./.github/workflows/test_package_main.yaml\n    with:\n      # macos-13: x86_64, macos-latest: arm\n      # See https://docs.github.com/en/actions/using-github-hosted-runners/about-github-hosted-runners/about-github-hosted-runners\n      os: macos-latest\n      python-version: ${{ inputs.python-version }}\n    secrets:\n      CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}\n      TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }}\n    permissions:\n      id-token: write\n"
  },
  {
    "path": ".github/workflows/test_package_macos-latest_python-3.10.yaml",
    "content": "name: macos-latest/3.10\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - main\njobs:\n  package:\n    uses: ./.github/workflows/test_package_macos-latest.yaml\n    with:\n      python-version: \"3.10\"\n    secrets:\n      CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}\n      TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }}\n    permissions:\n      id-token: write\n"
  },
  {
    "path": ".github/workflows/test_package_macos-latest_python-3.11.yaml",
    "content": "name: macos-latest/3.11\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - main\njobs:\n  package:\n    uses: ./.github/workflows/test_package_macos-latest.yaml\n    with:\n      python-version: \"3.11\"\n    secrets:\n      CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}\n      TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }}\n    permissions:\n      id-token: write\n"
  },
  {
    "path": ".github/workflows/test_package_macos-latest_python-3.12.yaml",
    "content": "name: macos-latest/3.12\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - main\njobs:\n  package:\n    uses: ./.github/workflows/test_package_macos-latest.yaml\n    with:\n      python-version: \"3.12\"\n    secrets:\n      CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}\n      TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }}\n    permissions:\n      id-token: write\n"
  },
  {
    "path": ".github/workflows/test_package_macos-latest_python-3.8.yaml",
    "content": "name: macos-latest/3.8\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - main\njobs:\n  package:\n    uses: ./.github/workflows/test_package_macos-latest.yaml\n    with:\n      python-version: \"3.8\"\n    secrets:\n      CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}\n      TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }}\n    permissions:\n      id-token: write\n"
  },
  {
    "path": ".github/workflows/test_package_macos-latest_python-3.9.yaml",
    "content": "name: macos-latest/3.9\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - main\njobs:\n  package:\n    uses: ./.github/workflows/test_package_macos-latest.yaml\n    with:\n      python-version: \"3.9\"\n    secrets:\n      CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}\n      TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }}\n    permissions:\n      id-token: write\n"
  },
  {
    "path": ".github/workflows/test_package_main.yaml",
    "content": "name: test package\non:\n  workflow_call:\n    inputs:\n      os:\n        required: true\n        type: string\n      python-version:\n        required: true\n        type: string\n    secrets:\n      CODECOV_TOKEN:\n        required: true\n      TEST_PYPI_API_TOKEN:\n        required: true\njobs:\n  build:\n    name: Run tests with pytest\n    runs-on: ${{ inputs.os }}\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n      - name: Set up Python ${{ inputs.python-version }}\n        uses: actions/setup-python@v4\n        with:\n          python-version: ${{ inputs.python-version }}\n      - name: Install dependencies\n        run: |\n          python -m pip install --upgrade pip\n          pip install -e \".[dev,tests]\"\n      - name: Preparation of pytest\n        run: |\n          python tests/scripts/download_all.py\n      - name: Pytest (run all tests including redundant ones)\n        id: run_redundant\n        if: startsWith(github.head_ref, 'release/')\n        run: |\n          pytest --run-redundant -vvv -n 16 --cov=./ssspy --cov-report=xml tests/package/\n      - name: Pytest (skip redundant tests)\n        if: steps.run_redundant.conclusion == 'skipped'\n        run: |\n          pytest -vvv -n 16 --cov=./ssspy --cov-report=xml tests/package/\n      - name: Pytest (regression tests)\n        run: |\n          pytest -vvv -n 16 tests/regression/\n      - name: Upload coverage reports to Codecov\n        if: inputs.python-version == '3.12' && inputs.os == 'ubuntu-latest'\n        uses: codecov/codecov-action@v3\n        env:\n          CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}\n        with:\n          fail_ci_if_error: true\n  upload_package:\n    needs:\n      - build\n    permissions:\n      id-token: write\n    secrets:\n      TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }}\n    if: github.event_name == 'pull_request' && inputs.python-version == '3.12' && inputs.os == 'ubuntu-latest'\n    uses: ./.github/workflows/upload_package.yaml\n"
  },
  {
    "path": ".github/workflows/test_package_ubuntu-latest.yaml",
    "content": "name: ubuntu-latest\non:\n  workflow_call:\n    inputs:\n      python-version:\n        required: true\n        type: string\n    secrets:\n      CODECOV_TOKEN:\n        required: true\n      TEST_PYPI_API_TOKEN:\n        required: true\njobs:\n  package:\n    uses: ./.github/workflows/test_package_main.yaml\n    with:\n      os: ubuntu-latest\n      python-version: ${{ inputs.python-version }}\n    secrets:\n      CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}\n      TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }}\n    permissions:\n      id-token: write\n"
  },
  {
    "path": ".github/workflows/test_package_ubuntu-latest_python-3.10.yaml",
    "content": "name: ubuntu-latest/3.10\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - main\njobs:\n  package:\n    uses: ./.github/workflows/test_package_ubuntu-latest.yaml\n    with:\n      python-version: \"3.10\"\n    secrets:\n      CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}\n      TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }}\n    permissions:\n      id-token: write\n"
  },
  {
    "path": ".github/workflows/test_package_ubuntu-latest_python-3.11.yaml",
    "content": "name: ubuntu-latest/3.11\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - main\njobs:\n  package:\n    uses: ./.github/workflows/test_package_ubuntu-latest.yaml\n    with:\n      python-version: \"3.11\"\n    secrets:\n      CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}\n      TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }}\n    permissions:\n      id-token: write\n"
  },
  {
    "path": ".github/workflows/test_package_ubuntu-latest_python-3.12.yaml",
    "content": "name: ubuntu-latest/3.12\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - main\njobs:\n  package:\n    uses: ./.github/workflows/test_package_ubuntu-latest.yaml\n    with:\n      python-version: \"3.12\"\n    secrets:\n      CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}\n      TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }}\n    permissions:\n      id-token: write\n"
  },
  {
    "path": ".github/workflows/test_package_ubuntu-latest_python-3.8.yaml",
    "content": "name: ubuntu-latest/3.8\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - main\njobs:\n  package:\n    uses: ./.github/workflows/test_package_ubuntu-latest.yaml\n    with:\n      python-version: \"3.8\"\n    secrets:\n      CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}\n      TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }}\n    permissions:\n      id-token: write\n"
  },
  {
    "path": ".github/workflows/test_package_ubuntu-latest_python-3.9.yaml",
    "content": "name: ubuntu-latest/3.9\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - main\njobs:\n  package:\n    uses: ./.github/workflows/test_package_ubuntu-latest.yaml\n    with:\n      python-version: \"3.9\"\n    secrets:\n      CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}\n      TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }}\n    permissions:\n      id-token: write\n"
  },
  {
    "path": ".github/workflows/test_package_windows-latest.yaml",
    "content": "name: windows-latest\non:\n  workflow_call:\n    inputs:\n      python-version:\n        required: true\n        type: string\n    secrets:\n      CODECOV_TOKEN:\n        required: true\n      TEST_PYPI_API_TOKEN:\n        required: true\njobs:\n  package:\n    uses: ./.github/workflows/test_package_main.yaml\n    with:\n      os: windows-latest\n      python-version: ${{ inputs.python-version }}\n    secrets:\n      CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}\n      TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }}\n    permissions:\n      id-token: write\n"
  },
  {
    "path": ".github/workflows/test_package_windows-latest_python-3.10.yaml",
    "content": "name: windows-latest/3.10\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - main\njobs:\n  package:\n    uses: ./.github/workflows/test_package_windows-latest.yaml\n    with:\n      python-version: \"3.10\"\n    secrets:\n      CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}\n      TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }}\n    permissions:\n      id-token: write\n"
  },
  {
    "path": ".github/workflows/test_package_windows-latest_python-3.11.yaml",
    "content": "name: windows-latest/3.11\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - main\njobs:\n  package:\n    uses: ./.github/workflows/test_package_windows-latest.yaml\n    with:\n      python-version: \"3.11\"\n    secrets:\n      CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}\n      TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }}\n    permissions:\n      id-token: write\n"
  },
  {
    "path": ".github/workflows/test_package_windows-latest_python-3.12.yaml",
    "content": "name: windows-latest/3.12\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - main\njobs:\n  package:\n    uses: ./.github/workflows/test_package_windows-latest.yaml\n    with:\n      python-version: \"3.12\"\n    secrets:\n      CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}\n      TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }}\n    permissions:\n      id-token: write\n"
  },
  {
    "path": ".github/workflows/test_package_windows-latest_python-3.8.yaml",
    "content": "name: windows-latest/3.8\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - main\njobs:\n  package:\n    uses: ./.github/workflows/test_package_windows-latest.yaml\n    with:\n      python-version: \"3.8\"\n    secrets:\n      CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}\n      TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }}\n    permissions:\n      id-token: write\n"
  },
  {
    "path": ".github/workflows/test_package_windows-latest_python-3.9.yaml",
    "content": "name: windows-latest/3.9\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - main\njobs:\n  package:\n    uses: ./.github/workflows/test_package_windows-latest.yaml\n    with:\n      python-version: \"3.9\"\n    secrets:\n      CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}\n      TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }}\n    permissions:\n      id-token: write\n"
  },
  {
    "path": ".github/workflows/upload_package.yaml",
    "content": "# based on \n# https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python\n# https://github.com/pypa/gh-action-pypi-publish\n\n# TODO: update this config for practical use\n\nname: Upload package to PyPI\non:\n  workflow_call:\n    secrets:\n      TEST_PYPI_API_TOKEN:\n        required: true\njobs:\n  build:\n    name: Build and upload package\n    runs-on: ubuntu-latest\n    permissions:\n      id-token: write\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n        with:\n          # to retrive tags\n          fetch-depth: 0\n      - name: Set up Python 3.x\n        uses: actions/setup-python@v4\n        with:\n          python-version: '3.x'\n      - name: Show git tags\n        run: |\n          git tag\n      - name: Install dependencies\n        run: |\n          python -m pip install --upgrade pip\n          pip install build wheel twine\n      - name: Build\n        run: |\n          python -m build\n      - name: Publish distribution to TestPyPI\n        env:\n          TEST_PYPI_API_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }}\n        run: |\n          python -m twine upload --repository testpypi --username __token__ --password ${TEST_PYPI_API_TOKEN} dist/*\n"
  },
  {
    "path": ".gitignore",
    "content": "# For building docs\ndocs/_notebooks/\n\n# For local\n.data/\n_version.py\n\n# For Mac\n.DS_Store\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n"
  },
  {
    "path": ".readthedocs.yaml",
    "content": "# .readthedocs.yaml\n# Read the Docs configuration file\n# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details\n\n# Required\nversion: 2\n\n# Set the version of Python and other tools you might need\nbuild:\n  os: ubuntu-20.04\n  tools:\n    python: \"3.11\"\n    # You can also specify other tool versions:\n    # nodejs: \"16\"\n    # rust: \"1.55\"\n    # golang: \"1.17\"\n  jobs:\n    pre_build:\n      - . ./docs/pre_build.sh\n\n# Build documentation in the docs/ directory with Sphinx\nsphinx:\n  configuration: docs/conf.py\n\n# If using Sphinx, optionally build your docs in additional formats such as PDF\n# formats:\n#    - pdf\n\n# Optionally declare the Python requirements required to build your docs\n# python:\n#   install:\n#   - method: pip\n#     path: .\n#     extra_requirements:\n#     - docs\n#     - notebooks\n"
  },
  {
    "path": "CHANGELOG.rst",
    "content": "Changelog\n#########\n\nv0.2.0\n******\n\nWhat's Changed\n==============\n\nBreaking Changes 🛠\n-------------------\n* Rename `aux` to `auxiliary` by @tky823 in https://github.com/tky823/ssspy/pull/268\n* Detailed build status by @tky823 in https://github.com/tky823/ssspy/pull/288\n\nNew Features 🎉\n---------------\n* Implementation of harmonic vector analysis by @tky823 in https://github.com/tky823/ssspy/pull/271\n* Implementation of ADMM-HVA by @tky823 in https://github.com/tky823/ssspy/pull/281\n\nBug Fixes 🐛\n------------\n* Fix test coverage by @tky823 in https://github.com/tky823/ssspy/pull/269\n* Fix timing of uploading package by @tky823 in https://github.com/tky823/ssspy/pull/273\n* Remove status badge of lint by @tky823 in https://github.com/tky823/ssspy/pull/274\n\nOther Changes\n-------------\n* Upload package to TestPyPI by @tky823 in https://github.com/tky823/ssspy/pull/267\n* Remove duplicate uploads to TestPyPI by @tky823 in https://github.com/tky823/ssspy/pull/270\n* Use flooring function to compute norm. by @tky823 in https://github.com/tky823/ssspy/pull/276\n* Regression tests by @tky823 in https://github.com/tky823/ssspy/pull/238\n* Add `needs` to upload_package job in GHA. by @tky823 in https://github.com/tky823/ssspy/pull/277\n* Update actions/checkout in GitHub actions by @tky823 in https://github.com/tky823/ssspy/pull/279\n* Hugging Face demo by @tky823 in https://github.com/tky823/ssspy/pull/282\n* Set permissions in workflows by @tky823 in https://github.com/tky823/ssspy/pull/289\n* Bump up version to 0.2.0 by @tky823 in https://github.com/tky823/ssspy/pull/290\n\n\n**Full Changelog**: https://github.com/tky823/ssspy/compare/v0.1.7...v0.2.0\n\nv0.1.7\n******\n\nSummary\n=======\nIn this version, we improve the management of the package.\nAs a new BSS method, ADMM-BSS is newly added.\n\nWhat's Changed\n==============\n\nBreaking Changes 🛠\n-------------------\n* Include ssspy only as package by @tky823 in https://github.com/tky823/ssspy/pull/253\n* Add ``MANIFEST.in`` by @tky823 in https://github.com/tky823/ssspy/pull/257\n\nNew Features 🎉\n---------------\n* Implementation of ADMM-IVA by @tky823 in https://github.com/tky823/ssspy/pull/263\n* Support ADMM-BSS_multi-penalty by @tky823 in https://github.com/tky823/ssspy/pull/265\n\nBug Fixes 🐛\n------------\n* Fix document deployment by @tky823 in https://github.com/tky823/ssspy/pull/255\n* Update some variables depending on ``demix_filter`` instead of ``self.algorithm``. by @tky823 in https://github.com/tky823/ssspy/pull/260\n\nOther Changes\n-------------\n* Release notes by @tky823 in https://github.com/tky823/ssspy/pull/246\n* Add label for breaking changes by @tky823 in https://github.com/tky823/ssspy/pull/247\n* Notebooks/getting started by @tky823 in https://github.com/tky823/ssspy/pull/248\n* Update docs and notebooks to install ``ssspy`` from pypi by @tky823 in https://github.com/tky823/ssspy/pull/251\n* Detect reformatting by @tky823 in https://github.com/tky823/ssspy/pull/258\n* Make PDSBSSBase inherit IterativeMethodBase by @tky823 in https://github.com/tky823/ssspy/pull/262\n\n\n**Full Changelog**: `v0.1.6...v0.1.7 <https://github.com/tky823/ssspy/compare/v0.1.6...v0.1.7>`_\n\nv0.1.6\n******\n\nSummary\n=======\nIn this version, the following BSS methods are newly added 🚀\n\n- Fast MNMF\n- IVA-IPA\n- ILRMA-IPA\n\nWhat's Changed\n==============\n* Bump up version to v0.1.5 by @tky823 in https://github.com/tky823/ssspy/pull/222\n* Rename \"XXXbase\" to \"XXXBase\" by @tky823 in https://github.com/tky823/ssspy/pull/224\n* Move default pair_selector by @tky823 in https://github.com/tky823/ssspy/pull/225\n* Implement Fast MNMF by @tky823 in https://github.com/tky823/ssspy/pull/226\n* Score-based permutation solver by @tky823 in https://github.com/tky823/ssspy/pull/221\n* Specify flooring function in each method by @tky823 in https://github.com/tky823/ssspy/pull/228\n* Solver for cubic equations. by @tky823 in https://github.com/tky823/ssspy/pull/230\n* Consider corner case of cubic polynomial by @tky823 in https://github.com/tky823/ssspy/pull/233\n* Use pytest-xdist by @tky823 in https://github.com/tky823/ssspy/pull/235\n* Implement IVA-IPA by @tky823 in https://github.com/tky823/ssspy/pull/234\n* Update links to reference by @tky823 in https://github.com/tky823/ssspy/pull/237\n* Fix shape of varphi in tests of IVA by @tky823 in https://github.com/tky823/ssspy/pull/240\n* End support of python=3.7 by @tky823 in https://github.com/tky823/ssspy/pull/243\n* Stabilize IVA-IPA related algorithms by @tky823 in https://github.com/tky823/ssspy/pull/241\n* Implementation of ILRMA-IPA by @tky823 in https://github.com/tky823/ssspy/pull/244\n\n\n**Full Changelog**: `v0.1.5...v0.1.6 <https://github.com/tky823/ssspy/compare/v0.1.5...v0.1.6>`_\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright 2022 Takuya Hasumi\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "MANIFEST.in",
    "content": "exclude .gitignore\nexclude *.yaml\nrecursive-include ssspy *.py\nprune .github\nprune tests\nprune docs\nprune notebooks\n"
  },
  {
    "path": "README.md",
    "content": "# ssspy\n[![Documentation Status](https://readthedocs.org/projects/sound-source-separation-python/badge/?version=latest)](https://sound-source-separation-python.readthedocs.io/en/latest/?badge=latest)\n[![codecov](https://codecov.io/gh/tky823/ssspy/branch/main/graph/badge.svg)](https://codecov.io/gh/tky823/ssspy)\n[![Open in Spaces](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm.svg)](https://tky823-ssspy-demo.hf.space/)\n\nA Python toolkit for sound source separation.\n\n## Build Status\n\n| Python | Ubuntu | MacOS (x86_64) | MacOS (arm64) | Windows |\n|:-:|:-:|:-:|:-:|:-:|\n| 3.9 | [![ubuntu-latest/3.9](https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.9.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.9.yaml) | [![macos-13/3.9](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.9.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.9.yaml) | [![macos-latest/3.9](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.9.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.9.yaml) | [![windows-latest/3.9](https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.9.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.9.yaml) |\n| 3.10 | [![ubuntu-latest/3.10](https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.10.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.10.yaml) | [![macos-13/3.10](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.10.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.10.yaml) | [![macos-latest/3.10](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.10.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.10.yaml) | [![windows-latest/3.10](https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.10.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.10.yaml) |\n| 3.11 | [![ubuntu-latest/3.11](https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.11.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.11.yaml) | [![macos-13/3.11](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.11.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.11.yaml) | [![macos-latest/3.11](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.11.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.11.yaml) | [![windows-latest/3.11](https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.11.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.11.yaml) |\n| 3.12 | [![ubuntu-latest/3.12](https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.12.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.12.yaml) | [![macos-13/3.12](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.12.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.12.yaml) | [![macos-latest/3.12](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.12.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.12.yaml) | [![windows-latest/3.12](https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.12.yaml/badge.svg?branch=main)](https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.12.yaml) |\n\n\n## Installation\nYou can install by pip.\n```shell\npip install ssspy\n```\n\nTo install latest version,\n```shell\npip install git+https://github.com/tky823/ssspy.git\n```\n\nInstead, you can build package from source.\n```shell\ngit clone https://github.com/tky823/ssspy.git\ncd ssspy\npip install .\n```\n\nIf you cannot install `ssspy` due to failure in building wheel for numpy, please install numpy in advance.\n\n## Build Documentation Locally (optional)\nTo build the documentation locally, you have to include `docs` and `notebooks` when installing `ssspy`.\n```shell\npip install -e \".[docs,notebooks]\"\n```\n\nYou need to convert some notebooks by the following command:\n```shell\n# in ssspy/\n. ./docs/pre_build.sh\n```\n\nWhen you build the documentation, run the following command.\n```shell\ncd docs/\nmake html\n```\n\nOr, you can build the documentation automatically using `sphinx-autobuild`.\n```shell\n# in ssspy/\nsphinx-autobuild docs docs/_build/html\n```\n\n## Blind Source Separation Methods\n\n| Method | Notebooks |\n|:-:|:-:|\n| Independent Component Analysis (ICA) [1-3] | Gradient-descent-based ICA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ICA/GradICA.ipynb) <br> Natural-gradient-descent-based ICA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ICA/NaturalGradICA.ipynb) <br> Fast ICA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ICA/FastICA.ipynb) |\n| Frequency-Domain Independent Component Analysis (FDICA) [4-6] | Gradient-descent-based FDICA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/FDICA/GradFDICA.ipynb) <br> Natural-gradient-descent-based FDICA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/FDICA/NaturalGradFDICA.ipynb) <br> Auxiliary-function-based FDICA (IP1): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/FDICA/AuxFDICA-IP1.ipynb) <br> Auxiliary-function-based FDICA (IP2): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/FDICA/AuxFDICA-IP2.ipynb) <br> Gradient-descent-based Laplace-FDICA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/FDICA/GradLaplaceFDICA.ipynb) <br> Natural-gradient-descent-based Laplace-FDICA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/FDICA/NaturalGradLaplaceFDICA.ipynb) <br> Auxiliary-function-based Laplace-FDICA (IP1): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/FDICA/AuxLaplaceFDICA-IP1.ipynb) <br> Auxiliary-function-based Laplace-FDICA (IP2): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/FDICA/AuxLaplaceFDICA-IP2.ipynb) |\n| Independent Vector Analysis (IVA) [7-14] | Gradient-descent-based IVA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/GradIVA.ipynb) <br> Natural-gradient-descent-based IVA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/NaturalGradIVA.ipynb) <br> Fast IVA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/FastIVA.ipynb) <br> Faster IVA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/FasterIVA.ipynb) <br> Auxiliary-function-based IVA (IP1): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/AuxIVA-IP1.ipynb) <br> Auxiliary-function-based IVA (IP2): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/AuxIVA-IP2.ipynb) <br> Auxiliary-function-based IVA (ISS1): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/AuxIVA-ISS1.ipynb) <br> Auxiliary-function-based IVA (ISS2): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/AuxIVA-ISS2.ipynb) <br> Auxiliary-function-based IVA (IPA): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/AuxIVA-IPA.ipynb) <br> Gradient-descent-based Laplace-IVA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/GradLaplaceIVA.ipynb) <br> Natural-gradient-descent-based Laplace-IVA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/NaturalGradLaplaceIVA.ipynb) <br> Auxiliary-function-based Laplace-IVA (IP1): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/AuxLaplaceIVA-IP1.ipynb) <br> Auxiliary-function-based Laplace-IVA (IP2): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/AuxLaplaceIVA-IP2.ipynb) <br> Auxiliary-function-based Laplace-IVA (ISS1): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/AuxLaplaceIVA-ISS1.ipynb) <br> Auxiliary-function-based Laplace-IVA (ISS2): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/AuxLaplaceIVA-ISS2.ipynb) <br> Auxiliary-function-based Laplace-IVA (IPA): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/AuxLaplaceIVA-IPA.ipynb) <br> Gradient-descent-based Gauss-IVA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/GradGaussIVA.ipynb) <br> Natural-gradient-descent-based Gauss-IVA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/NaturalGradGaussIVA.ipynb) <br> Auxiliary-function-based Gauss-IVA (IP1): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/AuxGaussIVA-IP1.ipynb) <br> Auxiliary-function-based Gauss-IVA (IP2): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/AuxGaussIVA-IP2.ipynb) <br> Auxiliary-function-based Gauss-IVA (ISS1): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/AuxGaussIVA-ISS1.ipynb) <br> Auxiliary-function-based Gauss-IVA (ISS2): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/AuxGaussIVA-ISS2.ipynb) <br> Auxiliary-function-based Gauss-IVA (IPA): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IVA/AuxGaussIVA-IPA.ipynb) |\n| Independent Low-Rank Matrix Analysis (ILRMA) [15-18] | Gauss-ILRMA (IP1/MM): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/GaussILRMA-IP1-MM.ipynb) <br> Gauss-ILRMA (IP1/ME): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/GaussILRMA-IP1-ME.ipynb) <br> Gauss-ILRMA (IP2/MM): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/GaussILRMA-IP2-MM.ipynb) <br> Gauss-ILRMA (IP2/ME): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/GaussILRMA-IP2-ME.ipynb) <br> Gauss-ILRMA (ISS1/MM): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/GaussILRMA-ISS1-MM.ipynb) <br> Gauss-ILRMA (ISS1/ME): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/GaussILRMA-ISS1-ME.ipynb) <br> Gauss-ILRMA (ISS2/MM): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/GaussILRMA-ISS2-MM.ipynb) <br> Gauss-ILRMA (ISS2/ME): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/GaussILRMA-ISS2-ME.ipynb) <br> Gauss-ILRMA (IPA/MM): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/GaussILRMA-IPA-MM.ipynb) <br> Gauss-ILRMA (IPA/ME): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/GaussILRMA-IPA-ME.ipynb) <br> *t*-ILRMA (IP1/MM): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/TILRMA-IP1-MM.ipynb) <br> *t*-ILRMA (IP1/ME): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/TILRMA-IP1-ME.ipynb) <br> *t*-ILRMA (IP2/MM): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/TILRMA-IP2-MM.ipynb) <br> *t*-ILRMA (IP2/ME): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/TILRMA-IP2-ME.ipynb) <br> *t*-ILRMA (ISS1/MM): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/TILRMA-ISS1-MM.ipynb) <br> *t*-ILRMA (ISS1/ME): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/TILRMA-ISS1-ME.ipynb) <br> *t*-ILRMA (ISS2/MM): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/TILRMA-ISS2-MM.ipynb) <br> *t*-ILRMA (ISS2/ME): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/TILRMA-ISS2-ME.ipynb) <br> GGD-ILRMA (IP1/MM): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/GGDILRMA-IP1-MM.ipynb) <br> GGD-ILRMA (IP2/MM): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/GGDILRMA-IP2-MM.ipynb) <br> GGD-ILRMA (ISS1/MM): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/GGDILRMA-ISS1-MM.ipynb) <br> GGD-ILRMA (ISS2/MM): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ILRMA/GGDILRMA-ISS2-MM.ipynb) |\n| Independent Positive Semidefinite Tensor Analysis (IPSDTA) [19, 20] | Gauss-IPSDTA (VCD): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IPSDTA/GaussIPSDTA-VCD.ipynb) <br> *t*-IPSDTA (VCD): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/IPSDTA/TIPSDTA-VCD.ipynb) |\n| Multichannel Nonnegative Matrix Factorization (MNMF) [21-24] | Gauss-MNMF: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/MNMF/GaussMNMF.ipynb) <br> *t*-MNMF: soon <br> Fast Gauss-MNMF (IP1): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/MNMF/FastGaussMNMF-IP1.ipynb) <br> Fast Gauss-MNMF (IP2): [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/MNMF/FastGaussMNMF-IP2.ipynb) |\n| Blind Source Separation via Primal-Dual Splitting Algorithm (PDS-BSS) [25,26] | PDS-BSS: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/PDSBSS/PDSBSS.ipynb) <br> PDS-BSS-multiPenalty: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/PDSBSS/PDSBSS_multi-penalty.ipynb) <br> PDS-BSS-masking: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/PDSBSS/PDSBSS_masking.ipynb) |\n| Blind Source Separation via Alternating Direction Method of Multipliers (ADMM-BSS) | ADMM-BSS: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/ADMMBSS/ADMMBSS.ipynb) |\n| Harmonic Vector Analysis (HVA) [27] | HVA: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/HVA/HVA.ipynb) |\n| Complex Angular Central Gaussian Mixture Model (cACGMM) [28] | cACGMM: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tky823/ssspy/blob/main/notebooks/BSS/CACGMM/CACGMM.ipynb) |\n\n- [1] [P. Comon, \"Independent component analysis, a new concept?\" Signal Processing, vol. 36, no. 3, pp. 287-314, 1994.](https://www.sciencedirect.com/science/article/pii/0165168494900299)\n- [2] [S. Amari, A. Cichocki, and H. H. Yang, \"A new learning algorithm forblind signal separation,\" in *Proc. NIPS*, 1996, pp. 757-763.](https://proceedings.neurips.cc/paper/1995/hash/e19347e1c3ca0c0b97de5fb3b690855a-Abstract.html)\n- [3] [A. Hyvärinen, \"Fast and robust fixed-point algorithms for independent component analysis,\" *IEEE Trans. on Neural Netw.*, vol. 10, no. 3, pp. 626-634, 1999.](https://www.cs.helsinki.fi/u/ahyvarin/papers/TNN99new.pdf)\n- [4] [N. Murata, S. Ikeda, and A. Ziehe, \"An approach to blind source separation based on temporal structure of speech signals,\" in *Neurocomputing*, vol. 41, no. 1, pp. 1-24, 2001](https://www.sciencedirect.com/science/article/pii/S0925231200003453)\n- [5] [H. Sawada, S. Araki, and S. Makino, \"Underdetermined convolutive blind source separation via frequency bin-wise clustering and permutation alignment,\" 2011.](https://ieeexplore.ieee.org/document/5473129)\n- [6] [N. Ono and S. Miyabe, \"Auxiliary-function-based independent componentanalysis for super-Gaussian sources,\" in *Proc. LVA/ICA*, 2010, pp. 165-172.](https://link.springer.com/chapter/10.1007/978-3-642-15995-4_21)\n- [7] [T. Kim, T. Attias, S.-Y. Lee, and T.-W. Lee, \"Blind source separation exploiting higher-order frequency dependencies,\" *IEEE trans. ASLP*, vol. 15, no.1, pp. 70-79, 2006.](https://link.springer.com/chapter/10.1007/11679363_21)\n- [8] [I. Lee, T. Kim, and T.-W. Lee, \"Fast fixed-point independent vector analysis algorithms for convolutive blind source separation,\" *Signal Processing*, vol. 87, no. 8, pp. 1859-1871, 2007.]()\n- [9] [N. Ono, \"Stable and fast update rules for independent vector analysis based on auxiliary function technique,\" in *Proc. WASPAA*, 2011, p.189-192.](https://ieeexplore.ieee.org/document/6082320)\n- [10] [N. Ono, \"Auxiliary-function-based independent vector analysis with power of vector-norm type weighting functions,\" in *Proc. APSIPA ASC*, 2012, pp. 1-4.](https://ieeexplore.ieee.org/document/6411886)\n- [11] [R. Scheibler and N. Ono, \"Fast and stable blind source separation with rank-1 updates,\" in *Proc. ICASSP*, 2020, pp. 236-240.](https://ieeexplore.ieee.org/document/9053556)\n- [12] [A. Brendel and W. Kellermann, \"Faster IVA: Update rules for independent vector analysis based on negentropy and the majorize-minimize principle,\" in *Proc. WASPAA*, 2021, pp. 131–135.](https://arxiv.org/abs/2003.09531)\n- [13] [R. Scheibler, \"Independent vector analysis via log-quadratically penalized quadratic minimization,\" *IEEE Trans. Signal Processing*, vol. 69, pp. 2509-2524, 2021.](https://ieeexplore.ieee.org/document/9399809)\n- [14] [R. Ikeshita and T. Nakatani, \"ISS2: An extension of iterative source steering algorithm for majorization-minimization-based independent vector analysis,\" in *Proc. EUSIPCO*, 2022, pp. 65-69.](https://arxiv.org/abs/2202.00875)\n- [15] [D. Kitamura, N. Ono, H. Sawada, H. Kameoka, and H. Saruwatari, \"Determined blind source separation unifying independent vector analysis and nonnegative matrix factorization,\" *IEEE/ACM Trans. ASLP*, vol. 24, no. 9, pp. 1626-1641, 2016.](https://ieeexplore.ieee.org/document/7486081)\n- [16] [D. Kitamura, S. Mogami, Y. Mitsui, N. Takamune, H. Saruwatari, N. Ono, Y. Takahashi, and K. Kondo, \"Generalized independent low-rank matrix analysis using heavy-tailed distributions for blind source separation,\" *EURASIP J. Adv. in Signal Processing*, vol. 2018, no. 28, 25 pages, 2018.](https://link.springer.com/article/10.1186/s13634-018-0549-5)\n- [17] [T. Nakashima, R. Scheibler, Y. Wakabayashi, and N. Ono, \"Faster independent low-rank matrix analysis with pairwise updates of demixing vectors,\" in *Proc. EUSIPCO*, 2021, pp. 301-305.](https://ieeexplore.ieee.org/document/9287508)\n- [18] [Y. Mitsui, D. Kitamura, N. Takamune, H. Saruwatari, Y. Takahashi, K. Kondo, \"Independent low-rank matrix analysis based on parametric majorization-equalization algorithm\", in *Proc. CAMSAP*, 2017, pp. 98-102.](https://arxiv.org/abs/1710.01589)\n- [19] [R. Ikeshita, \"Independent positive semidefinite tensor analysis in blind source separation,\" in *Proc. EUSIPCO*, 2018, pp. 1652-1656.](https://ieeexplore.ieee.org/document/8553546)\n- [20] [T. Kondo, K. Fukushige, N. Takamune, D. Kitamura, H. Saruwatari, R. Ikeshita, and T. Nakatani, \"Convergence-guaranteed independent positive semidefinite tensor analysis based on Student's t distribution,\" in *Proc ICASSP*, 2020, pp. 681-685.](https://ieeexplore.ieee.org/document/9054150)\n- [21] [A. Ozerov and C. Fevotte, \"Multichannel nonnegative matrix factorization in convolutive mixtures for audio source separation,\" *IEEE Trans. ASLP*, vol. 18, no. 3, pp. 550-563, 2010.](https://ieeexplore.ieee.org/document/5229304)\n- [22] [H. Sawada, H. Kameoka, S. Araki, and N. Ueda, \"Multichannel extensions of non-negative matrix factorization with complex-valued data,\" *IEEE Trans. ASLP*, vol. 21, no. 5, pp. 971-982, 2013.](https://ieeexplore.ieee.org/document/6410389)\n- [23] [K. Yoshii, K. Itoyama, and M. Goto, \"Student's T nonnegative matrix factorization and positive semidefinite tensor factorization for single-channel audio source separation,\" in *Proc. ICASSP*, 2016, pp. 51-55.](https://ieeexplore.ieee.org/document/7471635)\n- [24] [K. Sekiguchi, A. A. Nugraha, Y. Bando, and K. Yoshii, \"Fast multichannel source separation based on jointly diagonalizable spatial covariance matrices,\" in *Proc. EUSIPCO*, 2019, pp. 1-5.](https://arxiv.org/abs/1903.03237)\n- [25] [K. Yatabe and D. Kitamura, \"Determined blind source separation via proximal splitting algorithm,\" in *Proc. ICASSP*, 2018, pp. 776-780.](https://ieeexplore.ieee.org/document/8462338)\n- [26] [K. Yatabe and D. Kitamura, \"Time-frequency-masking-based determined BSS with application to sparse IVA,\" in *Proc. ICASSP*, 2019, pp. 715-719.](https://ieeexplore.ieee.org/document/8682217)\n- [27] [K. Yatabe and D. Kitamura, \"Determined BSS based on time-frequency masking and its application to harmonic vector analysis,\" *IEEE/ACM Trans. ASLP*, vol. 29, pp. 1609–1625, 2021.](https://dl.acm.org/doi/abs/10.1109/TASLP.2021.3073863)\n- [28] [N. Ito, S. Araki, and T. Nakatani, \"Complex angular central Gaussian mixture model for directional statistics in mask-based microphone array signal processing,\" in *Proc. EUSIPCO*, 2016, pp. 1153-1157.](https://ieeexplore.ieee.org/document/7760429)\n\n## LICENSE\nApache License 2.0\n"
  },
  {
    "path": "codecov.yaml",
    "content": "coverage:\n  status:\n    project:\n      default:\n        target: auto\n        threshold: 1%\n    patch:\n      default:\n        target: auto\n        threshold: 5%\n"
  },
  {
    "path": "docs/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line, and also\n# from the environment for the first two.\nSPHINXOPTS    ?=\nSPHINXBUILD   ?= sphinx-build\nSOURCEDIR     = .\nBUILDDIR      = _build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n"
  },
  {
    "path": "docs/api.rst",
    "content": "APIs\n====\n\nIntroduction\n------------\n\n.. code-block:: python\n\n   import numpy as np\n   import scipy.signal as ss\n   import IPython.display as ipd\n   import matplotlib.pyplot as plt\n\n   from ssspy.utils.dataset import download_sample_speech_data\n   from ssspy.transform import whiten\n   from ssspy.algorithm import projection_back\n   from ssspy.bss.fdica import AuxFDICA\n\n   n_fft, hop_length = 4096, 2048\n   window = \"hann\"\n\n   waveform_src_img = download_sample_speech_data(n_sources=3)\n   waveform_mix = np.sum(waveform_src_img, axis=1)\n   _, _, spectrogram_mix = ss.stft(\n      waveform_mix,\n      window=window,\n      nperseg=n_fft,\n      noverlap=n_fft-hop_length\n   )\n   _, _, spectrogram_mix = ss.stft(\n      waveform_mix,\n      window=window,\n      nperseg=n_fft,\n      noverlap=n_fft-hop_length\n   )\n\n   def contrast_fn(y):\n    return 2 * np.abs(y)\n\n   def d_contrast_fn(y):\n      return 2 * np.ones_like(y)\n\n   fdica = AuxFDICA(\n      contrast_fn=contrast_fn,\n      d_contrast_fn=d_contrast_fn,\n   )\n   spectrogram_mix_whitened = whiten(spectrogram_mix)\n   spectrogram_est = fdica(spectrogram_mix_whitened)\n   spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)\n\n   _, waveform_est = ss.istft(\n      spectrogram_est,\n      window=window,\n      nperseg=n_fft,\n      noverlap=n_fft-hop_length\n   )\n\n   for idx, waveform in enumerate(waveform_est):\n      print(\"Estimated source: {}\".format(idx + 1))\n      ipd.display(ipd.Audio(waveform, rate=16000))\n      print()\n\n   plt.figure()\n   plt.plot(fdica.loss)\n   plt.show()\n\nSubmodules\n----------\n\n.. toctree::\n   :maxdepth: 1\n\n   ssspy.bss\n   ssspy.algorithm\n   ssspy.transform\n   ssspy.linalg\n   ssspy.special\n"
  },
  {
    "path": "docs/changelog.rst",
    "content": ".. include:: ../CHANGELOG.rst\n"
  },
  {
    "path": "docs/conf.py",
    "content": "# Configuration file for the Sphinx documentation builder.\n#\n# This file only contains a selection of the most common options. For a full\n# list see the documentation:\n# https://www.sphinx-doc.org/en/master/usage/configuration.html\n\n# -- Path setup --------------------------------------------------------------\n\n# If extensions (or modules to document with autodoc) are in another directory,\n# add these directories to sys.path here. If the directory is relative to the\n# documentation root, use os.path.abspath to make it absolute, like shown here.\n#\n# import os\n# import sys\n# sys.path.insert(0, os.path.abspath('.'))\n\n\n# -- Project information -----------------------------------------------------\n\nproject = \"ssspy\"\ncopyright = \"2022, Takuya Hasumi\"\nauthor = \"Takuya Hasumi\"\n\n\n# -- General configuration ---------------------------------------------------\n\n# Add any Sphinx extension module names here, as strings. They can be\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom\n# ones.\nextensions = [\n    \"sphinx.ext.autodoc\",\n    \"sphinx.ext.napoleon\",\n    \"sphinx_autodoc_typehints\",\n    \"nbsphinx\",\n]\n\n# Add any paths that contain templates here, relative to this directory.\ntemplates_path = [\"_templates\"]\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\n# This pattern also affects html_static_path and html_extra_path.\nexclude_patterns = [\"_build\", \"Thumbs.db\", \".DS_Store\"]\n\n\n# -- Options for HTML output -------------------------------------------------\n\n# The theme to use for HTML and HTML Help pages.  See the documentation for\n# a list of builtin themes.\n#\nhtml_theme = \"furo\"\n\n# Add any paths that contain custom static files (such as style sheets) here,\n# relative to this directory. They are copied after the builtin static files,\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\n# html_static_path = [\"_static\"]\n"
  },
  {
    "path": "docs/index.rst",
    "content": ".. ssspy documentation master file, created by\n   sphinx-quickstart on Fri Apr 29 20:59:12 2022.\n   You can adapt this file completely to your liking, but it should at least\n   contain the root `toctree` directive.\n\nWelcome to ssspy's documentation!\n=================================\n\n.. image:: https://readthedocs.org/projects/sound-source-separation-python/badge/?version=latest\n   :target: https://sound-source-separation-python.readthedocs.io/en/latest/?badge=latest\n\n.. image:: https://github.com/tky823/ssspy/actions/workflows/lint.yaml/badge.svg\n   :target: https://github.com/tky823/ssspy/actions/workflows/lint.yaml\n\n.. image:: https://codecov.io/gh/tky823/ssspy/branch/main/graph/badge.svg?token=IZ89MTV64G\n   :target: https://codecov.io/gh/tky823/ssspy\n\n``ssspy`` is a Python toolkit for sound source separation.\n\nBuild status\n------------\n\n+--------+--------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------+\n| Python | Ubuntu                                                                                                                         | MacOS (x86_64)                                                                                                            | MacOS (arm64)                                                                                                                 | Windows                                                                                                                         |\n+--------+--------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------+\n| 3.9    | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.9.yaml/badge.svg?branch=main  | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.9.yaml/badge.svg?branch=main  | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.9.yaml/badge.svg?branch=main  | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.9.yaml/badge.svg?branch=main  |\n|        |    :target: https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.9.yaml                       |    :target: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.9.yaml                       |    :target: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.9.yaml                       |    :target: https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.9.yaml                       |\n+--------+--------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+---------------------------------+\n| 3.10   | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.10.yaml/badge.svg?branch=main | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.10.yaml/badge.svg?branch=main | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.10.yaml/badge.svg?branch=main | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.10.yaml/badge.svg?branch=main |\n|        |    :target: https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.10.yaml                      |    :target: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.10.yaml                      |    :target: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.10.yaml                      |    :target: https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.10.yaml                      |\n+--------+--------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+---------------------------------+\n| 3.11   | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.11.yaml/badge.svg?branch=main | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.11.yaml/badge.svg?branch=main | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.11.yaml/badge.svg?branch=main | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.11.yaml/badge.svg?branch=main |\n|        |    :target: https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.11.yaml                      |    :target: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.11.yaml                      |    :target: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.11.yaml                      |    :target: https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.11.yaml                      |\n+--------+--------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+---------------------------------+\n| 3.12   | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.12.yaml/badge.svg?branch=main | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.12.yaml/badge.svg?branch=main | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.12.yaml/badge.svg?branch=main | .. image:: https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.12.yaml/badge.svg?branch=main |\n|        |    :target: https://github.com/tky823/ssspy/actions/workflows/test_package_ubuntu-latest_python-3.12.yaml                      |    :target: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-13_python-3.12.yaml                      |    :target: https://github.com/tky823/ssspy/actions/workflows/test_package_macos-latest_python-3.12.yaml                      |    :target: https://github.com/tky823/ssspy/actions/workflows/test_package_windows-latest_python-3.12.yaml                      |\n+--------+--------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+---------------------------------+\n\nInstallation\n------------\n\nYou can install ``ssspy`` by pip.\n\n.. code-block:: shell\n\n   pip install ssspy\n\nTo install latest version,\n\n.. code-block:: shell\n\n   pip install git+https://github.com/tky823/ssspy.git\n\nInstead, you can build package from source.\n\n.. code-block:: shell\n\n   git clone https://github.com/tky823/ssspy.git\n   cd ssspy\n   pip install -e .\n\n.. note::\n\n   If you fail to install ``ssspy``, please update ``setuptools`` by\n\n   .. code-block:: shell\n\n      python -m pip install --upgrade setuptools\n\n.. note::\n\n   If you cannot install `ssspy` due to failure in building wheel for numpy, please install numpy in advance.\n   \nBuild Documentation Locally (optional)\n--------------------------------------\nTo build the documentation locally, you have to include ``docs`` and ``notebooks`` when installing ``ssspy``.\n\n.. code-block:: shell\n\n   pip install -e \".[docs,notebooks]\"\n\nYou need to convert some notebooks by the following command:\n\n.. code-block:: shell\n\n   . ./docs/pre_build.sh\n\nWhen you build the documentation, run the following command.\n\n.. code-block:: shell\n\n   cd docs/\n   make html\n\nOr, you can build the documentation automatically using ``sphinx-autobuild``.\n\n.. code-block:: shell\n\n   # in ssspy/\n   sphinx-autobuild docs docs/_build/html\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Contents:\n\n   _notebooks/Getting-Started.rst\n   changelog\n   api\n\n\nIndices and tables\n==================\n\n* :ref:`genindex`\n* :ref:`modindex`\n* :ref:`search`\n"
  },
  {
    "path": "docs/make.bat",
    "content": "@ECHO OFF\r\n\r\npushd %~dp0\r\n\r\nREM Command file for Sphinx documentation\r\n\r\nif \"%SPHINXBUILD%\" == \"\" (\r\n\tset SPHINXBUILD=sphinx-build\r\n)\r\nset SOURCEDIR=.\r\nset BUILDDIR=_build\r\n\r\nif \"%1\" == \"\" goto help\r\n\r\n%SPHINXBUILD% >NUL 2>NUL\r\nif errorlevel 9009 (\r\n\techo.\r\n\techo.The 'sphinx-build' command was not found. Make sure you have Sphinx\r\n\techo.installed, then set the SPHINXBUILD environment variable to point\r\n\techo.to the full path of the 'sphinx-build' executable. Alternatively you\r\n\techo.may add the Sphinx directory to PATH.\r\n\techo.\r\n\techo.If you don't have Sphinx installed, grab it from\r\n\techo.https://www.sphinx-doc.org/\r\n\texit /b 1\r\n)\r\n\r\n%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%\r\ngoto end\r\n\r\n:help\r\n%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%\r\n\r\n:end\r\npopd\r\n"
  },
  {
    "path": "docs/pre_build.sh",
    "content": "#!/bin/bash\n\n# TODO: unify .readthedocs.yaml\npip install -e \".[docs,notebooks]\"\n\n# convert .ipynb to .rst format.\njupyter nbconvert --execute notebooks/Examples/Getting-Started.ipynb --to notebook --output-dir docs/_notebooks/\n"
  },
  {
    "path": "docs/ssspy.algorithm.rst",
    "content": "ssspy.algorithm\n===============\n\n``ssspy.algorithm`` provides algorithms related to source separation.\n\nAlgorithms\n~~~~~~~~~~\n.. autofunction:: ssspy.algorithm.projection_back\n"
  },
  {
    "path": "docs/ssspy.bss.admmbss.rst",
    "content": "ssspy.bss.admmbss\n=================\n\nAlgorithms\n~~~~~~~~~~\n.. autoclass:: ssspy.bss.admmbss.ADMMBSSBase\n\n.. autoclass:: ssspy.bss.admmbss.ADMMBSS\n   :special-members: __call__\n   :members: update_once\n"
  },
  {
    "path": "docs/ssspy.bss.base.rst",
    "content": "ssspy.bss.base\n==============\n\nIn this module, we provide base class of blind source separation methods.\n\nAlgorithms\n~~~~~~~~~~\n\n.. autoclass:: ssspy.bss.base.IterativeMethodBase\n   :special-members: __call__\n   :members:\n   :undoc-members:\n"
  },
  {
    "path": "docs/ssspy.bss.cacgmm.rst",
    "content": "ssspy.bss.cacgmm\n================\n\nAlgorithms\n~~~~~~~~~~\n.. autoclass:: ssspy.bss.cacgmm.CACGMM\n    :special-members: __call__\n    :members:\n        separate, normalize_covariance,\n        update_once, update_posterior, update_parameters,\n        compute_loss, solve_permutation\n"
  },
  {
    "path": "docs/ssspy.bss.fdica.rst",
    "content": "ssspy.bss.fdica\n===============\n\nIn this module, we separate multichannel signals\nusing frequency-domain independent component analysis (FDICA).\nWe denote the number of sources and microphones as :math:`N` and :math:`M`, respectively.\nWe also denote short-time Fourier transforms of source, observed, and separated signals\nas :math:`\\boldsymbol{s}_{ij}`, :math:`\\boldsymbol{x}_{ij}`, and :math:`\\boldsymbol{y}_{ij}`,\nrespectively.\n\n.. math::\n   \\boldsymbol{s}_{ij}\n   &= (s_{ij1},\\ldots,s_{ijn},\\ldots,s_{ijN})^{\\mathsf{T}}\\in\\mathbb{C}^{N}, \\\\\n   \\boldsymbol{x}_{ij}\n   &= (x_{ij1},\\ldots,x_{ijm},\\ldots,x_{ijM})^{\\mathsf{T}}\\in\\mathbb{C}^{M}, \\\\\n   \\boldsymbol{y}_{ij}\n   &= (y_{ij1},\\ldots,y_{ijn},\\ldots,y_{ijN})^{\\mathsf{T}}\\in\\mathbb{C}^{N},\n\nwhere :math:`i=1,\\ldots,I` and :math:`j=1,\\ldots,J` are indices of frequency bins and time frames, respectively.\nWhen a mixing system is time-invariant, :math:`\\boldsymbol{x}_{ij}` is represented as follows:\n\n.. math::\n   \\boldsymbol{x}_{ij}\n   = \\boldsymbol{A}_{i}\\boldsymbol{s}_{ij},\n\nwhere :math:`\\boldsymbol{A}_{i}=(\\boldsymbol{a}_{i1},\\ldots,\\boldsymbol{a}_{in},\\ldots,\\boldsymbol{a}_{iN})\\in\\mathbb{C}^{M\\times N}` is\na mixing matrix.\nIf :math:`M=N` and :math:`\\boldsymbol{A}_{i}` is non-singular, a demixing system is represented as\n\n.. math::\n   \\boldsymbol{y}_{ij}\n   = \\boldsymbol{W}_{i}\\boldsymbol{x}_{ij},\n\nwhere :math:`\\boldsymbol{W}_{i}=(\\boldsymbol{w}_{i1},\\ldots,\\boldsymbol{w}_{in},\\ldots,\\boldsymbol{w}_{iN})^{\\mathsf{H}}\\in\\mathbb{C}^{N\\times M}` is\na demixing matrix.\nThe negative log-likelihood of observed signals (divided by :math:`J`) is computed as follows:\n\n.. math::\n   \\mathcal{L}\n   &= -\\frac{1}{J}\\log p(\\mathcal{X}) \\\\\n   &= -\\frac{1}{J}\\left(\\log p(\\mathcal{Y}) \\\n   + \\sum_{i}\\log|\\det\\boldsymbol{W}_{i}|^{2J} \\right) \\\\\n   &= -\\frac{1}{J}\\sum_{i,j,n}\\log p(y_{ijn})\n   - 2\\sum_{i}\\log|\\det\\boldsymbol{W}_{i}| \\\\\n   &= \\sum_{i}\\mathcal{L}^{[i]}, \\\\\n   \\mathcal{L}^{[i]} \\\n   &= \\frac{1}{J}\\sum_{j,n}G(y_{ijn})\n   - 2\\log|\\det\\boldsymbol{W}_{i}|, \\\\\n   G(y_{ijn})\n   &= -\\log p(y_{ijn}),\n\nwhere :math:`G(y_{ijn})` is a contrast function.\nThe derivative of :math:`G(y_{ijn})` is called a score function.\n\n.. math::\n   \\phi(y_{ijn})\n   = \\frac{\\partial G(y_{ijn})}{\\partial y_{ijn}^{*}}.\n\nAlgorithms\n~~~~~~~~~~\n\n.. autoclass:: ssspy.bss.fdica.FDICABase\n   :special-members: __call__\n   :members: separate,\n      compute_loss, compute_logdet,\n      restore_scale, apply_projection_back,\n      solve_permutation\n\n.. autoclass:: ssspy.bss.fdica.GradFDICABase\n   :special-members: __call__\n\n.. autoclass:: ssspy.bss.fdica.GradFDICA\n   :members: update_once\n\n.. autoclass:: ssspy.bss.fdica.NaturalGradFDICA\n   :members: update_once\n\n.. autoclass:: ssspy.bss.fdica.AuxFDICA\n   :special-members: __call__\n   :members: update_once, update_once_ip1, update_once_ip2\n\n.. autoclass:: ssspy.bss.fdica.GradLaplaceFDICA\n\n.. autoclass:: ssspy.bss.fdica.NaturalGradLaplaceFDICA\n\n.. autoclass:: ssspy.bss.fdica.AuxLaplaceFDICA\n"
  },
  {
    "path": "docs/ssspy.bss.hva.rst",
    "content": "ssspy.bss.hva\n=============\n\nAlgorithms\n~~~~~~~~~~\n.. autoclass:: ssspy.bss.hva.MaskingPDSHVA\n\n.. autoclass:: ssspy.bss.hva.MaskingADMMHVA\n\n.. autoclass:: ssspy.bss.hva.HVA\n"
  },
  {
    "path": "docs/ssspy.bss.ica.rst",
    "content": "ssspy.bss.ica\n=============\n\nIn this module, we separate time-domain multichannel signals\nusing independent component analysis (ICA) [#comon1994independent]_.\nWe denote the number of sources and microphones as :math:`N` and :math:`M`, respectively.\nWe also denote source, observed, and separated signals (in time-domain)\nas :math:`\\boldsymbol{s}_{t}`, :math:`\\boldsymbol{x}_{t}`, and :math:`\\boldsymbol{y}_{t}`,\nrespectively.\n\n.. math::\n   \\boldsymbol{s}_{t}\n   &= (s_{t1},\\ldots,s_{tn},\\ldots,s_{tN})^{\\mathsf{T}}\\in\\mathbb{R}^{N}, \\\\\n   \\boldsymbol{x}_{t}\n   &= (x_{t1},\\ldots,x_{tm},\\ldots,x_{tM})^{\\mathsf{T}}\\in\\mathbb{R}^{M}, \\\\\n   \\boldsymbol{y}_{t}\n   &= (y_{t1},\\ldots,y_{tn},\\ldots,y_{tN})^{\\mathsf{T}}\\in\\mathbb{R}^{N},\n\nwhere :math:`t=1,\\ldots,T` is an index of time samples.\nWhen a mixing system is time-invariant, :math:`\\boldsymbol{x}_{t}` is represented as follows:\n\n.. math::\n   \\boldsymbol{x}_{t}\n   = \\boldsymbol{A}\\boldsymbol{s}_{t},\n\nwhere :math:`\\boldsymbol{A}=(\\boldsymbol{a}_{1},\\ldots,\\boldsymbol{a}_{n},\\ldots,\\boldsymbol{a}_{N})\\in\\mathbb{R}^{M\\times N}` is\na mixing matrix.\nIf :math:`M=N` and :math:`\\boldsymbol{A}` is non-singular, a demixing system is represented as\n\n.. math::\n   \\boldsymbol{y}_{t}\n   = \\boldsymbol{W}\\boldsymbol{x}_{t},\n\nwhere :math:`\\boldsymbol{W}=(\\boldsymbol{w}_{1},\\ldots,\\boldsymbol{w}_{n},\\ldots,\\boldsymbol{w}_{N})^{\\mathsf{T}}\\in\\mathbb{R}^{N\\times M}` is\na demixing matrix.\nThe negative log-likelihood of observed signals (divided by :math:`T`) is computed as follows:\n\n.. math::\n   \\mathcal{L}\n   &= -\\frac{1}{T}\\log p(\\mathcal{X}) \\\\\n   &= -\\frac{1}{T}\\left(\\log p(\\mathcal{Y}) \\\n   + \\log|\\det\\boldsymbol{W}|^{T} \\right) \\\\\n   &= -\\frac{1}{T}\\sum_{t,n}\\log p(y_{tn})\n   - \\log|\\det\\boldsymbol{W}| \\\\\n   &= \\frac{1}{T}\\sum_{t,n}G(y_{tn})\n   - \\log|\\det\\boldsymbol{W}|, \\\\\n   G(y_{tn})\n   &= -\\log p(y_{tn}),\n\nwhere :math:`G(y_{tn})` is a contrast function.\nThe derivative of :math:`G(y_{tn})` is called a score function.\n\n.. math::\n   \\phi(y_{tn})\n   = \\frac{\\partial G(y_{tn})}{\\partial y_{ijn}}.\n\n.. [#comon1994independent] P. Comon,\n   \"Independent component analysis, a new concept?\"\n   *Signal Processing*, vol. 36, no. 3, pp. 287-314, 1994.\n\nAlgorithms\n~~~~~~~~~~\n\n.. autoclass:: ssspy.bss.ica.GradICABase\n   :special-members: __call__\n   :members: separate, compute_loss, compute_logdet\n\n.. autoclass:: ssspy.bss.ica.FastICABase\n   :special-members: __call__\n   :members: separate, compute_loss\n\n.. autoclass:: ssspy.bss.ica.GradICA\n   :members: update_once\n\n.. autoclass:: ssspy.bss.ica.NaturalGradICA\n   :members: update_once\n\n.. autoclass:: ssspy.bss.ica.FastICA\n   :members: update_once\n\n.. autoclass:: ssspy.bss.ica.GradLaplaceICA\n   :members: update_once, compute_loss\n\n.. autoclass:: ssspy.bss.ica.NaturalGradLaplaceICA\n   :members: update_once, compute_loss\n"
  },
  {
    "path": "docs/ssspy.bss.ilrma.rst",
    "content": "ssspy.bss.ilrma\n===============\n\nIn this module, we separate multichannel signals\nusing independent low-rank matrix analysis (ILRMA).\nWe denote the number of sources and microphones as :math:`N` and :math:`M`, respectively.\nWe also denote short-time Fourier transforms of source, observed, and separated signals\nas :math:`\\boldsymbol{s}_{ij}`, :math:`\\boldsymbol{x}_{ij}`, and :math:`\\boldsymbol{y}_{ij}`,\nrespectively.\n\n.. math::\n   \\boldsymbol{s}_{ij}\n   &= (s_{ij1},\\ldots,s_{ijn},\\ldots,s_{ijN})^{\\mathsf{T}}\\in\\mathbb{C}^{N}, \\\\\n   \\boldsymbol{x}_{ij}\n   &= (x_{ij1},\\ldots,x_{ijm},\\ldots,x_{ijM})^{\\mathsf{T}}\\in\\mathbb{C}^{M}, \\\\\n   \\boldsymbol{y}_{ij}\n   &= (y_{ij1},\\ldots,y_{ijn},\\ldots,y_{ijN})^{\\mathsf{T}}\\in\\mathbb{C}^{N},\n\nwhere :math:`i=1,\\ldots,I` and :math:`j=1,\\ldots,J` are indices of frequency bins and time frames, respectively.\nWhen a mixing system is time-invariant, :math:`\\boldsymbol{x}_{ij}` is represented as follows:\n\n.. math::\n   \\boldsymbol{x}_{ij}\n   = \\boldsymbol{A}_{i}\\boldsymbol{s}_{ij},\n\nwhere :math:`\\boldsymbol{A}_{i}=(\\boldsymbol{a}_{i1},\\ldots,\\boldsymbol{a}_{in},\\ldots,\\boldsymbol{a}_{iN})\\in\\mathbb{C}^{M\\times N}` is\na mixing matrix.\nIf :math:`M=N` and :math:`\\boldsymbol{A}_{i}` is non-singular, a demixing system is represented as\n\n.. math::\n   \\boldsymbol{y}_{ij}\n   = \\boldsymbol{W}_{i}\\boldsymbol{x}_{ij},\n\nwhere :math:`\\boldsymbol{W}_{i}=(\\boldsymbol{w}_{i1},\\ldots,\\boldsymbol{w}_{in},\\ldots,\\boldsymbol{w}_{iN})^{\\mathsf{H}}\\in\\mathbb{C}^{N\\times M}` is\na demixing matrix.\nThe negative log-likelihood of observed signals (divided by :math:`J`) is computed as follows:\n\n.. math::\n   \\mathcal{L}\n   &= -\\frac{1}{J}\\log p(\\mathcal{X}) \\\\\n   &= -\\frac{1}{J}\\left(\\log p(\\mathcal{Y}) \\\n   + \\sum_{i}\\log|\\det\\boldsymbol{W}_{i}|^{2J} \\right) \\\\\n   &= -\\frac{1}{J}\\sum_{i,j,n}\\log p(y_{ijn})\n   - 2\\sum_{i}\\log|\\det\\boldsymbol{W}_{i}|.\n\nAlgorithms\n~~~~~~~~~~\n.. autoclass:: ssspy.bss.ilrma.ILRMABase\n   :special-members: __call__\n   :members:\n      _init_nmf, separate, reconstruct_nmf, update_once,\n      normalize, normalize_by_power, normalize_by_projection_back,\n      compute_loss, compute_logdet, restore_scale, apply_projection_back\n\n.. autoclass:: ssspy.bss.ilrma.GaussILRMA\n   :special-members: __call__\n   :members:\n      update_once,\n      update_source_model, update_source_model_mm, update_source_model_me,\n      update_latent_mm, update_basis_mm, update_activation_mm,\n      update_latent_me, update_basis_me, update_activation_me,\n      update_spatial_model, update_spatial_model_ip1, update_spatial_model_ip2, update_spatial_model_iss1, update_spatial_model_iss2, update_spatial_model_ipa,\n      compute_loss, apply_projection_back\n\n.. autoclass:: ssspy.bss.ilrma.TILRMA\n   :special-members: __call__\n   :members:\n      update_once,\n      update_source_model, update_source_model_mm, update_source_model_me,\n      update_latent_mm, update_basis_mm, update_activation_mm,\n      update_latent_me, update_basis_me, update_activation_me,\n      update_spatial_model, update_spatial_model_ip1, update_spatial_model_ip2, update_spatial_model_iss1, update_spatial_model_iss2,\n      compute_loss, apply_projection_back\n\n.. autoclass:: ssspy.bss.ilrma.GGDILRMA\n   :special-members: __call__\n   :members:\n      update_once,\n      update_source_model, update_source_model_mm,\n      update_latent_mm, update_basis_mm, update_activation_mm,\n      update_spatial_model, update_spatial_model_ip1, update_spatial_model_ip2, update_spatial_model_iss1, update_spatial_model_iss2,\n      compute_loss, apply_projection_back\n"
  },
  {
    "path": "docs/ssspy.bss.iva.rst",
    "content": "ssspy.bss.iva\n=============\n\nIn this module, we separate multichannel signals\nusing independent vector analysis (IVA).\nWe denote the number of sources and microphones as :math:`N` and :math:`M`, respectively.\nWe also denote short-time Fourier transforms of source, observed, and separated signals\nas :math:`\\boldsymbol{s}_{ij}`, :math:`\\boldsymbol{x}_{ij}`, and :math:`\\boldsymbol{y}_{ij}`,\nrespectively.\n\n.. math::\n   \\boldsymbol{s}_{ij}\n   &= (s_{ij1},\\ldots,s_{ijn},\\ldots,s_{ijN})^{\\mathsf{T}}\\in\\mathbb{C}^{N}, \\\\\n   \\boldsymbol{x}_{ij}\n   &= (x_{ij1},\\ldots,x_{ijm},\\ldots,x_{ijM})^{\\mathsf{T}}\\in\\mathbb{C}^{M}, \\\\\n   \\boldsymbol{y}_{ij}\n   &= (y_{ij1},\\ldots,y_{ijn},\\ldots,y_{ijN})^{\\mathsf{T}}\\in\\mathbb{C}^{N},\n\nwhere :math:`i=1,\\ldots,I` and :math:`j=1,\\ldots,J` are indices of frequency bins and time frames, respectively.\nWe also define the following vector:\n\n.. math::\n   \\vec{\\boldsymbol{y}}_{jn}\n   = (y_{1jn},\\ldots,y_{ijn},\\ldots,y_{Ijn})^{\\mathsf{T}}\\in\\mathbb{C}^{I}.\n\nWhen a mixing system is time-invariant, :math:`\\boldsymbol{x}_{ij}` is represented as follows:\n\n.. math::\n   \\boldsymbol{x}_{ij}\n   = \\boldsymbol{A}_{i}\\boldsymbol{s}_{ij},\n\nwhere :math:`\\boldsymbol{A}_{i}=(\\boldsymbol{a}_{i1},\\ldots,\\boldsymbol{a}_{in},\\ldots,\\boldsymbol{a}_{iN})\\in\\mathbb{C}^{M\\times N}` is\na mixing matrix.\nIf :math:`M=N` and :math:`\\boldsymbol{A}_{i}` is non-singular, a demixing system is represented as\n\n.. math::\n   \\boldsymbol{y}_{ij}\n   = \\boldsymbol{W}_{i}\\boldsymbol{x}_{ij},\n\nwhere :math:`\\boldsymbol{W}_{i}=(\\boldsymbol{w}_{i1},\\ldots,\\boldsymbol{w}_{in},\\ldots,\\boldsymbol{w}_{iN})^{\\mathsf{H}}\\in\\mathbb{C}^{N\\times M}` is\na demixing matrix.\nThe negative log-likelihood of observed signals (divided by :math:`J`) is computed as follows:\n\n.. math::\n   \\mathcal{L}\n   &= -\\frac{1}{J}\\log p(\\mathcal{X}) \\\\\n   &= -\\frac{1}{J}\\left(\\log p(\\mathcal{Y}) \\\n   + \\sum_{i}\\log|\\det\\boldsymbol{W}_{i}|^{2J} \\right) \\\\\n   &= -\\frac{1}{J}\\sum_{j,n}\\log p(\\vec{\\boldsymbol{y}}_{jn})\n   - 2\\sum_{i}\\log|\\det\\boldsymbol{W}_{i}| \\\\\n   &= \\frac{1}{J}\\sum_{j,n}G(\\vec{\\boldsymbol{y}}_{jn})\n   - 2\\sum_{i}\\log|\\det\\boldsymbol{W}_{i}|, \\\\\n   G(\\vec{\\boldsymbol{y}}_{jn})\n   &= -\\log p(\\vec{\\boldsymbol{y}}_{jn}),\n\nwhere :math:`G(\\vec{\\boldsymbol{y}}_{jn})` is a contrast function.\nThe derivative of :math:`G(\\vec{\\boldsymbol{y}}_{jn})` is called a score function.\n\n.. math::\n   \\phi_{i}(\\vec{\\boldsymbol{y}}_{jn})\n   = \\frac{\\partial G(\\vec{\\boldsymbol{y}}_{jn})}{\\partial y_{ijn}^{*}}.\n\nAlgorithms\n~~~~~~~~~~\n.. autoclass:: ssspy.bss.iva.IVABase\n   :special-members: __call__\n   :members: separate, update_once, compute_loss, compute_logdet, restore_scale, apply_projection_back\n\n.. autoclass:: ssspy.bss.iva.GradIVABase\n\n.. autoclass:: ssspy.bss.iva.FastIVABase\n   :members: separate, compute_loss, apply_projection_back\n\n.. autoclass:: ssspy.bss.iva.AuxIVABase\n   :special-members: __call__\n   :members: separate, compute_loss, apply_projection_back\n\n.. autoclass:: ssspy.bss.iva.GradIVA\n   :members: update_once\n\n.. autoclass:: ssspy.bss.iva.NaturalGradIVA\n   :members: update_once\n\n.. autoclass:: ssspy.bss.iva.FastIVA\n   :special-members: __call__\n   :members: update_once\n\n.. autoclass:: ssspy.bss.iva.FasterIVA\n   :special-members: __call__\n   :members: update_once\n\n.. autoclass:: ssspy.bss.iva.AuxIVA\n   :special-members: __call__\n   :members: update_once, update_once_ip1, update_once_ip2, update_once_iss1, update_once_iss2, update_once_ipa\n\n.. autoclass:: ssspy.bss.iva.GradLaplaceIVA\n   :members: update_once, compute_loss\n\n.. autoclass:: ssspy.bss.iva.GradGaussIVA\n   :members: update_once, update_source_model\n\n.. autoclass:: ssspy.bss.iva.NaturalGradLaplaceIVA\n   :members: update_once, compute_loss\n\n.. autoclass:: ssspy.bss.iva.NaturalGradGaussIVA\n   :members: update_once, compute_loss\n\n.. autoclass:: ssspy.bss.iva.AuxLaplaceIVA\n\n.. autoclass:: ssspy.bss.iva.AuxGaussIVA\n   :members: update_once, update_source_model\n"
  },
  {
    "path": "docs/ssspy.bss.mnmf.rst",
    "content": "ssspy.bss.mnmf\n==============\n\n\nAlgorithms\n~~~~~~~~~~\n.. autoclass:: ssspy.bss.mnmf.FastMNMFBase\n    :special-members: __call__\n    :members: normalize, normalize_by_power\n\n.. autoclass:: ssspy.bss.mnmf.FastGaussMNMF\n    :special-members: __call__\n    :members: separate,\n        compute_loss, compute_logdet,\n        update_once, update_basis, update_activation, update_diagonalizer, update_spatial,\n        update_diagonalizer_ip1, update_diagonalizer_ip2\n"
  },
  {
    "path": "docs/ssspy.bss.pdsbss.rst",
    "content": "ssspy.bss.pdsbss\n================\n\nIn this module, we separate multichannel signals\nusing blind source separation via primal dual splitting algorithm.\nWe denote the number of sources and microphones as :math:`N` and :math:`M`, respectively.\nWe also denote short-time Fourier transforms of source, observed, and separated signals\nas :math:`\\boldsymbol{s}_{ij}`, :math:`\\boldsymbol{x}_{ij}`, and :math:`\\boldsymbol{y}_{ij}`,\nrespectively.\n\n.. math::\n   \\boldsymbol{s}_{ij}\n   &= (s_{ij1},\\ldots,s_{ijn},\\ldots,s_{ijN})^{\\mathsf{T}}\\in\\mathbb{C}^{N}, \\\\\n   \\boldsymbol{x}_{ij}\n   &= (x_{ij1},\\ldots,x_{ijm},\\ldots,x_{ijM})^{\\mathsf{T}}\\in\\mathbb{C}^{M}, \\\\\n   \\boldsymbol{y}_{ij}\n   &= (y_{ij1},\\ldots,y_{ijn},\\ldots,y_{ijN})^{\\mathsf{T}}\\in\\mathbb{C}^{N},\n\nwhere :math:`i=1,\\ldots,I` and :math:`j=1,\\ldots,J` are indices of frequency bins and time frames, respectively.\nWhen a mixing system is time-invariant, :math:`\\boldsymbol{x}_{ij}` is represented as follows:\n\n.. math::\n   \\boldsymbol{x}_{ij}\n   = \\boldsymbol{A}_{i}\\boldsymbol{s}_{ij},\n\nwhere :math:`\\boldsymbol{A}_{i}=(\\boldsymbol{a}_{i1},\\ldots,\\boldsymbol{a}_{in},\\ldots,\\boldsymbol{a}_{iN})\\in\\mathbb{C}^{M\\times N}` is\na mixing matrix.\nIf :math:`M=N` and :math:`\\boldsymbol{A}_{i}` is non-singular, a demixing system is represented as\n\n.. math::\n   \\boldsymbol{y}_{ij}\n   = \\boldsymbol{W}_{i}\\boldsymbol{x}_{ij},\n\nwhere :math:`\\boldsymbol{W}_{i}=(\\boldsymbol{w}_{i1},\\ldots,\\boldsymbol{w}_{in},\\ldots,\\boldsymbol{w}_{iN})^{\\mathsf{H}}\\in\\mathbb{C}^{N\\times M}` is\na demixing matrix.\nThe negative log-likelihood of observed signals (divided by :math:`2J`) is computed as follows:\n\n.. math::\n   \\mathcal{L}\n   &= \\mathcal{P}(\\mathcal{V}(\\mathcal{Y}))\n   + \\sum_{i}\\mathcal{I}(\\boldsymbol{W}_{i}), \\\\\n   \\mathcal{V}(\\mathcal{Y})\n   &:= (y_{111},\\ldots,y_{11N},\\ldots,y_{1JN},\\ldots,y_{IJN})^{\\mathsf{T}}\n   \\in\\mathbb{C}^{IJN} \\\\\n   \\mathcal{I}(\\boldsymbol{W}_{i})\n   &= - \\log|\\det\\boldsymbol{W}_{i}|,\n\nwhere :math:`\\mathcal{P}` is a penalty funcion that is determined by the source model.\n\nLet us consider independent vector analysis.\nIn this case, :math:`\\mathcal{P}` can be written by\n\n.. math::\n   \\mathcal{P}(\\mathcal{V}(\\mathcal{Y}))\n   = C\\sum_{j,n}\\left(\n   \\sum_{i}\\left|\\boldsymbol{w}_{in}^{\\mathsf{H}}\\boldsymbol{x}_{ij}\\right|^{2}\n   \\right)^{\\frac{1}{2}},\n\nwhere :math:`C` is a positive constant.\n\nTo the above formulation, we can apply the primal-dual splitting algorithm.\nOn the basis of this algorithm, the demixing filter is updated as follows:\n\n.. math::\n   \\tilde{\\boldsymbol{W}}_{i}\n   &\\leftarrow\\mathrm{prox}_{\\mu_{1}\\mathcal{I}}\n   \\left[\\boldsymbol{W}_{i} - \\mu_{1}\\mu_{2}\\sum_{j}\\boldsymbol{u}_{ij}\\boldsymbol{x}_{ij}^{\\mathsf{H}}\\right] \\\\\n   \\boldsymbol{z}_{ij}\n   &\\leftarrow\\boldsymbol{u}_{ij} + \\left(2 * \\tilde{\\boldsymbol{W}}_{i} - \\boldsymbol{W}_{i}\\right)\\boldsymbol{x}_{ij} \\\\\n   \\mathcal{V}(\\tilde{\\mathcal{U}})\n   &\\leftarrow\\mathcal{V}(\\mathcal{Z})\n   - \\mathrm{prox}_{\\mathcal{P}/\\mu_{2}}\\left[\\mathcal{V}(\\mathcal{Z})\\right] \\\\\n   \\boldsymbol{u}_{ij}\n   &\\leftarrow\\alpha\\tilde{\\boldsymbol{u}}_{ij} + (1 - \\alpha)\\boldsymbol{u}_{ij}, \\\\\n   \\boldsymbol{W}_{i}\n   &\\leftarrow\\alpha\\tilde{\\boldsymbol{W}}_{i} + (1 - \\alpha)\\boldsymbol{W}_{i}.\n\n:math:`\\boldsymbol{u}_{ij}` is a dual variable, which should be initialized by a certain value.\n:math:`\\mathrm{prox}_{g}` is a proximal operator defined as\n\n.. math::\n   \\mathrm{prox}_{g}[\\boldsymbol{z}]\n   = \\mathrm{argmin}_{\\boldsymbol{y}}\n   ~~g(\\boldsymbol{y}) + \\frac{1}{2}\\|\\boldsymbol{z} - \\boldsymbol{y}\\|_{2}^{2}.\n\nFor :math:`\\mathcal{I}`, we can obatain the following proximal operator:\n\n.. math::\n   \\mathrm{prox}_{\\mu\\mathcal{I}}[\\boldsymbol{W}_{i}]\n   &= \\boldsymbol{U}_{i}\\tilde{\\boldsymbol{\\Sigma}}_{i}\\boldsymbol{V}_{i}^{\\mathsf{H}}, \\\\\n   \\tilde{\\boldsymbol{\\Sigma}}_{i}\n   &= \\mathrm{diag}(\\tilde{\\sigma}_{i1},\\ldots,\\tilde{\\sigma}_{iN}), \\\\\n   \\tilde{\\sigma}_{in}\n   &= \\frac{\\sigma_{in} + \\sqrt{\\sigma_{in}^{2} + 4\\mu}}{2},\n\nwhere :math:`\\boldsymbol{U}_{i}`, :math:`\\boldsymbol{V}_{i}`,\nand :math:`\\boldsymbol{\\Sigma}_{i}=\\mathrm{diag}(\\sigma_{i1},\\ldots,\\sigma_{iN})` are singular value decomposition.\n\n.. math::\n   \\boldsymbol{W}_{i}\n   = \\boldsymbol{U}_{i}\\boldsymbol{\\Sigma}_{i}\\boldsymbol{V}_{i}^{\\mathsf{H}}.\n\nWhen :math:`\\mathcal{P}` is defined as\n\n.. math::\n   \\mathcal{P}(\\mathcal{V}(\\mathcal{Y}))\n   = C\\sum_{j,n}\\left(\n   \\sum_{i}\\left|\\boldsymbol{w}_{in}^{\\mathsf{H}}\\boldsymbol{x}_{ij}\\right|^{2}\n   \\right)^{\\frac{1}{2}},\n\nthe updates by the proximal operator can be written as\n\n.. math::\n   y_{ijn}\n   \\leftarrow\\left(1 - \\frac{\\mu}{\\sqrt{\\sum_{i}|y_{ijn}|^{2}}}\\right)_{+}y_{ijn}.\n\nAlgorithms\n~~~~~~~~~~\n.. autoclass:: ssspy.bss.pdsbss.PDSBSSBase\n\n.. autoclass:: ssspy.bss.pdsbss.PDSBSS\n   :special-members: __call__\n   :members: update_once\n"
  },
  {
    "path": "docs/ssspy.bss.proxbss.rst",
    "content": "ssspy.bss.proxbss\n=================\n\nAlgorithms\n~~~~~~~~~~\n.. autoclass:: ssspy.bss.proxbss.ProxBSSBase\n    :special-members: __call__\n    :members: separate, compute_loss, compute_logdet, normalize_by_spectral_norm, restore_scale, apply_projection_back, apply_minimal_distortion_principle\n"
  },
  {
    "path": "docs/ssspy.bss.rst",
    "content": "ssspy.bss\n=========\n\n``ssspy.bss`` provides various blind source separation methods.\n\nSubmodules\n~~~~~~~~~~\n\n.. toctree::\n   :maxdepth: 1\n\n   ssspy.bss.base\n   ssspy.bss.ica\n   ssspy.bss.fdica\n   ssspy.bss.iva\n   ssspy.bss.ilrma\n   ssspy.bss.mnmf\n   ssspy.bss.proxbss\n   ssspy.bss.pdsbss\n   ssspy.bss.admmbss\n   ssspy.bss.hva\n   ssspy.bss.cacgmm\n"
  },
  {
    "path": "docs/ssspy.linalg.rst",
    "content": "ssspy.linalg\n============\n\n``ssspy.linalg`` is linear algebra module related to source separation.\n\nAlgorithms\n~~~~~~~~~~\n.. autofunction:: ssspy.linalg.inv2\n\n.. autofunction:: ssspy.linalg.eigh\n\n.. autofunction:: ssspy.linalg.eigh2\n\n.. autofunction:: ssspy.linalg.gmeanmh\n\n.. autofunction:: ssspy.linalg.lqpqm2\n"
  },
  {
    "path": "docs/ssspy.special.rst",
    "content": "ssspy.special\n=============\n\n``ssspy.special`` is a module related to special function.\n\nAlgorithms\n~~~~~~~~~~\n.. autofunction:: ssspy.special.logsumexp\n\n.. autofunction:: ssspy.special.softmax\n"
  },
  {
    "path": "docs/ssspy.transform.rst",
    "content": "ssspy.transform\n===============\n\n``ssspy.transform`` provides transforms related to source separation.\n\nAlgorithms\n~~~~~~~~~~\n.. autofunction:: ssspy.transform.pca\n\n.. autofunction:: ssspy.transform.whiten\n"
  },
  {
    "path": "notebooks/BSS/ADMMBSS/ADMMBSS.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[],\"authorship_tag\":\"ABX9TyPQC07HvYHC5TQxnCCmzCwE\"},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"IuOxOjnRWK-4\"},\"outputs\":[],\"source\":[\"!pip install git+https://github.com/tky823/ssspy.git\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"CdZN4hyHWOy8\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"LcS_A4hyWRqL\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 2\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"O7PTMc5qWTWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"K4qwWkHcWvzv\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"x3gWs4LgWxSD\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.transform import whiten\\n\",\"from ssspy.algorithm import projection_back\\n\",\"from ssspy.bss.admmbss import ADMMBSS\\n\",\"from ssspy.linalg import prox\"],\"metadata\":{\"id\":\"j7Gd1k-8Wy0L\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"def l21_fn(y: np.ndarray) -> np.ndarray:\\n\",\"    \\\"\\\"\\\"Mixed L21 norm.\\n\",\"\\n\",\"    Args:\\n\",\"        y (np.ndarray):\\n\",\"            Input vector with shape of (n_sources, n_bins, n_frames).\\n\",\"\\n\",\"    Returns:\\n\",\"        Sum of mixed L21 norm.\\n\",\"    \\\"\\\"\\\"\\n\",\"    G = np.linalg.norm(y, axis=1)\\n\",\"    loss = np.sum(G, axis=(0, 1))\\n\",\"\\n\",\"    return loss\\n\",\"\\n\",\"def prox_l21(y, step_size: float = 1) -> np.ndarray:\\n\",\"    \\\"\\\"\\\"Apply proximal operator of mixed L21 norm.\\n\",\"\\n\",\"    Args:\\n\",\"        y (np.ndarray):\\n\",\"            Input vector with shape of (n_sources, n_bins, n_frames).\\n\",\"        step_size (float):\\n\",\"            Step size parameter.\\n\",\"\\n\",\"    Returns:\\n\",\"        Output value computed by proximal operator of mixed L21 norm.\\n\",\"        The shape of (n_sources, n_bins, n_frames).\\n\",\"    \\\"\\\"\\\"\\n\",\"    norm = np.linalg.norm(y, axis=1, keepdims=True)\\n\",\"\\n\",\"    # to suppress warning RuntimeWarning\\n\",\"    norm = np.where(norm < step_size, step_size, norm)\\n\",\"\\n\",\"    return np.maximum(1 - step_size / norm, 0) * y\"],\"metadata\":{\"id\":\"w87aPEguW1jU\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"admm_bss = ADMMBSS(\\n\",\"    rho=0.5,\\n\",\"    relaxation=1.75,\\n\",\"    penalty_fn=l21_fn,\\n\",\"    prox_penalty=prox_l21,\\n\",\"    scale_restoration=False,\\n\",\")\\n\",\"print(admm_bss)\"],\"metadata\":{\"id\":\"YrwbKut7W9AE\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"xoI6OCVvXAOm\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_mix_whitened = whiten(spectrogram_mix)\\n\",\"spectrogram_mix_normalized = admm_bss.normalize_by_spectral_norm(spectrogram_mix_whitened)\\n\",\"spectrogram_est = admm_bss(spectrogram_mix_normalized, n_iter=1000)\\n\",\"spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)\"],\"metadata\":{\"id\":\"L6B1UOU0XCJf\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EoeJuHniXDiG\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"7c0RlLEFXGgM\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(admm_bss.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"XnPQAUxGXH2i\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"y9ydVrlzXaQH\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/ADMMBSS/ADMMBSS_multi-penalty.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[],\"authorship_tag\":\"ABX9TyOoGT5woLRACpymG/HuTm0y\"},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"2A1FpkQ-h3ks\"},\"outputs\":[],\"source\":[\"!pip install git+https://github.com/tky823/ssspy.git\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"T-AFlnULiPGj\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"SLUkWHVGiQN4\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 2\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"UHG42oI4iRPU\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"1jdtdFu9iSRX\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"vCzucZlBiTps\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"import functools\"],\"metadata\":{\"id\":\"wqekYV1aiU0n\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.transform import whiten\\n\",\"from ssspy.algorithm import projection_back\\n\",\"from ssspy.bss.admmbss import ADMMBSS\"],\"metadata\":{\"id\":\"-78PZk8WiV2z\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"def l21_fn(y: np.ndarray) -> np.ndarray:\\n\",\"    \\\"\\\"\\\"Compute sum of mixed L21 norm.\\n\",\"\\n\",\"    Args:\\n\",\"        y (np.ndarray):\\n\",\"            Input vector with shape of (n_sources, n_bins, n_frames).\\n\",\"\\n\",\"    Returns:\\n\",\"        Sum of mixed L21 norm.\\n\",\"    \\\"\\\"\\\"\\n\",\"    G = np.linalg.norm(y, axis=1)\\n\",\"    loss = np.sum(G, axis=(0, 1))\\n\",\"\\n\",\"    return loss\\n\",\"\\n\",\"def lamb_l1_fn(y: np.ndarray, lamb: float = 1) -> np.ndarray:\\n\",\"    \\\"\\\"\\\"Compute sum of L1 norm.\\n\",\"\\n\",\"    Args:\\n\",\"        y (np.ndarray):\\n\",\"            Input vector with shape of (n_sources, n_bins, n_frames).\\n\",\"\\n\",\"    Returns:\\n\",\"        Sum of L1 norm.\\n\",\"    \\\"\\\"\\\"\\n\",\"    G = np.abs(y)\\n\",\"    loss = lamb * np.sum(G, axis=(0, 1, 2))\\n\",\"\\n\",\"    return loss\\n\",\"\\n\",\"def prox_l21(y: np.ndarray, step_size: float = 1) -> np.ndarray:\\n\",\"    \\\"\\\"\\\"Apply proximal operator of mixed L21 norm.\\n\",\"\\n\",\"    Args:\\n\",\"        y (np.ndarray):\\n\",\"            Input vector with shape of (n_sources, n_bins, n_frames).\\n\",\"        step_size (float):\\n\",\"            Step size parameter.\\n\",\"\\n\",\"    Returns:\\n\",\"        Output value computed by proximal operator of mixed L21 norm.\\n\",\"        The shape of (n_sources, n_bins, n_frames).\\n\",\"    \\\"\\\"\\\"\\n\",\"    norm = np.linalg.norm(y, axis=1, keepdims=True)\\n\",\"\\n\",\"    # to suppress warning RuntimeWarning\\n\",\"    norm = np.where(norm < step_size, step_size, norm)\\n\",\"\\n\",\"    return np.maximum(1 - step_size / norm, 0) * y\\n\",\"\\n\",\"def prox_lamb_l1(y: np.ndarray, step_size: float=1, lamb: float = 1) -> np.ndarray:\\n\",\"    \\\"\\\"\\\"Apply proximal operator of L1 norm.\\n\",\"\\n\",\"    Args:\\n\",\"        y (np.ndarray):\\n\",\"            Input vector with shape of (n_sources, n_bins, n_frames).\\n\",\"        step_size (float):\\n\",\"            Step size parameter.\\n\",\"\\n\",\"    Returns:\\n\",\"        Output value computed by proximal operator of L1 norm.\\n\",\"        The shape of (n_sources, n_bins, n_frames).\\n\",\"    \\\"\\\"\\\"\\n\",\"    norm = np.abs(y)\\n\",\"\\n\",\"    # to suppress warning RuntimeWarning\\n\",\"    norm = np.where(norm < step_size * lamb, step_size * lamb, norm)\\n\",\"\\n\",\"    return np.maximum(1 - (step_size * lamb) / norm, 0) * y\"],\"metadata\":{\"id\":\"FrDLnrzOiYYM\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"# SparseIVA without masking\\n\",\"penalty_fn = [\\n\",\"    l21_fn,\\n\",\"    functools.partial(lamb_l1_fn, lamb=1e-4),\\n\",\"]\\n\",\"prox_penalty = [\\n\",\"    prox_l21,\\n\",\"    functools.partial(prox_lamb_l1, lamb=1e-4),\\n\",\"]\"],\"metadata\":{\"id\":\"wk0H7Lj9iaOQ\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"admm_bss = ADMMBSS(\\n\",\"    rho=0.5,\\n\",\"    relaxation=1.75,\\n\",\"    penalty_fn=penalty_fn,\\n\",\"    prox_penalty=prox_penalty,\\n\",\"    scale_restoration=False\\n\",\")\\n\",\"print(admm_bss)\"],\"metadata\":{\"id\":\"lguDznJaihz_\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"T4aYnKTXirf4\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_mix_whitened = whiten(spectrogram_mix)\\n\",\"spectrogram_mix_normalized = admm_bss.normalize_by_spectral_norm(spectrogram_mix_whitened)\\n\",\"spectrogram_est = admm_bss(spectrogram_mix_normalized, n_iter=1000)\\n\",\"spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)\"],\"metadata\":{\"id\":\"w9DwgYhOitYv\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"TuUenwuBivge\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"8HjLaHqnixOp\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(admm_bss.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"ENGBhkjHiyTi\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"dssD7pzZkZ0H\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/CACGMM/CACGMM.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"mPK5sQpmunbL\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\\n\",\"from tqdm.notebook import tqdm\"],\"metadata\":{\"id\":\"gF-CVqpZuq7y\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"LJto4YUHusBK\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 2\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"RypuaK8GutWj\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"F8UVIE67uwNv\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"M6rHkLzcuzZL\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.cacgmm import CACGMM as CACGMMBase\"],\"metadata\":{\"id\":\"mBxB5tm2u0X7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"class CACGMM(CACGMMBase):\\n\",\"    def __init__(self, *args, **kwargs):\\n\",\"        super().__init__(*args, **kwargs)\\n\",\"\\n\",\"        self.progress_bar = None\\n\",\"\\n\",\"    def __call__(\\n\",\"        self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs\\n\",\"    ) -> np.ndarray:\\n\",\"        self.n_iter = n_iter\\n\",\"\\n\",\"        return super().__call__(input, n_iter=n_iter, initial_call=initial_call, **kwargs)\\n\",\"\\n\",\"    def update_once(self) -> None:\\n\",\"        if self.progress_bar is None:\\n\",\"            self.progress_bar = tqdm(total=self.n_iter)\\n\",\"\\n\",\"        super().update_once()\\n\",\"\\n\",\"        self.progress_bar.update(1)\"],\"metadata\":{\"id\":\"12peq3Xyu2WM\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"cacgmm = CACGMM(\\n\",\"    permutation_alignment=\\\"posterior_score\\\",\\n\",\"    global_iter=100,\\n\",\"    local_iter=100,\\n\",\"    rng=np.random.default_rng(42)\\n\",\")\\n\",\"print(cacgmm)\"],\"metadata\":{\"id\":\"P7QNo371u3YP\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"g-vqiYCUu6id\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = cacgmm(spectrogram_mix, n_iter=200)\"],\"metadata\":{\"id\":\"QQ3MmBwpu6f3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"tMz1cFCKu9Nk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"La9fImOkvAE5\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(cacgmm.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"ySSBSypivAAN\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"AyyuPUvlyTxN\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/FDICA/AuxFDICA-IP1.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"tLYLOhzujeH1\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"3WPpY4HOjmvT\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"AdivKHWiDs5L\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 2\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"Xve1xr0Ijn-O\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"ist3AxTsjtLi\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"LFWwc0y9kJDq\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.transform import whiten\\n\",\"from ssspy.algorithm import projection_back\\n\",\"from ssspy.bss.fdica import AuxFDICA\"],\"metadata\":{\"id\":\"ArUmiPttkKYH\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"def contrast_fn(y):\\n\",\"    return 2 * np.abs(y)\\n\",\"\\n\",\"def d_contrast_fn(y):\\n\",\"    return 2 * np.ones_like(y)\"],\"metadata\":{\"id\":\"6lTfgEW9kNN0\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"fdica = AuxFDICA(\\n\",\"    spatial_algorithm=\\\"IP1\\\",  # You can set \\\"IP\\\" instead of \\\"IP1\\\".\\n\",\"    contrast_fn=contrast_fn,\\n\",\"    d_contrast_fn=d_contrast_fn,\\n\",\"    scale_restoration=False\\n\",\")\\n\",\"print(fdica)\"],\"metadata\":{\"id\":\"y0T8jywSkRyb\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jHa8QSFtkaEM\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_mix_whitened = whiten(spectrogram_mix)\\n\",\"spectrogram_est = fdica(spectrogram_mix_whitened, n_iter=20)\\n\",\"spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)\"],\"metadata\":{\"id\":\"tHNkZgPakcAc\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"ZHugIyaRkdID\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"NYCtIWSbke2a\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(fdica.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"O7vW3o-ykhCN\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"DSAFgg91kjsb\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/FDICA/AuxFDICA-IP2.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"LbrobIqgDyVY\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 2\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.transform import whiten\\n\",\"from ssspy.algorithm import projection_back\\n\",\"from ssspy.bss.fdica import AuxFDICA\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"def contrast_fn(y):\\n\",\"    return 2 * np.abs(y)\\n\",\"\\n\",\"def d_contrast_fn(y):\\n\",\"    return 2 * np.ones_like(y)\"],\"metadata\":{\"id\":\"6AvhtkrdfAeZ\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"fdica = AuxFDICA(\\n\",\"    spatial_algorithm=\\\"IP2\\\",\\n\",\"    contrast_fn=contrast_fn,\\n\",\"    d_contrast_fn=d_contrast_fn,\\n\",\"    scale_restoration=False\\n\",\")\\n\",\"print(fdica)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_mix_whitened = whiten(spectrogram_mix)\\n\",\"spectrogram_est = fdica(spectrogram_mix_whitened, n_iter=50)\\n\",\"spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(fdica.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/FDICA/AuxLaplaceFDICA-IP1.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"tLYLOhzujeH1\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"3WPpY4HOjmvT\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"AdivKHWiDs5L\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 2\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"Xve1xr0Ijn-O\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"ist3AxTsjtLi\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"LFWwc0y9kJDq\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.transform import whiten\\n\",\"from ssspy.algorithm import projection_back\\n\",\"from ssspy.bss.fdica import AuxLaplaceFDICA\"],\"metadata\":{\"id\":\"ArUmiPttkKYH\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"fdica = AuxLaplaceFDICA(\\n\",\"    spatial_algorithm=\\\"IP1\\\",  # You can set \\\"IP\\\" instead of \\\"IP1\\\".\\n\",\"    scale_restoration=False\\n\",\")\\n\",\"print(fdica)\"],\"metadata\":{\"id\":\"y0T8jywSkRyb\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jHa8QSFtkaEM\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_mix_whitened = whiten(spectrogram_mix)\\n\",\"spectrogram_est = fdica(spectrogram_mix_whitened, n_iter=20)\\n\",\"spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)\"],\"metadata\":{\"id\":\"tHNkZgPakcAc\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"ZHugIyaRkdID\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"NYCtIWSbke2a\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(fdica.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"O7vW3o-ykhCN\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"DSAFgg91kjsb\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/FDICA/AuxLaplaceFDICA-IP2.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"LbrobIqgDyVY\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 2\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.transform import whiten\\n\",\"from ssspy.algorithm import projection_back\\n\",\"from ssspy.bss.fdica import AuxLaplaceFDICA\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"fdica = AuxLaplaceFDICA(\\n\",\"    spatial_algorithm=\\\"IP2\\\",\\n\",\"    scale_restoration=False\\n\",\")\\n\",\"print(fdica)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_mix_whitened = whiten(spectrogram_mix)\\n\",\"spectrogram_est = fdica(spectrogram_mix_whitened, n_iter=50)\\n\",\"spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(fdica.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/FDICA/GradFDICA.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"9mP-wlwN_imY\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 2\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"LvF_rAusCHdf\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.transform import whiten\\n\",\"from ssspy.algorithm import projection_back\\n\",\"from ssspy.bss.fdica import GradFDICA\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"def contrast_fn(y):\\n\",\"    return 2 * np.abs(y)\\n\",\"\\n\",\"def score_fn(y):\\n\",\"    denom = np.maximum(np.abs(y), 1e-10)\\n\",\"    return y / denom\"],\"metadata\":{\"id\":\"6AvhtkrdfAeZ\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## Holonomic type\"],\"metadata\":{\"id\":\"Gaz1FrqQfSY_\"}},{\"cell_type\":\"code\",\"source\":[\"fdica = GradFDICA(\\n\",\"    step_size=1e-1,\\n\",\"    contrast_fn=contrast_fn,\\n\",\"    score_fn=score_fn,\\n\",\"    is_holonomic=True,\\n\",\"    scale_restoration=False\\n\",\")\\n\",\"print(fdica)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_mix_whitened = whiten(spectrogram_mix)\\n\",\"spectrogram_est = fdica(spectrogram_mix_whitened, n_iter=500)\\n\",\"spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(fdica.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## Non-holonomic type\"],\"metadata\":{\"id\":\"vBVuX7V_frpo\"}},{\"cell_type\":\"code\",\"source\":[\"fdica = GradFDICA(\\n\",\"    step_size=1e-1,\\n\",\"    contrast_fn=contrast_fn,\\n\",\"    score_fn=score_fn,\\n\",\"    is_holonomic=False,\\n\",\"    scale_restoration=False\\n\",\")\\n\",\"print(fdica)\"],\"metadata\":{\"id\":\"ONLA5XqSfooH\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"PkLxGIrcfw5W\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_mix_whitened = whiten(spectrogram_mix)\\n\",\"spectrogram_est = fdica(spectrogram_mix_whitened, n_iter=500)\\n\",\"spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)\"],\"metadata\":{\"id\":\"1ZK8uv99fzU-\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"5TmNEew4f00R\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"qEmFDyaLf23E\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(fdica.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"Jvh7YJVdf7QR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/FDICA/GradLaplaceFDICA.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"9mP-wlwN_imY\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 2\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"LvF_rAusCHdf\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.transform import whiten\\n\",\"from ssspy.algorithm import projection_back\\n\",\"from ssspy.bss.fdica import GradLaplaceFDICA\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## Holonomic type\"],\"metadata\":{\"id\":\"Gaz1FrqQfSY_\"}},{\"cell_type\":\"code\",\"source\":[\"fdica = GradLaplaceFDICA(\\n\",\"    step_size=1e-1,\\n\",\"    is_holonomic=True,\\n\",\"    scale_restoration=False\\n\",\")\\n\",\"print(fdica)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_mix_whitened = whiten(spectrogram_mix)\\n\",\"spectrogram_est = fdica(spectrogram_mix_whitened, n_iter=500)\\n\",\"spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(fdica.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## Non-holonomic type\"],\"metadata\":{\"id\":\"vBVuX7V_frpo\"}},{\"cell_type\":\"code\",\"source\":[\"fdica = GradLaplaceFDICA(\\n\",\"    step_size=1e-1,\\n\",\"    is_holonomic=False,\\n\",\"    scale_restoration=False\\n\",\")\\n\",\"print(fdica)\"],\"metadata\":{\"id\":\"ONLA5XqSfooH\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"PkLxGIrcfw5W\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_mix_whitened = whiten(spectrogram_mix)\\n\",\"spectrogram_est = fdica(spectrogram_mix_whitened, n_iter=500)\\n\",\"spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)\"],\"metadata\":{\"id\":\"1ZK8uv99fzU-\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"5TmNEew4f00R\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"qEmFDyaLf23E\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(fdica.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"Jvh7YJVdf7QR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/FDICA/NaturalGradFDICA.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"2z5vwPlhedKU\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"8We0lvX8gHYO\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"otDru1KgDYlM\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 2\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"hdwn-vr6gK7N\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"dsjZqOE3gMJt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"z4VECyzxgNzM\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.transform import whiten\\n\",\"from ssspy.algorithm import projection_back\\n\",\"from ssspy.bss.fdica import NaturalGradFDICA\"],\"metadata\":{\"id\":\"sjX2hNDEgPVk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"def contrast_fn(y):\\n\",\"    return 2 * np.abs(y)\\n\",\"\\n\",\"def score_fn(y):\\n\",\"    denom = np.maximum(np.abs(y), 1e-10)\\n\",\"    return y / denom\"],\"metadata\":{\"id\":\"ATe35A28gRQp\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## Holonomic type\"],\"metadata\":{\"id\":\"r7fJ65ougw56\"}},{\"cell_type\":\"code\",\"source\":[\"fdica = NaturalGradFDICA(\\n\",\"    step_size=1e-1,\\n\",\"    contrast_fn=contrast_fn,\\n\",\"    score_fn=score_fn,\\n\",\"    is_holonomic=True,\\n\",\")\\n\",\"print(fdica)\"],\"metadata\":{\"id\":\"wL3nVUeagSsv\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"CpKeyzimghGg\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = fdica(spectrogram_mix, n_iter=500)\"],\"metadata\":{\"id\":\"uk0geR7pgkY5\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"Y5SHhlEQgmac\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"XxQYDDkegolW\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(fdica.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"EEI1EfyVgsrF\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## Non-holonomic type\"],\"metadata\":{\"id\":\"JkE7CwScg6He\"}},{\"cell_type\":\"code\",\"source\":[\"fdica = NaturalGradFDICA(\\n\",\"    step_size=1e-1,\\n\",\"    contrast_fn=contrast_fn,\\n\",\"    score_fn=score_fn,\\n\",\"    is_holonomic=False,\\n\",\"    scale_restoration=False\\n\",\")\\n\",\"print(fdica)\"],\"metadata\":{\"id\":\"qjAwSFQJg2mc\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"QS_zwah4g8Cf\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_mix_whitened = whiten(spectrogram_mix)\\n\",\"spectrogram_est = fdica(spectrogram_mix_whitened, n_iter=500)\\n\",\"spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)\"],\"metadata\":{\"id\":\"cPMDNHZtg-Cu\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jsgG54MGhBAi\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"9c4OLGoWhCkg\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(fdica.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"_5PVRsCdhEFr\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"Z7YaF_48hGcB\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/FDICA/NaturalGradLaplaceFDICA.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"2z5vwPlhedKU\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"8We0lvX8gHYO\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"otDru1KgDYlM\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 2\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"hdwn-vr6gK7N\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"dsjZqOE3gMJt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"z4VECyzxgNzM\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.transform import whiten\\n\",\"from ssspy.algorithm import projection_back\\n\",\"from ssspy.bss.fdica import NaturalGradLaplaceFDICA\"],\"metadata\":{\"id\":\"sjX2hNDEgPVk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## Holonomic type\"],\"metadata\":{\"id\":\"r7fJ65ougw56\"}},{\"cell_type\":\"code\",\"source\":[\"fdica = NaturalGradLaplaceFDICA(\\n\",\"    step_size=1e-1,\\n\",\"    is_holonomic=True,\\n\",\")\\n\",\"print(fdica)\"],\"metadata\":{\"id\":\"wL3nVUeagSsv\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"CpKeyzimghGg\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = fdica(spectrogram_mix, n_iter=500)\"],\"metadata\":{\"id\":\"uk0geR7pgkY5\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"Y5SHhlEQgmac\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"XxQYDDkegolW\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(fdica.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"EEI1EfyVgsrF\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## Non-holonomic type\"],\"metadata\":{\"id\":\"JkE7CwScg6He\"}},{\"cell_type\":\"code\",\"source\":[\"fdica = NaturalGradLaplaceFDICA(\\n\",\"    step_size=1e-1,\\n\",\"    is_holonomic=False,\\n\",\"    scale_restoration=False\\n\",\")\\n\",\"print(fdica)\"],\"metadata\":{\"id\":\"qjAwSFQJg2mc\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"QS_zwah4g8Cf\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_mix_whitened = whiten(spectrogram_mix)\\n\",\"spectrogram_est = fdica(spectrogram_mix_whitened, n_iter=500)\\n\",\"spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)\"],\"metadata\":{\"id\":\"cPMDNHZtg-Cu\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jsgG54MGhBAi\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"9c4OLGoWhCkg\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(fdica.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"_5PVRsCdhEFr\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"Z7YaF_48hGcB\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/HVA/ADMM-HVA.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[],\"authorship_tag\":\"ABX9TyOmTvOlLp1H2ygKo0T05BU4\"},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"eADaj8ajw-Hk\"},\"outputs\":[],\"source\":[\"!pip install git+https://github.com/tky823/ssspy.git\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"-vUkn20rxHOV\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"niAQu6AxxRmF\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 2\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"Rf3G-ViRxaRx\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"ivZ5VHBixbnN\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"SzYu9qPTxgfc\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.transform import whiten\\n\",\"from ssspy.algorithm import projection_back\\n\",\"from ssspy.bss.hva import MaskingADMMHVA\"],\"metadata\":{\"id\":\"niQwCmEkxhtQ\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"hva = MaskingADMMHVA(\\n\",\"    rho=0.5,\\n\",\"    relaxation=1.75,\\n\",\"    attenuation=0.2,\\n\",\"    scale_restoration=False,\\n\",\")\\n\",\"print(hva)\"],\"metadata\":{\"id\":\"HIZcUll6xvXY\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"Q8bfISfwxwfd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_mix_whitened = whiten(spectrogram_mix)\\n\",\"spectrogram_mix_normalized = hva.normalize_by_spectral_norm(spectrogram_mix_whitened)\\n\",\"spectrogram_est = hva(spectrogram_mix_normalized, n_iter=500)\\n\",\"spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)\"],\"metadata\":{\"id\":\"B7jVoSBCxyPu\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"etxpTfiTx1nz\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"I2ZhRdS1x3b6\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"EC6aKaZpx5Zd\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/HVA/HVA.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[],\"authorship_tag\":\"ABX9TyMJFxbBzvXESFspRTOKnHrj\"},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"N9ubIAwVRTVT\"},\"outputs\":[],\"source\":[\"!pip install git+https://github.com/tky823/ssspy.git\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"O9gHwzJNRmbD\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"qdeqalLKRoQS\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 2\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"zODACQ83RpUy\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"o-csyI_2RqeI\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"-ORaaTs-RrgL\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.transform import whiten\\n\",\"from ssspy.algorithm import projection_back\\n\",\"from ssspy.bss.hva import HVA\"],\"metadata\":{\"id\":\"wtsPoH2CR2fl\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"hva = HVA(\\n\",\"    mu1=1,\\n\",\"    mu2=1,\\n\",\"    relaxation=1.75,\\n\",\"    attenuation=0.2,\\n\",\"    scale_restoration=False,\\n\",\")\\n\",\"print(hva)\"],\"metadata\":{\"id\":\"JJphe0rSR6PG\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"u-SuTb4NR974\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_mix_whitened = whiten(spectrogram_mix)\\n\",\"spectrogram_mix_normalized = hva.normalize_by_spectral_norm(spectrogram_mix_whitened)\\n\",\"spectrogram_est = hva(spectrogram_mix_normalized, n_iter=200)\\n\",\"spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)\"],\"metadata\":{\"id\":\"BJY7JxF9R_F_\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"kxZrXVL7SBb7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"vs8RFvEuSCqI\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"7CwmQlMpSDwn\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/ICA/FastICA.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 2\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=False,\\n\",\") # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.ica import FastICA\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"def contrast_fn(x):\\n\",\"    return np.log(1 + np.exp(x))\\n\",\"\\n\",\"def score_fn(x):\\n\",\"    return 1 / (1 + np.exp(-x))\\n\",\"\\n\",\"def d_score_fn(x):\\n\",\"    sigma = 1 / (1 + np.exp(-x))\\n\",\"    return sigma * (1 - sigma)\"],\"metadata\":{\"id\":\"6AvhtkrdfAeZ\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"ica = FastICA(\\n\",\"    contrast_fn=contrast_fn,\\n\",\"    score_fn=score_fn,\\n\",\"    d_score_fn=d_score_fn,\\n\",\")\\n\",\"print(ica)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_est = ica(waveform_mix, n_iter=10)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ica.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"PAIdooVaylih\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/ICA/GradICA.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 2\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=False,\\n\",\") # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.transform import whiten\\n\",\"from ssspy.bss.ica import GradICA\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"def contrast_fn(x):\\n\",\"    return np.log(1 + np.exp(x))\\n\",\"\\n\",\"def score_fn(x):\\n\",\"    return 1 / (1 + np.exp(-x))\"],\"metadata\":{\"id\":\"6AvhtkrdfAeZ\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## Holonomic type\"],\"metadata\":{\"id\":\"XifuX1ebyENl\"}},{\"cell_type\":\"code\",\"source\":[\"ica = GradICA(\\n\",\"    contrast_fn=contrast_fn, score_fn=score_fn, is_holonomic=True\\n\",\")\\n\",\"print(ica)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_mix_whitened = whiten(waveform_mix)\\n\",\"waveform_est = ica(waveform_mix_whitened, n_iter=500)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ica.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## Non-holonomic type\"],\"metadata\":{\"id\":\"L3-HTRTWyhAl\"}},{\"cell_type\":\"code\",\"source\":[\"ica = GradICA(\\n\",\"    step_size=1e+0,\\n\",\"    contrast_fn=contrast_fn,\\n\",\"    score_fn=score_fn,\\n\",\"    is_holonomic=False\\n\",\")\\n\",\"print(ica)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_mix_whitened = whiten(waveform_mix)\\n\",\"waveform_est = ica(waveform_mix_whitened, n_iter=100)\"],\"metadata\":{\"id\":\"Dfp7ncInygoA\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Ch8wdSFGyjrl\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ica.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"EeTia-LOykrm\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"PAIdooVaylih\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/ICA/NaturalGradICA.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 2\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=False,\\n\",\") # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.transform import whiten\\n\",\"from ssspy.bss.ica import NaturalGradICA\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"def contrast_fn(x):\\n\",\"    return np.log(1 + np.exp(x))\\n\",\"\\n\",\"def score_fn(x):\\n\",\"    return 1 / (1 + np.exp(-x))\"],\"metadata\":{\"id\":\"6AvhtkrdfAeZ\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## Holonomic type\"],\"metadata\":{\"id\":\"XifuX1ebyENl\"}},{\"cell_type\":\"code\",\"source\":[\"ica = NaturalGradICA(\\n\",\"    contrast_fn=contrast_fn, score_fn=score_fn, is_holonomic=True\\n\",\")\\n\",\"print(ica)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_est = ica(waveform_mix, n_iter=500)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ica.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## Non-holonomic type\"],\"metadata\":{\"id\":\"L3-HTRTWyhAl\"}},{\"cell_type\":\"code\",\"source\":[\"ica = NaturalGradICA(\\n\",\"    step_size=1e+0,\\n\",\"    contrast_fn=contrast_fn,\\n\",\"    score_fn=score_fn,\\n\",\"    is_holonomic=False\\n\",\")\\n\",\"print(ica)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_mix_whitened = whiten(waveform_mix)\\n\",\"waveform_est = ica(waveform_mix_whitened, n_iter=100)\"],\"metadata\":{\"id\":\"Dfp7ncInygoA\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Ch8wdSFGyjrl\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ica.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"EeTia-LOykrm\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"PAIdooVaylih\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/ILRMA/GGDILRMA-IP1-MM.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\\n\",\"from tqdm.notebook import tqdm\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 4\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female4\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.ilrma import GGDILRMA as GGDILRMABase\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"class GGDILRMA(GGDILRMABase):\\n\",\"    def __init__(self, *args, **kwargs):\\n\",\"        super().__init__(*args, **kwargs)\\n\",\"\\n\",\"        self.progress_bar = None\\n\",\"\\n\",\"    def __call__(self, *args, n_iter: int = 100, **kwargs):\\n\",\"        self.n_iter = n_iter\\n\",\"\\n\",\"        return super().__call__(*args, n_iter=n_iter, **kwargs)\\n\",\"\\n\",\"    def update_once(self) -> None:\\n\",\"        if self.progress_bar is None:\\n\",\"            self.progress_bar = tqdm(total=self.n_iter)\\n\",\"\\n\",\"        super().update_once()\\n\",\"\\n\",\"        self.progress_bar.update(1)\"],\"metadata\":{\"id\":\"kyZFMq8BQJ-H\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/ partitioning function\\n\",\"For small $\\\\beta$, the combination of `normalization` and `flooring_fn` sometimes prevents numerical stability, so we set `normalization=False` in this notebook.\"],\"metadata\":{\"id\":\"8ErS0NZ12Gyq\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = GGDILRMA(\\n\",\"    n_basis=8,\\n\",\"    beta=1.95,\\n\",\"    spatial_algorithm=\\\"IP1\\\",  # You can set \\\"IP\\\" instead of \\\"IP1\\\".\\n\",\"    domain=2,\\n\",\"    partitioning=True,  # w/ partitioning function\\n\",\"    normalization=False,\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=200)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/o partitioning function\"],\"metadata\":{\"id\":\"3iVGgX3F2M8Z\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = GGDILRMA(\\n\",\"    n_basis=2,\\n\",\"    beta=1.8,\\n\",\"    spatial_algorithm=\\\"IP1\\\",  # You can set \\\"IP\\\" instead of \\\"IP1\\\".\\n\",\"    domain=2,\\n\",\"    partitioning=False,  # w/o partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"X6OsBzSZ2Q8Z\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=200)\"],\"metadata\":{\"id\":\"Ccmek_Ek2Q6D\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"7VASi3lZ2Q39\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"ct-qKgs42ZK9\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"SIpKCQTo2ZHl\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"r6q7VzoOzt-T\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/ILRMA/GGDILRMA-IP2-MM.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\\n\",\"from tqdm.notebook import tqdm\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 4\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female4\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.ilrma import GGDILRMA as GGDILRMABase\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"class GGDILRMA(GGDILRMABase):\\n\",\"    def __init__(self, *args, **kwargs):\\n\",\"        super().__init__(*args, **kwargs)\\n\",\"\\n\",\"        self.progress_bar = None\\n\",\"\\n\",\"    def __call__(self, *args, n_iter: int = 100, **kwargs):\\n\",\"        self.n_iter = n_iter\\n\",\"\\n\",\"        return super().__call__(*args, n_iter=n_iter, **kwargs)\\n\",\"\\n\",\"    def update_once(self) -> None:\\n\",\"        if self.progress_bar is None:\\n\",\"            self.progress_bar = tqdm(total=self.n_iter)\\n\",\"\\n\",\"        super().update_once()\\n\",\"\\n\",\"        self.progress_bar.update(1)\"],\"metadata\":{\"id\":\"hxhAhSwcQNuw\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/ partitioning function\"],\"metadata\":{\"id\":\"8ErS0NZ12Gyq\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = GGDILRMA(\\n\",\"    n_basis=8,\\n\",\"    beta=1.95,\\n\",\"    spatial_algorithm=\\\"IP2\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=True,  # w/ partitioning function\\n\",\"    normalization=False,\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=200)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/o partitioning function\"],\"metadata\":{\"id\":\"3iVGgX3F2M8Z\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = GGDILRMA(\\n\",\"    n_basis=2,\\n\",\"    beta=1.9,\\n\",\"    spatial_algorithm=\\\"IP2\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=False,  # w/o partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"X6OsBzSZ2Q8Z\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=200)\"],\"metadata\":{\"id\":\"Ccmek_Ek2Q6D\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"7VASi3lZ2Q39\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"ct-qKgs42ZK9\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"SIpKCQTo2ZHl\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"r6q7VzoOzt-T\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/ILRMA/GGDILRMA-ISS1-MM.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\\n\",\"from tqdm.notebook import tqdm\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 4\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female4\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.ilrma import GGDILRMA as GGDILRMABase\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"class GGDILRMA(GGDILRMABase):\\n\",\"    def __init__(self, *args, **kwargs):\\n\",\"        super().__init__(*args, **kwargs)\\n\",\"\\n\",\"        self.progress_bar = None\\n\",\"\\n\",\"    def __call__(self, *args, n_iter: int = 100, **kwargs):\\n\",\"        self.n_iter = n_iter\\n\",\"\\n\",\"        return super().__call__(*args, n_iter=n_iter, **kwargs)\\n\",\"\\n\",\"    def update_once(self) -> None:\\n\",\"        if self.progress_bar is None:\\n\",\"            self.progress_bar = tqdm(total=self.n_iter)\\n\",\"\\n\",\"        super().update_once()\\n\",\"\\n\",\"        self.progress_bar.update(1)\"],\"metadata\":{\"id\":\"Y602oZNqQPBS\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/ partitioning function\"],\"metadata\":{\"id\":\"8ErS0NZ12Gyq\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = GGDILRMA(\\n\",\"    n_basis=8,\\n\",\"    beta=1.95,\\n\",\"    spatial_algorithm=\\\"ISS1\\\",  # You can set \\\"ISS\\\" instead of \\\"ISS1\\\".\\n\",\"    domain=2,\\n\",\"    partitioning=True,  # w/ partitioning function\\n\",\"    normalization=False,\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=200)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/o partitioning function\"],\"metadata\":{\"id\":\"3iVGgX3F2M8Z\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = GGDILRMA(\\n\",\"    n_basis=2,\\n\",\"    beta=1.95,\\n\",\"    spatial_algorithm=\\\"ISS1\\\",  # You can set \\\"ISS\\\" instead of \\\"ISS1\\\".\\n\",\"    domain=2,\\n\",\"    partitioning=False,  # w/o partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"X6OsBzSZ2Q8Z\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=200)\"],\"metadata\":{\"id\":\"Ccmek_Ek2Q6D\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"7VASi3lZ2Q39\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"ct-qKgs42ZK9\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"SIpKCQTo2ZHl\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"r6q7VzoOzt-T\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/ILRMA/GGDILRMA-ISS2-MM.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\\n\",\"from tqdm.notebook import tqdm\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 4\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female4\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.ilrma import GGDILRMA as GGDILRMABase\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"class GGDILRMA(GGDILRMABase):\\n\",\"    def __init__(self, *args, **kwargs):\\n\",\"        super().__init__(*args, **kwargs)\\n\",\"\\n\",\"        self.progress_bar = None\\n\",\"\\n\",\"    def __call__(self, *args, n_iter: int = 100, **kwargs):\\n\",\"        self.n_iter = n_iter\\n\",\"\\n\",\"        return super().__call__(*args, n_iter=n_iter, **kwargs)\\n\",\"\\n\",\"    def update_once(self) -> None:\\n\",\"        if self.progress_bar is None:\\n\",\"            self.progress_bar = tqdm(total=self.n_iter)\\n\",\"\\n\",\"        super().update_once()\\n\",\"\\n\",\"        self.progress_bar.update(1)\"],\"metadata\":{\"id\":\"tXuJbQQ1QPvH\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/ partitioning function\"],\"metadata\":{\"id\":\"8ErS0NZ12Gyq\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = GGDILRMA(\\n\",\"    n_basis=8,\\n\",\"    beta=1.95,\\n\",\"    spatial_algorithm=\\\"ISS2\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=True,  # w/ partitioning function\\n\",\"    normalization=False,\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=200)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/o partitioning function\"],\"metadata\":{\"id\":\"3iVGgX3F2M8Z\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = GGDILRMA(\\n\",\"    n_basis=3,\\n\",\"    beta=1.95,\\n\",\"    spatial_algorithm=\\\"ISS2\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=False,  # w/o partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"X6OsBzSZ2Q8Z\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=200)\"],\"metadata\":{\"id\":\"Ccmek_Ek2Q6D\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"7VASi3lZ2Q39\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"ct-qKgs42ZK9\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"SIpKCQTo2ZHl\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"r6q7VzoOzt-T\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/ILRMA/GaussILRMA-IP1-ME.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\\n\",\"from tqdm.notebook import tqdm\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 4\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female4\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.ilrma import GaussILRMA as GaussILRMABase\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"class GaussILRMA(GaussILRMABase):\\n\",\"    def __init__(self, *args, **kwargs):\\n\",\"        super().__init__(*args, **kwargs)\\n\",\"\\n\",\"        self.progress_bar = None\\n\",\"\\n\",\"    def __call__(self, *args, n_iter: int = 100, **kwargs):\\n\",\"        self.n_iter = n_iter\\n\",\"\\n\",\"        return super().__call__(*args, n_iter=n_iter, **kwargs)\\n\",\"\\n\",\"    def update_once(self) -> None:\\n\",\"        if self.progress_bar is None:\\n\",\"            self.progress_bar = tqdm(total=self.n_iter)\\n\",\"\\n\",\"        super().update_once()\\n\",\"\\n\",\"        self.progress_bar.update(1)\"],\"metadata\":{\"id\":\"B4E4ILrjIlcT\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/ partitioning function\"],\"metadata\":{\"id\":\"8ErS0NZ12Gyq\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = GaussILRMA(\\n\",\"    n_basis=8,\\n\",\"    spatial_algorithm=\\\"IP1\\\",  # You can set \\\"IP\\\" instead of \\\"IP1\\\".\\n\",\"    source_algorithm=\\\"ME\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=True,  # w/ partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=100)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/o partitioning function\"],\"metadata\":{\"id\":\"3iVGgX3F2M8Z\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = GaussILRMA(\\n\",\"    n_basis=2,\\n\",\"    spatial_algorithm=\\\"IP1\\\",  # You can set \\\"IP\\\" instead of \\\"IP1\\\".\\n\",\"    source_algorithm=\\\"ME\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=False,  # w/o partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"X6OsBzSZ2Q8Z\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=100)\"],\"metadata\":{\"id\":\"Ccmek_Ek2Q6D\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"7VASi3lZ2Q39\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"ct-qKgs42ZK9\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"SIpKCQTo2ZHl\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"r6q7VzoOzt-T\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/ILRMA/GaussILRMA-IP1-MM.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\\n\",\"from tqdm.notebook import tqdm\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 4\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female4\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.ilrma import GaussILRMA as GaussILRMABase\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"class GaussILRMA(GaussILRMABase):\\n\",\"    def __init__(self, *args, **kwargs):\\n\",\"        super().__init__(*args, **kwargs)\\n\",\"\\n\",\"        self.progress_bar = None\\n\",\"\\n\",\"    def __call__(self, *args, n_iter: int = 100, **kwargs):\\n\",\"        self.n_iter = n_iter\\n\",\"\\n\",\"        return super().__call__(*args, n_iter=n_iter, **kwargs)\\n\",\"\\n\",\"    def update_once(self) -> None:\\n\",\"        if self.progress_bar is None:\\n\",\"            self.progress_bar = tqdm(total=self.n_iter)\\n\",\"\\n\",\"        super().update_once()\\n\",\"\\n\",\"        self.progress_bar.update(1)\"],\"metadata\":{\"id\":\"B4E4ILrjIlcT\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/ partitioning function\"],\"metadata\":{\"id\":\"8ErS0NZ12Gyq\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = GaussILRMA(\\n\",\"    n_basis=8,\\n\",\"    spatial_algorithm=\\\"IP1\\\",  # You can set \\\"IP\\\" instead of \\\"IP1\\\".\\n\",\"    source_algorithm=\\\"MM\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=True,  # w/ partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=100)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/o partitioning function\"],\"metadata\":{\"id\":\"3iVGgX3F2M8Z\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = GaussILRMA(\\n\",\"    n_basis=2,\\n\",\"    spatial_algorithm=\\\"IP1\\\",  # You can set \\\"IP\\\" instead of \\\"IP1\\\".\\n\",\"    source_algorithm=\\\"MM\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=False,  # w/o partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"X6OsBzSZ2Q8Z\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=100)\"],\"metadata\":{\"id\":\"Ccmek_Ek2Q6D\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"7VASi3lZ2Q39\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"ct-qKgs42ZK9\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"SIpKCQTo2ZHl\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"r6q7VzoOzt-T\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/ILRMA/GaussILRMA-IP2-ME.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\\n\",\"from tqdm.notebook import tqdm\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 4\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female4\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.ilrma import GaussILRMA as GaussILRMABase\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"class GaussILRMA(GaussILRMABase):\\n\",\"    def __init__(self, *args, **kwargs):\\n\",\"        super().__init__(*args, **kwargs)\\n\",\"\\n\",\"        self.progress_bar = None\\n\",\"\\n\",\"    def __call__(self, *args, n_iter: int = 100, **kwargs):\\n\",\"        self.n_iter = n_iter\\n\",\"\\n\",\"        return super().__call__(*args, n_iter=n_iter, **kwargs)\\n\",\"\\n\",\"    def update_once(self) -> None:\\n\",\"        if self.progress_bar is None:\\n\",\"            self.progress_bar = tqdm(total=self.n_iter)\\n\",\"\\n\",\"        super().update_once()\\n\",\"\\n\",\"        self.progress_bar.update(1)\"],\"metadata\":{\"id\":\"nrIom3XyJlB9\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/ partitioning function\"],\"metadata\":{\"id\":\"8ErS0NZ12Gyq\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = GaussILRMA(\\n\",\"    n_basis=8,\\n\",\"    spatial_algorithm=\\\"IP2\\\",\\n\",\"    source_algorithm=\\\"ME\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=True,  # w/ partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=200)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/o partitioning function\"],\"metadata\":{\"id\":\"3iVGgX3F2M8Z\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = GaussILRMA(\\n\",\"    n_basis=2,\\n\",\"    spatial_algorithm=\\\"IP2\\\",\\n\",\"    source_algorithm=\\\"ME\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=False,  # w/o partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"X6OsBzSZ2Q8Z\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=100)\"],\"metadata\":{\"id\":\"Ccmek_Ek2Q6D\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"7VASi3lZ2Q39\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"ct-qKgs42ZK9\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"SIpKCQTo2ZHl\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"r6q7VzoOzt-T\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/ILRMA/GaussILRMA-IP2-MM.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\\n\",\"from tqdm.notebook import tqdm\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 4\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female4\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.ilrma import GaussILRMA as GaussILRMABase\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"class GaussILRMA(GaussILRMABase):\\n\",\"    def __init__(self, *args, **kwargs):\\n\",\"        super().__init__(*args, **kwargs)\\n\",\"\\n\",\"        self.progress_bar = None\\n\",\"\\n\",\"    def __call__(self, *args, n_iter: int = 100, **kwargs):\\n\",\"        self.n_iter = n_iter\\n\",\"\\n\",\"        return super().__call__(*args, n_iter=n_iter, **kwargs)\\n\",\"\\n\",\"    def update_once(self) -> None:\\n\",\"        if self.progress_bar is None:\\n\",\"            self.progress_bar = tqdm(total=self.n_iter)\\n\",\"\\n\",\"        super().update_once()\\n\",\"\\n\",\"        self.progress_bar.update(1)\"],\"metadata\":{\"id\":\"nrIom3XyJlB9\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/ partitioning function\"],\"metadata\":{\"id\":\"8ErS0NZ12Gyq\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = GaussILRMA(\\n\",\"    n_basis=8,\\n\",\"    spatial_algorithm=\\\"IP2\\\",\\n\",\"    source_algorithm=\\\"MM\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=True,  # w/ partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=200)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/o partitioning function\"],\"metadata\":{\"id\":\"3iVGgX3F2M8Z\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = GaussILRMA(\\n\",\"    n_basis=2,\\n\",\"    spatial_algorithm=\\\"IP2\\\",\\n\",\"    source_algorithm=\\\"MM\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=False,  # w/o partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"X6OsBzSZ2Q8Z\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=100)\"],\"metadata\":{\"id\":\"Ccmek_Ek2Q6D\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"7VASi3lZ2Q39\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"ct-qKgs42ZK9\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"SIpKCQTo2ZHl\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"r6q7VzoOzt-T\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/ILRMA/GaussILRMA-IPA-ME.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\\n\",\"from tqdm.notebook import tqdm\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 4\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female4\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.ilrma import GaussILRMA as GaussILRMABase\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"class GaussILRMA(GaussILRMABase):\\n\",\"    def __init__(self, *args, **kwargs):\\n\",\"        super().__init__(*args, **kwargs)\\n\",\"\\n\",\"        self.progress_bar = None\\n\",\"\\n\",\"    def __call__(self, *args, n_iter: int = 100, **kwargs):\\n\",\"        self.n_iter = n_iter\\n\",\"\\n\",\"        return super().__call__(*args, n_iter=n_iter, **kwargs)\\n\",\"\\n\",\"    def update_once(self) -> None:\\n\",\"        if self.progress_bar is None:\\n\",\"            self.progress_bar = tqdm(total=self.n_iter)\\n\",\"\\n\",\"        super().update_once()\\n\",\"\\n\",\"        self.progress_bar.update(1)\"],\"metadata\":{\"id\":\"Q2dt5YXZJl2T\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/ partitioning function\"],\"metadata\":{\"id\":\"8ErS0NZ12Gyq\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = GaussILRMA(\\n\",\"    n_basis=8,\\n\",\"    spatial_algorithm=\\\"IPA\\\",\\n\",\"    source_algorithm=\\\"ME\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=True,  # w/ partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=100)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/o partitioning function\"],\"metadata\":{\"id\":\"3iVGgX3F2M8Z\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = GaussILRMA(\\n\",\"    n_basis=2,\\n\",\"    spatial_algorithm=\\\"IPA\\\",\\n\",\"    source_algorithm=\\\"ME\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=False,  # w/o partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"X6OsBzSZ2Q8Z\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=100)\"],\"metadata\":{\"id\":\"Ccmek_Ek2Q6D\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"7VASi3lZ2Q39\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"ct-qKgs42ZK9\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"SIpKCQTo2ZHl\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"r6q7VzoOzt-T\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/ILRMA/GaussILRMA-IPA-MM.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\\n\",\"from tqdm.notebook import tqdm\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 4\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female4\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.ilrma import GaussILRMA as GaussILRMABase\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"class GaussILRMA(GaussILRMABase):\\n\",\"    def __init__(self, *args, **kwargs):\\n\",\"        super().__init__(*args, **kwargs)\\n\",\"\\n\",\"        self.progress_bar = None\\n\",\"\\n\",\"    def __call__(self, *args, n_iter: int = 100, **kwargs):\\n\",\"        self.n_iter = n_iter\\n\",\"\\n\",\"        return super().__call__(*args, n_iter=n_iter, **kwargs)\\n\",\"\\n\",\"    def update_once(self) -> None:\\n\",\"        if self.progress_bar is None:\\n\",\"            self.progress_bar = tqdm(total=self.n_iter)\\n\",\"\\n\",\"        super().update_once()\\n\",\"\\n\",\"        self.progress_bar.update(1)\"],\"metadata\":{\"id\":\"Q2dt5YXZJl2T\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/ partitioning function\"],\"metadata\":{\"id\":\"8ErS0NZ12Gyq\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = GaussILRMA(\\n\",\"    n_basis=8,\\n\",\"    spatial_algorithm=\\\"IPA\\\",\\n\",\"    source_algorithm=\\\"MM\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=True,  # w/ partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=100)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/o partitioning function\"],\"metadata\":{\"id\":\"3iVGgX3F2M8Z\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = GaussILRMA(\\n\",\"    n_basis=2,\\n\",\"    spatial_algorithm=\\\"IPA\\\",\\n\",\"    source_algorithm=\\\"MM\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=False,  # w/o partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"X6OsBzSZ2Q8Z\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=100)\"],\"metadata\":{\"id\":\"Ccmek_Ek2Q6D\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"7VASi3lZ2Q39\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"ct-qKgs42ZK9\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"SIpKCQTo2ZHl\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"r6q7VzoOzt-T\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/ILRMA/GaussILRMA-ISS1-ME.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\\n\",\"from tqdm.notebook import tqdm\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 4\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female4\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.ilrma import GaussILRMA as GaussILRMABase\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"class GaussILRMA(GaussILRMABase):\\n\",\"    def __init__(self, *args, **kwargs):\\n\",\"        super().__init__(*args, **kwargs)\\n\",\"\\n\",\"        self.progress_bar = None\\n\",\"\\n\",\"    def __call__(self, *args, n_iter: int = 100, **kwargs):\\n\",\"        self.n_iter = n_iter\\n\",\"\\n\",\"        return super().__call__(*args, n_iter=n_iter, **kwargs)\\n\",\"\\n\",\"    def update_once(self) -> None:\\n\",\"        if self.progress_bar is None:\\n\",\"            self.progress_bar = tqdm(total=self.n_iter)\\n\",\"\\n\",\"        super().update_once()\\n\",\"\\n\",\"        self.progress_bar.update(1)\"],\"metadata\":{\"id\":\"Q2dt5YXZJl2T\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/ partitioning function\"],\"metadata\":{\"id\":\"8ErS0NZ12Gyq\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = GaussILRMA(\\n\",\"    n_basis=8,\\n\",\"    spatial_algorithm=\\\"ISS1\\\",  # You can set \\\"ISS\\\" instead of \\\"ISS1\\\".\\n\",\"    source_algorithm=\\\"ME\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=True,  # w/ partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=100)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/o partitioning function\"],\"metadata\":{\"id\":\"3iVGgX3F2M8Z\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = GaussILRMA(\\n\",\"    n_basis=2,\\n\",\"    spatial_algorithm=\\\"ISS1\\\",  # You can set \\\"ISS\\\" instead of \\\"ISS1\\\".\\n\",\"    source_algorithm=\\\"ME\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=False,  # w/o partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"X6OsBzSZ2Q8Z\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=100)\"],\"metadata\":{\"id\":\"Ccmek_Ek2Q6D\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"7VASi3lZ2Q39\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"ct-qKgs42ZK9\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"SIpKCQTo2ZHl\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"r6q7VzoOzt-T\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/ILRMA/GaussILRMA-ISS1-MM.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\\n\",\"from tqdm.notebook import tqdm\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 4\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female4\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.ilrma import GaussILRMA as GaussILRMABase\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"class GaussILRMA(GaussILRMABase):\\n\",\"    def __init__(self, *args, **kwargs):\\n\",\"        super().__init__(*args, **kwargs)\\n\",\"\\n\",\"        self.progress_bar = None\\n\",\"\\n\",\"    def __call__(self, *args, n_iter: int = 100, **kwargs):\\n\",\"        self.n_iter = n_iter\\n\",\"\\n\",\"        return super().__call__(*args, n_iter=n_iter, **kwargs)\\n\",\"\\n\",\"    def update_once(self) -> None:\\n\",\"        if self.progress_bar is None:\\n\",\"            self.progress_bar = tqdm(total=self.n_iter)\\n\",\"\\n\",\"        super().update_once()\\n\",\"\\n\",\"        self.progress_bar.update(1)\"],\"metadata\":{\"id\":\"Q2dt5YXZJl2T\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/ partitioning function\"],\"metadata\":{\"id\":\"8ErS0NZ12Gyq\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = GaussILRMA(\\n\",\"    n_basis=8,\\n\",\"    spatial_algorithm=\\\"ISS1\\\",  # You can set \\\"ISS\\\" instead of \\\"ISS1\\\".\\n\",\"    source_algorithm=\\\"MM\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=True,  # w/ partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=100)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/o partitioning function\"],\"metadata\":{\"id\":\"3iVGgX3F2M8Z\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = GaussILRMA(\\n\",\"    n_basis=2,\\n\",\"    spatial_algorithm=\\\"ISS1\\\",  # You can set \\\"ISS\\\" instead of \\\"ISS1\\\".\\n\",\"    source_algorithm=\\\"MM\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=False,  # w/o partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"X6OsBzSZ2Q8Z\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=100)\"],\"metadata\":{\"id\":\"Ccmek_Ek2Q6D\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"7VASi3lZ2Q39\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"ct-qKgs42ZK9\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"SIpKCQTo2ZHl\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"r6q7VzoOzt-T\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/ILRMA/GaussILRMA-ISS2-ME.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\\n\",\"from tqdm.notebook import tqdm\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 4\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female4\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.ilrma import GaussILRMA as GaussILRMABase\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"class GaussILRMA(GaussILRMABase):\\n\",\"    def __init__(self, *args, **kwargs):\\n\",\"        super().__init__(*args, **kwargs)\\n\",\"\\n\",\"        self.progress_bar = None\\n\",\"\\n\",\"    def __call__(self, *args, n_iter: int = 100, **kwargs):\\n\",\"        self.n_iter = n_iter\\n\",\"\\n\",\"        return super().__call__(*args, n_iter=n_iter, **kwargs)\\n\",\"\\n\",\"    def update_once(self) -> None:\\n\",\"        if self.progress_bar is None:\\n\",\"            self.progress_bar = tqdm(total=self.n_iter)\\n\",\"\\n\",\"        super().update_once()\\n\",\"\\n\",\"        self.progress_bar.update(1)\"],\"metadata\":{\"id\":\"82kfbjHkJml-\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/ partitioning function\"],\"metadata\":{\"id\":\"8ErS0NZ12Gyq\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = GaussILRMA(\\n\",\"    n_basis=5,\\n\",\"    spatial_algorithm=\\\"ISS2\\\",\\n\",\"    source_algorithm=\\\"ME\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=True,  # w/ partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=100)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/o partitioning function\"],\"metadata\":{\"id\":\"3iVGgX3F2M8Z\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = GaussILRMA(\\n\",\"    n_basis=2,\\n\",\"    spatial_algorithm=\\\"ISS2\\\",\\n\",\"    source_algorithm=\\\"ME\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=False,  # w/o partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"X6OsBzSZ2Q8Z\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=100)\"],\"metadata\":{\"id\":\"Ccmek_Ek2Q6D\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"7VASi3lZ2Q39\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"ct-qKgs42ZK9\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"SIpKCQTo2ZHl\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"r6q7VzoOzt-T\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/ILRMA/GaussILRMA-ISS2-MM.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\\n\",\"from tqdm.notebook import tqdm\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 4\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female4\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.ilrma import GaussILRMA as GaussILRMABase\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"class GaussILRMA(GaussILRMABase):\\n\",\"    def __init__(self, *args, **kwargs):\\n\",\"        super().__init__(*args, **kwargs)\\n\",\"\\n\",\"        self.progress_bar = None\\n\",\"\\n\",\"    def __call__(self, *args, n_iter: int = 100, **kwargs):\\n\",\"        self.n_iter = n_iter\\n\",\"\\n\",\"        return super().__call__(*args, n_iter=n_iter, **kwargs)\\n\",\"\\n\",\"    def update_once(self) -> None:\\n\",\"        if self.progress_bar is None:\\n\",\"            self.progress_bar = tqdm(total=self.n_iter)\\n\",\"\\n\",\"        super().update_once()\\n\",\"\\n\",\"        self.progress_bar.update(1)\"],\"metadata\":{\"id\":\"82kfbjHkJml-\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/ partitioning function\"],\"metadata\":{\"id\":\"8ErS0NZ12Gyq\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = GaussILRMA(\\n\",\"    n_basis=5,\\n\",\"    spatial_algorithm=\\\"ISS2\\\",\\n\",\"    source_algorithm=\\\"MM\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=True,  # w/ partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=100)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/o partitioning function\"],\"metadata\":{\"id\":\"3iVGgX3F2M8Z\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = GaussILRMA(\\n\",\"    n_basis=2,\\n\",\"    spatial_algorithm=\\\"ISS2\\\",\\n\",\"    source_algorithm=\\\"MM\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=False,  # w/o partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"X6OsBzSZ2Q8Z\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=100)\"],\"metadata\":{\"id\":\"Ccmek_Ek2Q6D\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"7VASi3lZ2Q39\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"ct-qKgs42ZK9\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"SIpKCQTo2ZHl\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"r6q7VzoOzt-T\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/ILRMA/TILRMA-IP1-ME.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\\n\",\"from tqdm.notebook import tqdm\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 4\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female4\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.ilrma import TILRMA as TILRMABase\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"class TILRMA(TILRMABase):\\n\",\"    def __init__(self, *args, **kwargs):\\n\",\"        super().__init__(*args, **kwargs)\\n\",\"\\n\",\"        self.progress_bar = None\\n\",\"\\n\",\"    def __call__(self, *args, n_iter: int = 100, **kwargs):\\n\",\"        self.n_iter = n_iter\\n\",\"\\n\",\"        return super().__call__(*args, n_iter=n_iter, **kwargs)\\n\",\"\\n\",\"    def update_once(self) -> None:\\n\",\"        if self.progress_bar is None:\\n\",\"            self.progress_bar = tqdm(total=self.n_iter)\\n\",\"\\n\",\"        super().update_once()\\n\",\"\\n\",\"        self.progress_bar.update(1)\"],\"metadata\":{\"id\":\"AYaZ5WZFa7Mw\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/ partitioning function\"],\"metadata\":{\"id\":\"8ErS0NZ12Gyq\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = TILRMA(\\n\",\"    n_basis=8,\\n\",\"    dof=1000,\\n\",\"    spatial_algorithm=\\\"IP1\\\",  # You can set \\\"IP\\\" instead of \\\"IP1\\\".\\n\",\"    source_algorithm=\\\"ME\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=True,  # w/ partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=500)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/o partitioning function\"],\"metadata\":{\"id\":\"3iVGgX3F2M8Z\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = TILRMA(\\n\",\"    n_basis=2,\\n\",\"    dof=1000,\\n\",\"    spatial_algorithm=\\\"IP1\\\",  # You can set \\\"IP\\\" instead of \\\"IP1\\\".\\n\",\"    source_algorithm=\\\"ME\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=False,  # w/o partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"X6OsBzSZ2Q8Z\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=200)\"],\"metadata\":{\"id\":\"Ccmek_Ek2Q6D\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"7VASi3lZ2Q39\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"ct-qKgs42ZK9\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"SIpKCQTo2ZHl\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"r6q7VzoOzt-T\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/ILRMA/TILRMA-IP1-MM.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\\n\",\"from tqdm.notebook import tqdm\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 4\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female4\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.ilrma import TILRMA as TILRMABase\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"class TILRMA(TILRMABase):\\n\",\"    def __init__(self, *args, **kwargs):\\n\",\"        super().__init__(*args, **kwargs)\\n\",\"\\n\",\"        self.progress_bar = None\\n\",\"\\n\",\"    def __call__(self, *args, n_iter: int = 100, **kwargs):\\n\",\"        self.n_iter = n_iter\\n\",\"\\n\",\"        return super().__call__(*args, n_iter=n_iter, **kwargs)\\n\",\"\\n\",\"    def update_once(self) -> None:\\n\",\"        if self.progress_bar is None:\\n\",\"            self.progress_bar = tqdm(total=self.n_iter)\\n\",\"\\n\",\"        super().update_once()\\n\",\"\\n\",\"        self.progress_bar.update(1)\"],\"metadata\":{\"id\":\"AYaZ5WZFa7Mw\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/ partitioning function\"],\"metadata\":{\"id\":\"8ErS0NZ12Gyq\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = TILRMA(\\n\",\"    n_basis=8,\\n\",\"    dof=1000,\\n\",\"    spatial_algorithm=\\\"IP1\\\",  # You can set \\\"IP\\\" instead of \\\"IP1\\\".\\n\",\"    source_algorithm=\\\"MM\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=True,  # w/ partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=500)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/o partitioning function\"],\"metadata\":{\"id\":\"3iVGgX3F2M8Z\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = TILRMA(\\n\",\"    n_basis=2,\\n\",\"    dof=1000,\\n\",\"    spatial_algorithm=\\\"IP1\\\",  # You can set \\\"IP\\\" instead of \\\"IP1\\\".\\n\",\"    source_algorithm=\\\"MM\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=False,  # w/o partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"X6OsBzSZ2Q8Z\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=200)\"],\"metadata\":{\"id\":\"Ccmek_Ek2Q6D\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"7VASi3lZ2Q39\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"ct-qKgs42ZK9\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"SIpKCQTo2ZHl\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"r6q7VzoOzt-T\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/ILRMA/TILRMA-IP2-ME.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\\n\",\"from tqdm.notebook import tqdm\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 4\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female4\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.ilrma import TILRMA as TILRMABase\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"class TILRMA(TILRMABase):\\n\",\"    def __init__(self, *args, **kwargs):\\n\",\"        super().__init__(*args, **kwargs)\\n\",\"\\n\",\"        self.progress_bar = None\\n\",\"\\n\",\"    def __call__(self, *args, n_iter: int = 100, **kwargs):\\n\",\"        self.n_iter = n_iter\\n\",\"\\n\",\"        return super().__call__(*args, n_iter=n_iter, **kwargs)\\n\",\"\\n\",\"    def update_once(self) -> None:\\n\",\"        if self.progress_bar is None:\\n\",\"            self.progress_bar = tqdm(total=self.n_iter)\\n\",\"\\n\",\"        super().update_once()\\n\",\"\\n\",\"        self.progress_bar.update(1)\"],\"metadata\":{\"id\":\"F5nHP_t0a9zc\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/ partitioning function\"],\"metadata\":{\"id\":\"8ErS0NZ12Gyq\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = TILRMA(\\n\",\"    n_basis=8,\\n\",\"    dof=1000,\\n\",\"    spatial_algorithm=\\\"IP2\\\",\\n\",\"    source_algorithm=\\\"ME\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=True,  # w/ partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=200)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/o partitioning function\"],\"metadata\":{\"id\":\"3iVGgX3F2M8Z\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = TILRMA(\\n\",\"    n_basis=2,\\n\",\"    dof=1000,\\n\",\"    spatial_algorithm=\\\"IP2\\\",\\n\",\"    source_algorithm=\\\"ME\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=False,  # w/o partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"X6OsBzSZ2Q8Z\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=200)\"],\"metadata\":{\"id\":\"Ccmek_Ek2Q6D\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"7VASi3lZ2Q39\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"ct-qKgs42ZK9\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"SIpKCQTo2ZHl\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"r6q7VzoOzt-T\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/ILRMA/TILRMA-IP2-MM.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\\n\",\"from tqdm.notebook import tqdm\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 4\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female4\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.ilrma import TILRMA as TILRMABase\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"class TILRMA(TILRMABase):\\n\",\"    def __init__(self, *args, **kwargs):\\n\",\"        super().__init__(*args, **kwargs)\\n\",\"\\n\",\"        self.progress_bar = None\\n\",\"\\n\",\"    def __call__(self, *args, n_iter: int = 100, **kwargs):\\n\",\"        self.n_iter = n_iter\\n\",\"\\n\",\"        return super().__call__(*args, n_iter=n_iter, **kwargs)\\n\",\"\\n\",\"    def update_once(self) -> None:\\n\",\"        if self.progress_bar is None:\\n\",\"            self.progress_bar = tqdm(total=self.n_iter)\\n\",\"\\n\",\"        super().update_once()\\n\",\"\\n\",\"        self.progress_bar.update(1)\"],\"metadata\":{\"id\":\"F5nHP_t0a9zc\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/ partitioning function\"],\"metadata\":{\"id\":\"8ErS0NZ12Gyq\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = TILRMA(\\n\",\"    n_basis=8,\\n\",\"    dof=1000,\\n\",\"    spatial_algorithm=\\\"IP2\\\",\\n\",\"    source_algorithm=\\\"MM\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=True,  # w/ partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=200)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/o partitioning function\"],\"metadata\":{\"id\":\"3iVGgX3F2M8Z\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = TILRMA(\\n\",\"    n_basis=2,\\n\",\"    dof=1000,\\n\",\"    spatial_algorithm=\\\"IP2\\\",\\n\",\"    source_algorithm=\\\"MM\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=False,  # w/o partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"X6OsBzSZ2Q8Z\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=200)\"],\"metadata\":{\"id\":\"Ccmek_Ek2Q6D\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"7VASi3lZ2Q39\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"ct-qKgs42ZK9\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"SIpKCQTo2ZHl\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"r6q7VzoOzt-T\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/ILRMA/TILRMA-ISS1-ME.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\\n\",\"from tqdm.notebook import tqdm\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 4\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female4\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.ilrma import TILRMA as TILRMABase\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"class TILRMA(TILRMABase):\\n\",\"    def __init__(self, *args, **kwargs):\\n\",\"        super().__init__(*args, **kwargs)\\n\",\"\\n\",\"        self.progress_bar = None\\n\",\"\\n\",\"    def __call__(self, *args, n_iter: int = 100, **kwargs):\\n\",\"        self.n_iter = n_iter\\n\",\"\\n\",\"        return super().__call__(*args, n_iter=n_iter, **kwargs)\\n\",\"\\n\",\"    def update_once(self) -> None:\\n\",\"        if self.progress_bar is None:\\n\",\"            self.progress_bar = tqdm(total=self.n_iter)\\n\",\"\\n\",\"        super().update_once()\\n\",\"\\n\",\"        self.progress_bar.update(1)\"],\"metadata\":{\"id\":\"mctxmyYha-Wl\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/ partitioning function\"],\"metadata\":{\"id\":\"8ErS0NZ12Gyq\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = TILRMA(\\n\",\"    n_basis=8,\\n\",\"    dof=1000,\\n\",\"    spatial_algorithm=\\\"ISS1\\\",  # You can set \\\"ISS\\\" instead of \\\"ISS1\\\".\\n\",\"    source_algorithm=\\\"ME\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=True,  # w/ partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=500)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/o partitioning function\"],\"metadata\":{\"id\":\"3iVGgX3F2M8Z\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = TILRMA(\\n\",\"    n_basis=2,\\n\",\"    dof=1000,\\n\",\"    spatial_algorithm=\\\"ISS1\\\",  # You can set \\\"ISS\\\" instead of \\\"ISS1\\\".\\n\",\"    source_algorithm=\\\"ME\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=False,  # w/o partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"X6OsBzSZ2Q8Z\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=200)\"],\"metadata\":{\"id\":\"Ccmek_Ek2Q6D\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"7VASi3lZ2Q39\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"ct-qKgs42ZK9\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"SIpKCQTo2ZHl\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"r6q7VzoOzt-T\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/ILRMA/TILRMA-ISS1-MM.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\\n\",\"from tqdm.notebook import tqdm\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 4\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female4\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.ilrma import TILRMA as TILRMABase\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"class TILRMA(TILRMABase):\\n\",\"    def __init__(self, *args, **kwargs):\\n\",\"        super().__init__(*args, **kwargs)\\n\",\"\\n\",\"        self.progress_bar = None\\n\",\"\\n\",\"    def __call__(self, *args, n_iter: int = 100, **kwargs):\\n\",\"        self.n_iter = n_iter\\n\",\"\\n\",\"        return super().__call__(*args, n_iter=n_iter, **kwargs)\\n\",\"\\n\",\"    def update_once(self) -> None:\\n\",\"        if self.progress_bar is None:\\n\",\"            self.progress_bar = tqdm(total=self.n_iter)\\n\",\"\\n\",\"        super().update_once()\\n\",\"\\n\",\"        self.progress_bar.update(1)\"],\"metadata\":{\"id\":\"mctxmyYha-Wl\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/ partitioning function\"],\"metadata\":{\"id\":\"8ErS0NZ12Gyq\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = TILRMA(\\n\",\"    n_basis=8,\\n\",\"    dof=1000,\\n\",\"    spatial_algorithm=\\\"ISS1\\\",  # You can set \\\"ISS\\\" instead of \\\"ISS1\\\".\\n\",\"    source_algorithm=\\\"MM\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=True,  # w/ partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=500)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/o partitioning function\"],\"metadata\":{\"id\":\"3iVGgX3F2M8Z\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = TILRMA(\\n\",\"    n_basis=2,\\n\",\"    dof=1000,\\n\",\"    spatial_algorithm=\\\"ISS1\\\",  # You can set \\\"ISS\\\" instead of \\\"ISS1\\\".\\n\",\"    source_algorithm=\\\"MM\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=False,  # w/o partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"X6OsBzSZ2Q8Z\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=200)\"],\"metadata\":{\"id\":\"Ccmek_Ek2Q6D\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"7VASi3lZ2Q39\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"ct-qKgs42ZK9\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"SIpKCQTo2ZHl\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"r6q7VzoOzt-T\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/ILRMA/TILRMA-ISS2-ME.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\\n\",\"from tqdm.notebook import tqdm\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 4\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female4\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.ilrma import TILRMA as TILRMABase\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"class TILRMA(TILRMABase):\\n\",\"    def __init__(self, *args, **kwargs):\\n\",\"        super().__init__(*args, **kwargs)\\n\",\"\\n\",\"        self.progress_bar = None\\n\",\"\\n\",\"    def __call__(self, *args, n_iter: int = 100, **kwargs):\\n\",\"        self.n_iter = n_iter\\n\",\"\\n\",\"        return super().__call__(*args, n_iter=n_iter, **kwargs)\\n\",\"\\n\",\"    def update_once(self) -> None:\\n\",\"        if self.progress_bar is None:\\n\",\"            self.progress_bar = tqdm(total=self.n_iter)\\n\",\"\\n\",\"        super().update_once()\\n\",\"\\n\",\"        self.progress_bar.update(1)\"],\"metadata\":{\"id\":\"F_vG0afTa_F9\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/ partitioning function\"],\"metadata\":{\"id\":\"8ErS0NZ12Gyq\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = TILRMA(\\n\",\"    n_basis=8,\\n\",\"    dof=1000,\\n\",\"    spatial_algorithm=\\\"ISS2\\\",\\n\",\"    source_algorithm=\\\"ME\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=True,  # w/ partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=500)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/o partitioning function\"],\"metadata\":{\"id\":\"3iVGgX3F2M8Z\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = TILRMA(\\n\",\"    n_basis=3,\\n\",\"    dof=1000,\\n\",\"    spatial_algorithm=\\\"ISS2\\\",\\n\",\"    source_algorithm=\\\"ME\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=False,  # w/o partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"X6OsBzSZ2Q8Z\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=200)\"],\"metadata\":{\"id\":\"Ccmek_Ek2Q6D\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"7VASi3lZ2Q39\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"ct-qKgs42ZK9\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"SIpKCQTo2ZHl\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"HEf1-l6w4zon\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/ILRMA/TILRMA-ISS2-MM.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\\n\",\"from tqdm.notebook import tqdm\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 4\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female4\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.ilrma import TILRMA as TILRMABase\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"class TILRMA(TILRMABase):\\n\",\"    def __init__(self, *args, **kwargs):\\n\",\"        super().__init__(*args, **kwargs)\\n\",\"\\n\",\"        self.progress_bar = None\\n\",\"\\n\",\"    def __call__(self, *args, n_iter: int = 100, **kwargs):\\n\",\"        self.n_iter = n_iter\\n\",\"\\n\",\"        return super().__call__(*args, n_iter=n_iter, **kwargs)\\n\",\"\\n\",\"    def update_once(self) -> None:\\n\",\"        if self.progress_bar is None:\\n\",\"            self.progress_bar = tqdm(total=self.n_iter)\\n\",\"\\n\",\"        super().update_once()\\n\",\"\\n\",\"        self.progress_bar.update(1)\"],\"metadata\":{\"id\":\"F_vG0afTa_F9\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/ partitioning function\"],\"metadata\":{\"id\":\"8ErS0NZ12Gyq\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = TILRMA(\\n\",\"    n_basis=8,\\n\",\"    dof=1000,\\n\",\"    spatial_algorithm=\\\"ISS2\\\",\\n\",\"    source_algorithm=\\\"MM\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=True,  # w/ partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=500)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/o partitioning function\"],\"metadata\":{\"id\":\"3iVGgX3F2M8Z\"}},{\"cell_type\":\"code\",\"source\":[\"ilrma = TILRMA(\\n\",\"    n_basis=3,\\n\",\"    dof=1000,\\n\",\"    spatial_algorithm=\\\"ISS2\\\",\\n\",\"    source_algorithm=\\\"MM\\\",\\n\",\"    domain=2,\\n\",\"    partitioning=False,  # w/o partitioning function\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"X6OsBzSZ2Q8Z\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=200)\"],\"metadata\":{\"id\":\"Ccmek_Ek2Q6D\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"7VASi3lZ2Q39\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"ct-qKgs42ZK9\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ilrma.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"SIpKCQTo2ZHl\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"HEf1-l6w4zon\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/IPSDTA/GaussIPSDTA-VCD.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\\n\",\"from tqdm.notebook import tqdm\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 2\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.ipsdta import GaussIPSDTA as GaussIPSDTABase\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"class GaussIPSDTA(GaussIPSDTABase):\\n\",\"    def __init__(self, *args, source_steps=1, spatial_steps=1, **kwargs):\\n\",\"        super().__init__(*args, **kwargs)\\n\",\"\\n\",\"        self.progress_bar = None\\n\",\"        self.source_steps = source_steps\\n\",\"        self.spatial_steps = spatial_steps\\n\",\"\\n\",\"    def __call__(self, *args, n_iter: int = 100, **kwargs):\\n\",\"        self.n_iter = n_iter\\n\",\"\\n\",\"        return super().__call__(*args, n_iter=n_iter, **kwargs)\\n\",\"\\n\",\"    def update_once(self) -> None:\\n\",\"        if self.progress_bar is None:\\n\",\"            self.progress_bar = tqdm(total=self.n_iter)\\n\",\"\\n\",\"        for _ in range(self.source_steps):\\n\",\"            self.update_source_model()\\n\",\"\\n\",\"        for _ in range(self.spatial_steps):\\n\",\"            self.update_spatial_model()\\n\",\"\\n\",\"        self.progress_bar.update(1)\"],\"metadata\":{\"id\":\"KChk_QhJ8Xhg\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"ipsdta = GaussIPSDTA(\\n\",\"    n_basis=2,\\n\",\"    n_blocks=1024,  # block 1: {1, 2}, ..., block 1023: {2045, 2046}, block 1024: {2047, 2048, 2049}\\n\",\"    spatial_algorithm=\\\"VCD\\\",\\n\",\"    source_steps=1,\\n\",\"    spatial_steps=10,\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ipsdta)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"X6OsBzSZ2Q8Z\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ipsdta(spectrogram_mix, n_iter=100)\"],\"metadata\":{\"id\":\"Ccmek_Ek2Q6D\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"7VASi3lZ2Q39\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"ct-qKgs42ZK9\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ipsdta.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"SIpKCQTo2ZHl\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"r6q7VzoOzt-T\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/IPSDTA/TIPSDTA-VCD.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\\n\",\"from tqdm.notebook import tqdm\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 2\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.ipsdta import TIPSDTA as TIPSDTABase\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"class TIPSDTA(TIPSDTABase):\\n\",\"    def __init__(self, *args, source_steps=1, spatial_steps=1, **kwargs):\\n\",\"        super().__init__(*args, **kwargs)\\n\",\"\\n\",\"        self.progress_bar = None\\n\",\"        self.source_steps = source_steps\\n\",\"        self.spatial_steps = spatial_steps\\n\",\"\\n\",\"    def __call__(self, *args, n_iter: int = 100, **kwargs):\\n\",\"        self.n_iter = n_iter\\n\",\"\\n\",\"        return super().__call__(*args, n_iter=n_iter, **kwargs)\\n\",\"\\n\",\"    def update_once(self) -> None:\\n\",\"        if self.progress_bar is None:\\n\",\"            self.progress_bar = tqdm(total=self.n_iter)\\n\",\"\\n\",\"        for _ in range(self.source_steps):\\n\",\"            self.update_source_model()\\n\",\"\\n\",\"        for _ in range(self.spatial_steps):\\n\",\"            self.update_spatial_model()\\n\",\"\\n\",\"        self.progress_bar.update(1)\"],\"metadata\":{\"id\":\"KChk_QhJ8Xhg\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"ipsdta = TIPSDTA(\\n\",\"    n_basis=2,\\n\",\"    n_blocks=1024,  # block 1: {1, 2}, ..., block 1023: {2045, 2046}, block 1024: {2047, 2048, 2049}\\n\",\"    dof=1000,\\n\",\"    spatial_algorithm=\\\"VCD\\\",\\n\",\"    source_steps=1,\\n\",\"    spatial_steps=10,\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(ipsdta)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"X6OsBzSZ2Q8Z\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ipsdta(spectrogram_mix, n_iter=100)\"],\"metadata\":{\"id\":\"Ccmek_Ek2Q6D\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"7VASi3lZ2Q39\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"ct-qKgs42ZK9\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(ipsdta.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"SIpKCQTo2ZHl\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"r6q7VzoOzt-T\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/IVA/AuxGaussIVA-IP1.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 3\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.iva import AuxGaussIVA\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"iva = AuxGaussIVA(\\n\",\"    spatial_algorithm=\\\"IP1\\\", # You can set \\\"IP\\\" instead of \\\"IP1\\\".\\n\",\")\\n\",\"print(iva)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = iva(spectrogram_mix, n_iter=50)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(iva.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/IVA/AuxGaussIVA-IP2.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 3\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.iva import AuxGaussIVA\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"iva = AuxGaussIVA(\\n\",\"    spatial_algorithm=\\\"IP2\\\",\\n\",\")\\n\",\"print(iva)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = iva(spectrogram_mix, n_iter=50)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(iva.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/IVA/AuxGaussIVA-IPA.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 3\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.iva import AuxGaussIVA\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"iva = AuxGaussIVA(\\n\",\"    spatial_algorithm=\\\"IPA\\\",\\n\",\"    newton_iter=10,\\n\",\")\\n\",\"print(iva)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = iva(spectrogram_mix, n_iter=50)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(iva.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/IVA/AuxGaussIVA-ISS1.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 3\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.iva import AuxGaussIVA\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"iva = AuxGaussIVA(\\n\",\"    spatial_algorithm=\\\"ISS1\\\", # You can set \\\"ISS\\\" instead of \\\"ISS1\\\".\\n\",\")\\n\",\"print(iva)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = iva(spectrogram_mix, n_iter=50)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(iva.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/IVA/AuxGaussIVA-ISS2.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 4\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female4\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.iva import AuxGaussIVA\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"iva = AuxGaussIVA(\\n\",\"    spatial_algorithm=\\\"ISS2\\\",\\n\",\")\\n\",\"print(iva)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = iva(spectrogram_mix, n_iter=50)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(iva.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/IVA/AuxIVA-IP1.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 3\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.iva import AuxIVA\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"def contrast_fn(y):\\n\",\"    return 2 * np.linalg.norm(y, axis=1)\\n\",\"\\n\",\"def d_contrast_fn(y):\\n\",\"    return 2 * np.ones_like(y)\"],\"metadata\":{\"id\":\"6AvhtkrdfAeZ\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"iva = AuxIVA(\\n\",\"    spatial_algorithm=\\\"IP1\\\", # You can set \\\"IP\\\" instead of \\\"IP1\\\".\\n\",\"    contrast_fn=contrast_fn,\\n\",\"    d_contrast_fn=d_contrast_fn\\n\",\")\\n\",\"print(iva)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = iva(spectrogram_mix, n_iter=50)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(iva.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/IVA/AuxIVA-IP2.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 3\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.iva import AuxIVA\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"def contrast_fn(y):\\n\",\"    return 2 * np.linalg.norm(y, axis=1)\\n\",\"\\n\",\"def d_contrast_fn(y):\\n\",\"    return 2 * np.ones_like(y)\"],\"metadata\":{\"id\":\"6AvhtkrdfAeZ\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"iva = AuxIVA(\\n\",\"    spatial_algorithm=\\\"IP2\\\",\\n\",\"    contrast_fn=contrast_fn,\\n\",\"    d_contrast_fn=d_contrast_fn\\n\",\")\\n\",\"print(iva)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = iva(spectrogram_mix, n_iter=50)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(iva.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/IVA/AuxIVA-IPA.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 3\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.iva import AuxIVA\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"def contrast_fn(y):\\n\",\"    return 2 * np.linalg.norm(y, axis=1)\\n\",\"\\n\",\"def d_contrast_fn(y):\\n\",\"    return 2 * np.ones_like(y)\"],\"metadata\":{\"id\":\"6AvhtkrdfAeZ\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"iva = AuxIVA(\\n\",\"    spatial_algorithm=\\\"IPA\\\",\\n\",\"    contrast_fn=contrast_fn,\\n\",\"    d_contrast_fn=d_contrast_fn,\\n\",\"    newton_iter=10,\\n\",\")\\n\",\"print(iva)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = iva(spectrogram_mix, n_iter=50)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(iva.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"OoJoi32YK5PQ\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/IVA/AuxIVA-ISS1.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 3\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.iva import AuxIVA\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"def contrast_fn(y):\\n\",\"    return 2 * np.linalg.norm(y, axis=1)\\n\",\"\\n\",\"def d_contrast_fn(y):\\n\",\"    return 2 * np.ones_like(y)\"],\"metadata\":{\"id\":\"6AvhtkrdfAeZ\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"iva = AuxIVA(\\n\",\"    spatial_algorithm=\\\"ISS1\\\", # You can set \\\"ISS\\\" instead of \\\"ISS1\\\".\\n\",\"    contrast_fn=contrast_fn,\\n\",\"    d_contrast_fn=d_contrast_fn\\n\",\")\\n\",\"print(iva)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = iva(spectrogram_mix, n_iter=50)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(iva.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/IVA/AuxIVA-ISS2.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 4\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female4\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.iva import AuxIVA\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"def contrast_fn(y):\\n\",\"    return 2 * np.linalg.norm(y, axis=1)\\n\",\"\\n\",\"def d_contrast_fn(y):\\n\",\"    return 2 * np.ones_like(y)\"],\"metadata\":{\"id\":\"6AvhtkrdfAeZ\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"iva = AuxIVA(\\n\",\"    spatial_algorithm=\\\"ISS2\\\",\\n\",\"    contrast_fn=contrast_fn,\\n\",\"    d_contrast_fn=d_contrast_fn\\n\",\")\\n\",\"print(iva)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = iva(spectrogram_mix, n_iter=50)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(iva.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/IVA/AuxLaplaceIVA-IP1.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 3\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.iva import AuxLaplaceIVA\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"iva = AuxLaplaceIVA(\\n\",\"    spatial_algorithm=\\\"IP1\\\", # You can set \\\"IP\\\" instead of \\\"IP1\\\".\\n\",\")\\n\",\"print(iva)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = iva(spectrogram_mix, n_iter=50)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(iva.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/IVA/AuxLaplaceIVA-IP2.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 3\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.iva import AuxLaplaceIVA\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"iva = AuxLaplaceIVA(\\n\",\"    spatial_algorithm=\\\"IP2\\\",\\n\",\")\\n\",\"print(iva)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = iva(spectrogram_mix, n_iter=50)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(iva.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/IVA/AuxLaplaceIVA-IPA.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 3\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.iva import AuxLaplaceIVA\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"iva = AuxLaplaceIVA(\\n\",\"    spatial_algorithm=\\\"IPA\\\",\\n\",\"    newton_iter=10,\\n\",\")\\n\",\"print(iva)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = iva(spectrogram_mix, n_iter=50)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(iva.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/IVA/AuxLaplaceIVA-ISS1.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 3\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.iva import AuxLaplaceIVA\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"iva = AuxLaplaceIVA(\\n\",\"    spatial_algorithm=\\\"ISS1\\\", # You can set \\\"ISS\\\" instead of \\\"ISS1\\\".\\n\",\")\\n\",\"print(iva)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = iva(spectrogram_mix, n_iter=50)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(iva.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/IVA/AuxLaplaceIVA-ISS2.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 4\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female4\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.iva import AuxLaplaceIVA\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"iva = AuxLaplaceIVA(\\n\",\"    spatial_algorithm=\\\"ISS2\\\",\\n\",\")\\n\",\"print(iva)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = iva(spectrogram_mix, n_iter=50)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(iva.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/IVA/FastIVA.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 3\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.iva import FastIVA\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"def contrast_fn(y):\\n\",\"    return 2 * np.linalg.norm(y, axis=1)\\n\",\"\\n\",\"def d_contrast_fn(y):\\n\",\"    return 2 * np.ones_like(y)\\n\",\"\\n\",\"def dd_contrast_fn(y):\\n\",\"    return 2 * np.zeros_like(y)\"],\"metadata\":{\"id\":\"6AvhtkrdfAeZ\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"iva = FastIVA(\\n\",\"    contrast_fn=contrast_fn,\\n\",\"    d_contrast_fn=d_contrast_fn,\\n\",\"    dd_contrast_fn=dd_contrast_fn\\n\",\")\\n\",\"print(iva)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = iva(spectrogram_mix, n_iter=100)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(iva.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"YQdFA8YAv7NK\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/IVA/FasterIVA.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 2\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.iva import FasterIVA\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"def contrast_fn(y):\\n\",\"    return 2 * np.linalg.norm(y, axis=1)\\n\",\"\\n\",\"def d_contrast_fn(y):\\n\",\"    return 2 * np.ones_like(y)\"],\"metadata\":{\"id\":\"6AvhtkrdfAeZ\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"iva = FasterIVA(\\n\",\"    contrast_fn=contrast_fn,\\n\",\"    d_contrast_fn=d_contrast_fn\\n\",\")\\n\",\"print(iva)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = iva(spectrogram_mix, n_iter=50)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(iva.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"YQdFA8YAv7NK\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/IVA/GradGaussIVA.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"6qelLYKGEye2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 3\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.transform import whiten\\n\",\"from ssspy.algorithm import projection_back\\n\",\"from ssspy.bss.iva import GradGaussIVA\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## Holonomic type\"],\"metadata\":{\"id\":\"9XmmBfq1p32-\"}},{\"cell_type\":\"code\",\"source\":[\"iva = GradGaussIVA(\\n\",\"    step_size=1e-1,\\n\",\"    is_holonomic=True,\\n\",\"    scale_restoration=False\\n\",\")\\n\",\"print(iva)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_mix_whitened = whiten(spectrogram_mix)\\n\",\"spectrogram_est = iva(spectrogram_mix_whitened, n_iter=100)\\n\",\"spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(iva.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## Non-holonomic type\"],\"metadata\":{\"id\":\"YNtBangeqF3m\"}},{\"cell_type\":\"code\",\"source\":[\"iva = GradGaussIVA(\\n\",\"    step_size=1e-1,\\n\",\"    is_holonomic=False,\\n\",\"    scale_restoration=False\\n\",\")\\n\",\"print(iva)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"wFA0WatXqG3C\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_mix_whitened = whiten(spectrogram_mix)\\n\",\"spectrogram_est = iva(spectrogram_mix_whitened, n_iter=100)\\n\",\"spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)\"],\"metadata\":{\"id\":\"HEbtWD5nqJpD\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"9LYvRj20qLW2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"_pvjv8PDqNBh\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(iva.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"Dfue8hDuqOVf\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"UsEuDyYjqPx0\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/IVA/GradIVA.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"6qelLYKGEye2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 3\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.transform import whiten\\n\",\"from ssspy.algorithm import projection_back\\n\",\"from ssspy.bss.iva import GradIVA\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"def contrast_fn(y):\\n\",\"    return 2 * np.linalg.norm(y, axis=1)\\n\",\"\\n\",\"def score_fn(y):\\n\",\"    norm = np.linalg.norm(y, axis=1, keepdims=True)\\n\",\"    norm = np.maximum(norm, 1e-10)\\n\",\"    return y / norm\"],\"metadata\":{\"id\":\"6AvhtkrdfAeZ\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## Holonomic type\"],\"metadata\":{\"id\":\"9XmmBfq1p32-\"}},{\"cell_type\":\"code\",\"source\":[\"iva = GradIVA(\\n\",\"    step_size=1e+2,\\n\",\"    contrast_fn=contrast_fn,\\n\",\"    score_fn=score_fn,\\n\",\"    is_holonomic=True,\\n\",\"    scale_restoration=False\\n\",\")\\n\",\"print(iva)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_mix_whitened = whiten(spectrogram_mix)\\n\",\"spectrogram_est = iva(spectrogram_mix_whitened, n_iter=500)\\n\",\"spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(iva.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## Non-holonomic type\"],\"metadata\":{\"id\":\"YNtBangeqF3m\"}},{\"cell_type\":\"code\",\"source\":[\"iva = GradIVA(\\n\",\"    step_size=1e+1,\\n\",\"    contrast_fn=contrast_fn,\\n\",\"    score_fn=score_fn,\\n\",\"    is_holonomic=False,\\n\",\"    scale_restoration=False\\n\",\")\\n\",\"print(iva)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"wFA0WatXqG3C\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_mix_whitened = whiten(spectrogram_mix)\\n\",\"spectrogram_est = iva(spectrogram_mix_whitened, n_iter=500)\\n\",\"spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)\"],\"metadata\":{\"id\":\"HEbtWD5nqJpD\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"9LYvRj20qLW2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"_pvjv8PDqNBh\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(iva.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"Dfue8hDuqOVf\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"UsEuDyYjqPx0\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/IVA/GradLaplaceIVA.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"6qelLYKGEye2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 3\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.transform import whiten\\n\",\"from ssspy.algorithm import projection_back\\n\",\"from ssspy.bss.iva import GradLaplaceIVA\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## Holonomic type\"],\"metadata\":{\"id\":\"9XmmBfq1p32-\"}},{\"cell_type\":\"code\",\"source\":[\"iva = GradLaplaceIVA(\\n\",\"    step_size=1e+2,\\n\",\"    is_holonomic=True,\\n\",\"    scale_restoration=False\\n\",\")\\n\",\"print(iva)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_mix_whitened = whiten(spectrogram_mix)\\n\",\"spectrogram_est = iva(spectrogram_mix_whitened, n_iter=500)\\n\",\"spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(iva.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## Non-holonomic type\"],\"metadata\":{\"id\":\"YNtBangeqF3m\"}},{\"cell_type\":\"code\",\"source\":[\"iva = GradLaplaceIVA(\\n\",\"    step_size=1e+1,\\n\",\"    is_holonomic=False,\\n\",\"    scale_restoration=False\\n\",\")\\n\",\"print(iva)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"wFA0WatXqG3C\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_mix_whitened = whiten(spectrogram_mix)\\n\",\"spectrogram_est = iva(spectrogram_mix_whitened, n_iter=500)\\n\",\"spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)\"],\"metadata\":{\"id\":\"HEbtWD5nqJpD\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"9LYvRj20qLW2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"_pvjv8PDqNBh\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(iva.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"Dfue8hDuqOVf\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"UsEuDyYjqPx0\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/IVA/NaturalGradGaussIVA.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"4VTo4zewE2uu\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 3\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.transform import whiten\\n\",\"from ssspy.algorithm import projection_back\\n\",\"from ssspy.bss.iva import NaturalGradGaussIVA\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## Holonomic type\"],\"metadata\":{\"id\":\"9XmmBfq1p32-\"}},{\"cell_type\":\"code\",\"source\":[\"iva = NaturalGradGaussIVA(\\n\",\"    step_size=1e-2,\\n\",\"    is_holonomic=True,\\n\",\"    scale_restoration=False\\n\",\")\\n\",\"print(iva)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_mix_whitened = whiten(spectrogram_mix)\\n\",\"spectrogram_est = iva(spectrogram_mix_whitened, n_iter=500)\\n\",\"spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(iva.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## Non-holonomic type\"],\"metadata\":{\"id\":\"YNtBangeqF3m\"}},{\"cell_type\":\"code\",\"source\":[\"iva = NaturalGradGaussIVA(\\n\",\"    step_size=1e-1,\\n\",\"    is_holonomic=False,\\n\",\"    scale_restoration=False\\n\",\")\\n\",\"print(iva)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"wFA0WatXqG3C\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_mix_whitened = whiten(spectrogram_mix)\\n\",\"spectrogram_est = iva(spectrogram_mix_whitened, n_iter=500)\\n\",\"spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)\"],\"metadata\":{\"id\":\"HEbtWD5nqJpD\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"9LYvRj20qLW2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"_pvjv8PDqNBh\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(iva.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"Dfue8hDuqOVf\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"UsEuDyYjqPx0\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/IVA/NaturalGradIVA.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"4VTo4zewE2uu\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 3\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.transform import whiten\\n\",\"from ssspy.algorithm import projection_back\\n\",\"from ssspy.bss.iva import NaturalGradIVA\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"def contrast_fn(y):\\n\",\"    return 2 * np.linalg.norm(y, axis=1)\\n\",\"\\n\",\"def score_fn(y):\\n\",\"    norm = np.linalg.norm(y, axis=1, keepdims=True)\\n\",\"    norm = np.maximum(norm, 1e-10)\\n\",\"    return y / norm\"],\"metadata\":{\"id\":\"6AvhtkrdfAeZ\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## Holonomic type\"],\"metadata\":{\"id\":\"9XmmBfq1p32-\"}},{\"cell_type\":\"code\",\"source\":[\"iva = NaturalGradIVA(\\n\",\"    step_size=1e-1,\\n\",\"    contrast_fn=contrast_fn,\\n\",\"    score_fn=score_fn,\\n\",\"    is_holonomic=True\\n\",\")\\n\",\"print(iva)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = iva(spectrogram_mix, n_iter=500)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(iva.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## Non-holonomic type\"],\"metadata\":{\"id\":\"YNtBangeqF3m\"}},{\"cell_type\":\"code\",\"source\":[\"iva = NaturalGradIVA(\\n\",\"    step_size=1e+1,\\n\",\"    contrast_fn=contrast_fn,\\n\",\"    score_fn=score_fn,\\n\",\"    is_holonomic=False,\\n\",\"    scale_restoration=False\\n\",\")\\n\",\"print(iva)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"wFA0WatXqG3C\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_mix_whitened = whiten(spectrogram_mix)\\n\",\"spectrogram_est = iva(spectrogram_mix_whitened, n_iter=500)\\n\",\"spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)\"],\"metadata\":{\"id\":\"HEbtWD5nqJpD\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"9LYvRj20qLW2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"_pvjv8PDqNBh\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(iva.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"Dfue8hDuqOVf\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"UsEuDyYjqPx0\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/IVA/NaturalGradLaplaceIVA.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"4VTo4zewE2uu\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 3\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.transform import whiten\\n\",\"from ssspy.algorithm import projection_back\\n\",\"from ssspy.bss.iva import NaturalGradLaplaceIVA\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## Holonomic type\"],\"metadata\":{\"id\":\"9XmmBfq1p32-\"}},{\"cell_type\":\"code\",\"source\":[\"iva = NaturalGradLaplaceIVA(\\n\",\"    step_size=1e-1,\\n\",\"    is_holonomic=True,\\n\",\")\\n\",\"print(iva)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = iva(spectrogram_mix, n_iter=500)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(iva.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## Non-holonomic type\"],\"metadata\":{\"id\":\"YNtBangeqF3m\"}},{\"cell_type\":\"code\",\"source\":[\"iva = NaturalGradLaplaceIVA(\\n\",\"    step_size=1e+1,\\n\",\"    is_holonomic=False,\\n\",\"    scale_restoration=False\\n\",\")\\n\",\"print(iva)\"],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"wFA0WatXqG3C\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_mix_whitened = whiten(spectrogram_mix)\\n\",\"spectrogram_est = iva(spectrogram_mix_whitened, n_iter=500)\\n\",\"spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)\"],\"metadata\":{\"id\":\"HEbtWD5nqJpD\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"9LYvRj20qLW2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"_pvjv8PDqNBh\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(iva.loss)\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"Dfue8hDuqOVf\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"UsEuDyYjqPx0\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/MNMF/FastGaussMNMF-IP1.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"8WeTZarWPKFL\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\\n\",\"from tqdm.notebook import tqdm\"],\"metadata\":{\"id\":\"EJnsybEqPOuX\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"z9eNi9hHPQjM\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 2\\n\",\"max_duration = 10\\n\",\"reverb_duration = 0.36\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"eLlZRFo8PRrL\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    reverb_duration=reverb_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"KxmdFO-mPUPr\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"HqsfDVe5PV3j\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.mnmf import FastGaussMNMF as FastGaussMNMFBase\"],\"metadata\":{\"id\":\"pD9uMM6ZPaqZ\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"class FastGaussMNMF(FastGaussMNMFBase):\\n\",\"    def __init__(self, *args, **kwargs):\\n\",\"        super().__init__(*args, **kwargs)\\n\",\"\\n\",\"        self.progress_bar = None\\n\",\"\\n\",\"    def __call__(self, *args, n_iter: int = 100, **kwargs):\\n\",\"        self.n_iter = n_iter\\n\",\"\\n\",\"        return super().__call__(*args, n_iter=n_iter, **kwargs)\\n\",\"\\n\",\"    def update_once(self) -> None:\\n\",\"        if self.progress_bar is None:\\n\",\"            self.progress_bar = tqdm(total=self.n_iter)\\n\",\"\\n\",\"        super().update_once()\\n\",\"\\n\",\"        self.progress_bar.update(1)\"],\"metadata\":{\"id\":\"zqa964qWKKhN\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"mnmf = FastGaussMNMF(\\n\",\"    n_basis=16,\\n\",\"    n_sources=2,\\n\",\"    diagonalizer_algorithm=\\\"IP1\\\",  # You can set \\\"IP\\\" instead of \\\"IP1\\\".\\n\",\"    partitioning=False,\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(mnmf)\"],\"metadata\":{\"id\":\"vF2lhWRfPXDs\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"JD5ndId6Pe6l\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = mnmf(spectrogram_mix, n_iter=200)\"],\"metadata\":{\"id\":\"G3LxT-USPgB6\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"luvhp8QRPg_1\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"YddaALjBPiA_\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(mnmf.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"zIFWvEW0PjDI\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"UqwiQjodPkiH\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/MNMF/FastGaussMNMF-IP2.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"8WeTZarWPKFL\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\\n\",\"from tqdm.notebook import tqdm\"],\"metadata\":{\"id\":\"EJnsybEqPOuX\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"z9eNi9hHPQjM\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 2\\n\",\"max_duration = 10\\n\",\"reverb_duration = 0.36\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"eLlZRFo8PRrL\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    reverb_duration=reverb_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"KxmdFO-mPUPr\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"HqsfDVe5PV3j\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.mnmf import FastGaussMNMF as FastGaussMNMFBase\"],\"metadata\":{\"id\":\"pD9uMM6ZPaqZ\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"class FastGaussMNMF(FastGaussMNMFBase):\\n\",\"    def __init__(self, *args, **kwargs):\\n\",\"        super().__init__(*args, **kwargs)\\n\",\"\\n\",\"        self.progress_bar = None\\n\",\"\\n\",\"    def __call__(self, *args, n_iter: int = 100, **kwargs):\\n\",\"        self.n_iter = n_iter\\n\",\"\\n\",\"        return super().__call__(*args, n_iter=n_iter, **kwargs)\\n\",\"\\n\",\"    def update_once(self) -> None:\\n\",\"        if self.progress_bar is None:\\n\",\"            self.progress_bar = tqdm(total=self.n_iter)\\n\",\"\\n\",\"        super().update_once()\\n\",\"\\n\",\"        self.progress_bar.update(1)\"],\"metadata\":{\"id\":\"TrIDa8g-KQwx\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"mnmf = FastGaussMNMF(\\n\",\"    n_basis=8,\\n\",\"    n_sources=2,\\n\",\"    diagonalizer_algorithm=\\\"IP2\\\",\\n\",\"    partitioning=False,\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(mnmf)\"],\"metadata\":{\"id\":\"vF2lhWRfPXDs\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"JD5ndId6Pe6l\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = mnmf(spectrogram_mix, n_iter=200)\"],\"metadata\":{\"id\":\"G3LxT-USPgB6\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"luvhp8QRPg_1\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"YddaALjBPiA_\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(mnmf.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"zIFWvEW0PjDI\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"UqwiQjodPkiH\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/MNMF/GaussMNMF.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"ab7xaF2Brwdn\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\\n\",\"from tqdm.notebook import tqdm\"],\"metadata\":{\"id\":\"DXzG4Q9pr0bb\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"-Dv4Pr4lr1h5\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 2\\n\",\"max_duration = 10\\n\",\"reverb_duration = 0.36\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"zIcbmVidr2da\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    reverb_duration=reverb_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7qf18ULsr3Ra\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"eqGUE2Lxr4kE\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.mnmf import GaussMNMF as GaussMNMFBase\"],\"metadata\":{\"id\":\"AE2qyrBar5mu\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"class GaussMNMF(GaussMNMFBase):\\n\",\"    def __init__(self, *args, **kwargs) -> None:\\n\",\"        super().__init__(*args, **kwargs)\\n\",\"\\n\",\"        self.progress_bar = None\\n\",\"\\n\",\"    def __call__(self, *args, n_iter: int = 100, **kwargs) -> np.ndarray:\\n\",\"        self.n_iter = n_iter\\n\",\"\\n\",\"        return super().__call__(*args, n_iter=n_iter, **kwargs)\\n\",\"\\n\",\"    def update_once(self) -> None:\\n\",\"        if self.progress_bar is None:\\n\",\"            self.progress_bar = tqdm(total=self.n_iter)\\n\",\"\\n\",\"        super().update_once()\\n\",\"\\n\",\"        self.progress_bar.update(1)\"],\"metadata\":{\"id\":\"7WV07guJr6ql\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/ partitioning function\"],\"metadata\":{\"id\":\"e-LL4HxC1B0o\"}},{\"cell_type\":\"code\",\"source\":[\"mnmf = GaussMNMF(\\n\",\"    n_basis=30,\\n\",\"    n_sources=2,\\n\",\"    partitioning=True,\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(mnmf)\"],\"metadata\":{\"id\":\"AhCv-7wsr7qW\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"ZrzY92tor-KR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = mnmf(spectrogram_mix, n_iter=200)\"],\"metadata\":{\"id\":\"1q6-YqiRr_bR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"3RrMJzK7sAjB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"x5NrR6GYsBhp\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(mnmf.loss[10:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"gai3R1cUsCpz\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"## w/o partitioning function\"],\"metadata\":{\"id\":\"QRlsyJNh1G3S\"}},{\"cell_type\":\"code\",\"source\":[\"mnmf = GaussMNMF(\\n\",\"    n_basis=10,\\n\",\"    n_sources=2,\\n\",\"    partitioning=False,\\n\",\"    rng=np.random.default_rng(42),\\n\",\")\\n\",\"print(mnmf)\"],\"metadata\":{\"id\":\"ulNwZ-TusD3F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"CJzGmTKY1N2p\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = mnmf(spectrogram_mix, n_iter=500)\"],\"metadata\":{\"id\":\"E5zfaLRl1Ou5\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"YMgxFM0-1Pzb\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"n4bek3rj1Q34\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(mnmf.loss[10:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"ymg1yCBn1Ruj\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"UzRVEMZo1Uxv\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/PDSBSS/PDSBSS.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 2\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.transform import whiten\\n\",\"from ssspy.algorithm import projection_back\\n\",\"from ssspy.bss.pdsbss import PDSBSS\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"def l21_fn(y: np.ndarray) -> np.ndarray:\\n\",\"    \\\"\\\"\\\"Mixed L21 norm.\\n\",\"\\n\",\"    Args:\\n\",\"        y (np.ndarray):\\n\",\"            Input vector with shape of (n_sources, n_bins, n_frames).\\n\",\"\\n\",\"    Returns:\\n\",\"        Sum of mixed L21 norm.\\n\",\"    \\\"\\\"\\\"\\n\",\"    G = np.linalg.norm(y, axis=1)\\n\",\"    loss = np.sum(G, axis=(0, 1))\\n\",\"\\n\",\"    return loss\\n\",\"\\n\",\"def prox_l21(y, step_size: float = 1) -> np.ndarray:\\n\",\"    \\\"\\\"\\\"Apply proximal operator of mixed L21 norm.\\n\",\"\\n\",\"    Args:\\n\",\"        y (np.ndarray):\\n\",\"            Input vector with shape of (n_sources, n_bins, n_frames).\\n\",\"        step_size (float):\\n\",\"            Step size parameter.\\n\",\"\\n\",\"    Returns:\\n\",\"        Output value computed by proximal operator of mixed L21 norm.\\n\",\"        The shape of (n_sources, n_bins, n_frames).\\n\",\"    \\\"\\\"\\\"\\n\",\"    norm = np.linalg.norm(y, axis=1, keepdims=True)\\n\",\"\\n\",\"    return np.maximum(1 - step_size / norm, 0) * y\"],\"metadata\":{\"id\":\"6AvhtkrdfAeZ\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"pds_bss = PDSBSS(\\n\",\"    mu1=1,\\n\",\"    mu2=1,\\n\",\"    alpha=1.75,\\n\",\"    penalty_fn=l21_fn,\\n\",\"    prox_penalty=prox_l21,\\n\",\"    scale_restoration=False,\\n\",\")\\n\",\"print(pds_bss)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_mix_whitened = whiten(spectrogram_mix)\\n\",\"spectrogram_mix_normalized = pds_bss.normalize_by_spectral_norm(spectrogram_mix_whitened)\\n\",\"spectrogram_est = pds_bss(spectrogram_mix_normalized, n_iter=1000)\\n\",\"spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(pds_bss.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/PDSBSS/PDSBSS_masking.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[],\"authorship_tag\":\"ABX9TyM/ed6diEsnD5EoCV6u0+FD\"},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"nr1INZyXeXuR\"},\"outputs\":[],\"source\":[\"!pip install git+https://github.com/tky823/ssspy.git\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"r9-0QuFfeaut\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"bYCpxjc6ecsj\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 2\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"psnOiF2Tedv_\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"9DGQmwafeeoR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"FVoIVFKPefZB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"import functools\"],\"metadata\":{\"id\":\"NFL0jjzSevDf\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.transform import whiten\\n\",\"from ssspy.algorithm import projection_back\\n\",\"from ssspy.bss.pdsbss import MaskingPDSBSS\"],\"metadata\":{\"id\":\"YFJ_fx8gegRT\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"def l21_mask(y, step_size: float = 1) -> np.ndarray:\\n\",\"    \\\"\\\"\\\"Apply proximal operator of mixed L21 norm.\\n\",\"\\n\",\"    Args:\\n\",\"        y (np.ndarray):\\n\",\"            Input vector with shape of (n_sources, n_bins, n_frames).\\n\",\"        step_size (float):\\n\",\"            Step size parameter.\\n\",\"\\n\",\"    Returns:\\n\",\"        Mask based on mixed L21 norm.\\n\",\"        The shape of (n_sources, n_bins, n_frames).\\n\",\"\\n\",\"    \\\"\\\"\\\"\\n\",\"    norm = np.linalg.norm(y, axis=1, keepdims=True)\\n\",\"    mask = np.maximum(1 - step_size / norm, 0)\\n\",\"    mask = np.tile(mask, (1, y.shape[1], 1))\\n\",\"\\n\",\"    return mask\"],\"metadata\":{\"id\":\"JyLNkfqKepOH\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"pds_bss = MaskingPDSBSS(\\n\",\"    mu1=1,\\n\",\"    mu2=1,\\n\",\"    relaxation=1.75,\\n\",\"    mask_fn=functools.partial(l21_mask, step_size=1),\\n\",\"    scale_restoration=False,\\n\",\")\\n\",\"print(pds_bss)\"],\"metadata\":{\"id\":\"3pvUgw5vesw1\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"ii11j5XrexNu\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_mix_whitened = whiten(spectrogram_mix)\\n\",\"spectrogram_mix_normalized = pds_bss.normalize_by_spectral_norm(spectrogram_mix_whitened)\\n\",\"spectrogram_est = pds_bss(spectrogram_mix_normalized, n_iter=1000)\\n\",\"spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)\"],\"metadata\":{\"id\":\"XpScpxUQe1cT\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"HY8VE2ARe3mW\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"v9vi88kKe48R\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"008c-U04e6Dw\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/BSS/PDSBSS/PDSBSS_multi-penalty.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[]},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"iLMRz5h_I_U_\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import matplotlib.pyplot as plt\\n\",\"import IPython.display as ipd\"],\"metadata\":{\"id\":\"k4vmcNb5em_x\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.utils.dataset import download_sample_speech_data\"],\"metadata\":{\"id\":\"rOvxrG-sfp02\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 2\\n\",\"max_duration = 10\\n\",\"sisec2010_tag = \\\"dev1_female3\\\"\\n\",\"n_fft, hop_length = 4096, 2048\"],\"metadata\":{\"id\":\"faUwi6X9esWR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_speech_data(\\n\",\"    n_sources=n_sources,\\n\",\"    sisec2010_tag=sisec2010_tag,\\n\",\"    max_duration=max_duration,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"7jwyi2wReuRR\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"Pa28gsTce8yt\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"import functools\"],\"metadata\":{\"id\":\"ga-3Noc9sr5p\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"from ssspy.transform import whiten\\n\",\"from ssspy.algorithm import projection_back\\n\",\"from ssspy.bss.pdsbss import PDSBSS\"],\"metadata\":{\"id\":\"tixTvLybe-w7\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"def l21_fn(y: np.ndarray) -> np.ndarray:\\n\",\"    \\\"\\\"\\\"Compute sum of mixed L21 norm.\\n\",\"\\n\",\"    Args:\\n\",\"        y (np.ndarray):\\n\",\"            Input vector with shape of (n_sources, n_bins, n_frames).\\n\",\"\\n\",\"    Returns:\\n\",\"        Sum of mixed L21 norm.\\n\",\"    \\\"\\\"\\\"\\n\",\"    G = np.linalg.norm(y, axis=1)\\n\",\"    loss = np.sum(G, axis=(0, 1))\\n\",\"\\n\",\"    return loss\\n\",\"\\n\",\"def lamb_l1_fn(y: np.ndarray, lamb: float = 1) -> np.ndarray:\\n\",\"    \\\"\\\"\\\"Compute sum of L1 norm.\\n\",\"\\n\",\"    Args:\\n\",\"        y (np.ndarray):\\n\",\"            Input vector with shape of (n_sources, n_bins, n_frames).\\n\",\"\\n\",\"    Returns:\\n\",\"        Sum of L1 norm.\\n\",\"    \\\"\\\"\\\"\\n\",\"    G = np.abs(y)\\n\",\"    loss = lamb * np.sum(G, axis=(0, 1, 2))\\n\",\"\\n\",\"    return loss\\n\",\"\\n\",\"def prox_l21(y: np.ndarray, step_size: float = 1) -> np.ndarray:\\n\",\"    \\\"\\\"\\\"Apply proximal operator of mixed L21 norm.\\n\",\"\\n\",\"    Args:\\n\",\"        y (np.ndarray):\\n\",\"            Input vector with shape of (n_sources, n_bins, n_frames).\\n\",\"        step_size (float):\\n\",\"            Step size parameter.\\n\",\"\\n\",\"    Returns:\\n\",\"        Output value computed by proximal operator of mixed L21 norm.\\n\",\"        The shape of (n_sources, n_bins, n_frames).\\n\",\"    \\\"\\\"\\\"\\n\",\"    norm = np.linalg.norm(y, axis=1, keepdims=True)\\n\",\"\\n\",\"    return np.maximum(1 - step_size / norm, 0) * y\\n\",\"\\n\",\"def prox_lamb_l1(y: np.ndarray, step_size: float=1, lamb: float = 1) -> np.ndarray:\\n\",\"    \\\"\\\"\\\"Apply proximal operator of L1 norm.\\n\",\"\\n\",\"    Args:\\n\",\"        y (np.ndarray):\\n\",\"            Input vector with shape of (n_sources, n_bins, n_frames).\\n\",\"        step_size (float):\\n\",\"            Step size parameter.\\n\",\"\\n\",\"    Returns:\\n\",\"        Output value computed by proximal operator of L1 norm.\\n\",\"        The shape of (n_sources, n_bins, n_frames).\\n\",\"    \\\"\\\"\\\"\\n\",\"    norm = np.abs(y)\\n\",\"\\n\",\"    return np.maximum(1 - (step_size * lamb) / norm, 0) * y\"],\"metadata\":{\"id\":\"6AvhtkrdfAeZ\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"# SparseIVA without masking\\n\",\"penalty_fn = [\\n\",\"    l21_fn,\\n\",\"    functools.partial(lamb_l1_fn, lamb=2e-3),\\n\",\"]\\n\",\"prox_penalty = [\\n\",\"    prox_l21,\\n\",\"    functools.partial(prox_lamb_l1, lamb=2e-3),\\n\",\"]\"],\"metadata\":{\"id\":\"hdi1FlWTs0Oc\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"pds_bss = PDSBSS(\\n\",\"    mu1=1,\\n\",\"    mu2=1,\\n\",\"    alpha=1.75,\\n\",\"    penalty_fn=penalty_fn,\\n\",\"    prox_penalty=prox_penalty,\\n\",\"    scale_restoration=False\\n\",\")\\n\",\"print(pds_bss)\"],\"metadata\":{\"id\":\"h5GghnnMfP1F\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"EecGuY-JfSBB\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_mix_whitened = whiten(spectrogram_mix)\\n\",\"spectrogram_mix_normalized = pds_bss.normalize_by_spectral_norm(spectrogram_mix_whitened)\\n\",\"spectrogram_est = pds_bss(spectrogram_mix_normalized, n_iter=1000)\\n\",\"spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)\"],\"metadata\":{\"id\":\"O3D2eA8HfVs2\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=\\\"hann\\\", nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"jdCqrAdPfXkk\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"xXTnlid-fZb3\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(pds_bss.loss[1:])\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"OHjwNcZIfe0K\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"qsBes0ysf9Sd\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "notebooks/Examples/Getting-Started.ipynb",
    "content": "{\"nbformat\":4,\"nbformat_minor\":0,\"metadata\":{\"colab\":{\"provenance\":[],\"authorship_tag\":\"ABX9TyO78AF8O5/QHlr4M5G1igtQ\"},\"kernelspec\":{\"name\":\"python3\",\"display_name\":\"Python 3\"},\"language_info\":{\"name\":\"python\"}},\"cells\":[{\"cell_type\":\"markdown\",\"source\":[\"# Quick Example of Blind Source Separation\\n\",\"In this notebook, we will show you a quick example of blind source separation by Gauss-ILRMA.\"],\"metadata\":{\"id\":\"RvVmZjNHDlDm\"}},{\"cell_type\":\"code\",\"execution_count\":null,\"metadata\":{\"id\":\"A5OAL3QxCtDz\"},\"outputs\":[],\"source\":[\"!pip install ssspy\"]},{\"cell_type\":\"code\",\"source\":[\"import numpy as np\\n\",\"import scipy.signal as ss\\n\",\"import IPython.display as ipd\\n\",\"import matplotlib.pyplot as plt\"],\"metadata\":{\"id\":\"h78m9LcCDVnJ\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"Here, we use sample music data.\\n\",\"\\n\"],\"metadata\":{\"id\":\"uIkscttUxNDy\"}},{\"cell_type\":\"code\",\"source\":[\"import os\\n\",\"import urllib.request\\n\",\"from typing import Tuple\\n\",\"\\n\",\"import ssspy\\n\",\"\\n\",\"def download_sample_music_data(\\n\",\"    n_sources: int = 2,\\n\",\"    conv: bool = True,\\n\",\"    branch: str = \\\"main\\\"\\n\",\") -> Tuple[np.ndarray, int]:\\n\",\"    instruments = [\\\"violin\\\", \\\"piano\\\"]\\n\",\"    root = \\\".data\\\"\\n\",\"    url_template = \\\"https://github.com/tky823/ssspy-data/raw/{branch}/{path}\\\"\\n\",\"    path_template = \\\"audio/{instrument}_8k_reverbed.wav\\\"\\n\",\"    waveforms = []\\n\",\"\\n\",\"    for src_idx in range(1, n_sources + 1):\\n\",\"        instrument = instruments[src_idx - 1]\\n\",\"        path = path_template.format(instrument=instrument)\\n\",\"        download_path = os.path.join(root, path)\\n\",\"        url = url_template.format(branch=branch, path=path)\\n\",\"\\n\",\"        os.makedirs(os.path.dirname(download_path), exist_ok=True)\\n\",\"\\n\",\"        if not os.path.exists(download_path):\\n\",\"            urllib.request.urlretrieve(url, download_path)\\n\",\"\\n\",\"        waveform, sample_rate = ssspy.wavread(download_path, channels_first=True)\\n\",\"\\n\",\"        assert sample_rate == 8000, f\\\"Sampling rate is expected 8000, but {sample_rate} is given.\\\"\\n\",\"\\n\",\"        waveforms.append(waveform)\\n\",\"\\n\",\"    waveforms = np.stack(waveforms, axis=1)\\n\",\"\\n\",\"    return waveforms[: n_sources], 8000\"],\"metadata\":{\"id\":\"FRgyrAJQJ6qy\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"n_sources = 2\\n\",\"n_fft, hop_length = 2048, 512\\n\",\"window = \\\"hann\\\"\"],\"metadata\":{\"id\":\"kQlKWdixDW7f\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"waveform_src_img, sample_rate = download_sample_music_data(\\n\",\"    n_sources=n_sources,\\n\",\"    conv=True,\\n\",\")  # (n_channels, n_sources, n_samples)\\n\",\"waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\"],\"metadata\":{\"id\":\"kmz01toPNq0b\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"Let's listen to mixtures.\"],\"metadata\":{\"id\":\"nr7q1nIEeJkS\"}},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_mix):\\n\",\"    print(\\\"Mixture: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"fbc32pgfDZzO\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"`ssspy.bss.ilrma.GaussILRMA` class provides Gauss-ILRMA.\"],\"metadata\":{\"id\":\"uGd-LoXHxmcA\"}},{\"cell_type\":\"code\",\"source\":[\"from ssspy.bss.ilrma import GaussILRMA\"],\"metadata\":{\"id\":\"dw0_3iy2DbKW\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"ilrma = GaussILRMA(n_basis=3, rng=np.random.default_rng(0))\\n\",\"print(ilrma)\"],\"metadata\":{\"id\":\"f_lnJNFuDdAj\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, _, spectrogram_mix = ss.stft(waveform_mix, window=window, nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"4_URDBncDr1o\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"spectrogram_est = ilrma(spectrogram_mix, n_iter=500)\"],\"metadata\":{\"id\":\"q2ueWnTdDtcY\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[\"_, waveform_est = ss.istft(spectrogram_est, window=window, nperseg=n_fft, noverlap=n_fft-hop_length)\"],\"metadata\":{\"id\":\"YARLJx37DuwV\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"Let's listen to estimated sources.\"],\"metadata\":{\"id\":\"nxlG0Qa2eFFX\"}},{\"cell_type\":\"code\",\"source\":[\"for idx, waveform in enumerate(waveform_est):\\n\",\"    print(\\\"Estimated source: {}\\\".format(idx + 1))\\n\",\"    display(ipd.Audio(waveform, rate=sample_rate))\\n\",\"    print()\"],\"metadata\":{\"id\":\"jsom6b5YDv_0\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"Display cost at each iteration.\"],\"metadata\":{\"id\":\"wwh56-cGeAQ_\"}},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(range(len(ilrma.loss)), ilrma.loss)\\n\",\"plt.xlabel(\\\"Iteration\\\")\\n\",\"plt.ylabel(\\\"Cost\\\")\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"n-z8DgHUDxLI\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"markdown\",\"source\":[\"Since there has been a significant decrease by the first update, let's see the costs after the first iteration.\"],\"metadata\":{\"id\":\"9rtftb4vU_rw\"}},{\"cell_type\":\"code\",\"source\":[\"plt.figure()\\n\",\"plt.plot(range(1, len(ilrma.loss)), ilrma.loss[1:])\\n\",\"plt.xlabel(\\\"Iteration\\\")\\n\",\"plt.ylabel(\\\"Cost\\\")\\n\",\"plt.show()\\n\",\"plt.close()\"],\"metadata\":{\"id\":\"tV2nAF-xKmZw\"},\"execution_count\":null,\"outputs\":[]},{\"cell_type\":\"code\",\"source\":[],\"metadata\":{\"id\":\"v2r5VTF2VeW5\"},\"execution_count\":null,\"outputs\":[]}]}"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\n# ref: https://github.com/pypa/setuptools_scm\nrequires = [\n    \"setuptools>=45\",\n    \"setuptools_scm[toml]>=6.2\",    \n]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"ssspy\"\nauthors = [\n    {name = \"Takuya Hasumi\"},\n]\ndescription = \"A Python toolkit for sound source separation.\"\nreadme = \"README.md\"\nlicense = {file = \"LICENSE\"}\nurls = {url = \"https://github.com/tky823/ssspy\"}\nrequires-python = \">=3.8, <4\"\ndependencies = [\n    \"numpy\",\n]\ndynamic = [\n    \"version\",\n]\n\n[project.optional-dependencies]\ndev = [\n    \"flake8\",\n    \"black\",\n    \"isort\"\n]\ndocs = [\n    \"sphinx\",\n    \"sphinx-autodoc-typehints\",\n    \"sphinx-autobuild\",\n    \"nbsphinx\",\n    \"furo\",\n]\nnotebooks = [\n    \"ipykernel\",\n    \"matplotlib\",\n    \"scipy\",  # for STFT in notebooks\n]\ntests = [\n    \"pytest\",\n    \"pytest-cov\",\n    \"pytest-xdist\",\n    \"scipy\",\n]\n\n[tool.setuptools.dynamic]\nversion = {attr = \"ssspy.__version__\"}\n\n[tool.setuptools.packages.find]\n# TODO: redundancy with MANIFEST.in\n#       see https://github.com/tky823/ssspy/issues/256\ninclude = [\n    \"ssspy\",\n]\n\n[tool.setuptools_scm]\nwrite_to = \"ssspy/_version.py\"\nversion_scheme = \"guess-next-dev\"\nlocal_scheme = \"no-local-version\"\n\n[tool.black]\nline-length = 100\nexclude = \"ssspy/_version.py\"\n\n[tools.flake8]\nmax-line-length = 100\nexclude = \"ssspy/_version.py\"\n\n[tool.isort]\nprofile = \"black\"\nline_length = 100\n\n[tool.pytest.ini_options]\n# to import relative paths\npythonpath = [\n    \"tests\",\n]\n"
  },
  {
    "path": "ssspy/__init__.py",
    "content": "try:\n    from .io import wavread, wavwrite\nexcept ModuleNotFoundError:\n    # to avoid module not found error during installation\n    # e.g. numpy is not found in io.py\n    pass\n\ntry:\n    from ._version import __version__\nexcept ModuleNotFoundError:\n    __version__ = \"0.2.0\"\n\n__all__ = [\"__version__\", \"wavread\", \"wavwrite\"]\n"
  },
  {
    "path": "ssspy/algorithm/__init__.py",
    "content": "from . import permutation_alignment\nfrom .minimal_distortion_principle import minimal_distortion_principle\nfrom .projection_back import projection_back\n\n__all__ = [\"permutation_alignment\", \"minimal_distortion_principle\", \"projection_back\"]\n\nPROJECTION_BACK_KEYWORDS = [\"projection_back\", \"projection-back\", \"PB\"]\nMINIMAL_DISTORTION_PRINCIPLE_KEYWORDS = [\n    \"minimal_distortion_principle\",\n    \"minimal-distortion-principle\",\n    \"MDP\",\n]\n"
  },
  {
    "path": "ssspy/algorithm/minimal_distortion_principle.py",
    "content": "from typing import Optional\n\nimport numpy as np\n\n\ndef minimal_distortion_principle(\n    estimated: np.ndarray,\n    reference: Optional[np.ndarray] = None,\n    reference_id: Optional[int] = 0,\n) -> np.ndarray:\n    r\"\"\"Minimal distortion principle to restore scale ambiguity.\n\n    The implementation is based on [#matsuoka2002minimal]_.\n\n    Args:\n        estimated (numpy.ndarray):\n            Estimated spectrograms with shape of (n_channels, n_bins, n_frames).\n        reference (numpy.ndarray, optional):\n            Reference spectrogram with shape of (n_sources, n_bins, n_frames).\n        reference_id (int, optional):\n            Reference microphone index. Default: ``0``.\n\n    Returns:\n        numpy.ndarray of rescaled estimated spectrograms or demixing filters.\n\n    .. [#matsuoka2002minimal]\n        N. Murata, S. Ikeda, and A. Ziehe,\n        \"Minimal distortion principle for blind source separation,\"\n        in *Proc. ICA*, 2001, pp. 722-727.\n    \"\"\"\n    Y = estimated\n    X_conj = reference.conj()\n\n    if reference_id is None:\n        num = np.sum(Y * X_conj[:, np.newaxis, :, :], axis=-1, keepdims=True)\n    else:\n        num = np.sum(Y * X_conj[reference_id], axis=-1, keepdims=True)\n\n    denom = np.sum(np.abs(Y) ** 2, axis=-1, keepdims=True)\n    Z = num / denom\n    output_scaled = Z.conj() * Y\n\n    return output_scaled\n"
  },
  {
    "path": "ssspy/algorithm/permutation_alignment.py",
    "content": "import functools\nimport itertools\nfrom typing import Callable, Optional\n\nimport numpy as np\n\nfrom ..special.flooring import identity, max_flooring\n\nEPS = 1e-10\n\n\ndef correlation_based_permutation_solver(\n    sequence: np.ndarray,\n    *args,\n    flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n        max_flooring, eps=EPS\n    ),\n    overwrite: bool = True,\n) -> np.ndarray:\n    r\"\"\"Solve permutation of estimated spectrograms.\n\n    Group channels at each frequency bin according to correlations\n    between frequencies [#murata2001approach]_.\n\n    Args:\n        sequence (numpy.ndarray):\n            Array-like sequence of shape (n_bins, n_sources, n_frames).\n        args (tuple of numpy.ndarray, optional):\n            Positional arguments each of which is ``numpy.ndarray``.\n            The shapes of each item should be (n_bins, n_sources, \\*).\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to receive (n_channels, n_bins, n_frames)\n            and return (n_channels, n_bins, n_frames).\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``partial(max_flooring, eps=1e-10)``.\n        overwrite (bool):\n            Overwrite ``sequence`` and ``args`` if ``overwrite=True``.\n            Default: ``True``.\n\n    Returns:\n        - If ``args`` is not given, ``numpy.ndarray`` of permutated separated spectrograms\n            with shape of (n_sources, n_bins, n_frames) are returned.\n        - If one positional argument is given, ``numpy.ndarray``s of permutated separated\n            spectrograms and the permutated positional argument are returned.\n        - If more than two positional arguments are given, ``numpy.ndarray``s of\n            permutated separated spectrograms and the permutated positional arguments are returned.\n\n        .. [#murata2001approach]\n            N. Murata, S. Ikeda, and A. Ziehe,\n            \"An approach to blind source separation based on temporal structure of speech signals,\"\n            in *Neurocomputing*, vol. 41, no. 1, pp. 1-24, 2001.\n\n    .. note::\n\n        In this function, the shape of ``separated`` is expected ``(n_bins, n_sources, ...)``,\n        which is different from other functions.\n    \"\"\"\n    assert sequence.ndim == 3, \"Dimension of sequence is expected to be 3.\"\n\n    for pos_idx, arg in enumerate(args):\n        if arg.shape[:2] != sequence.shape[:2]:\n            raise ValueError(\"The shape of {}th argument is invalid.\".format(pos_idx + 1))\n\n    if overwrite:\n        Y = sequence\n        permutable = args\n    else:\n        Y = sequence.copy()\n\n        permutable = []\n\n        for arg in args:\n            permutable.append(arg.copy())\n\n        permutable = tuple(permutable)\n\n    if flooring_fn is None:\n        flooring_fn = identity\n    else:\n        flooring_fn = flooring_fn\n\n    n_bins, n_sources, _ = Y.shape\n\n    permutations = list(itertools.permutations(range(n_sources)))\n\n    P = np.abs(Y)\n    norm = np.sqrt(np.sum(P**2, axis=1, keepdims=True))\n    norm = flooring_fn(norm)\n    P = P / norm\n    correlation = np.sum(P @ P.transpose(0, 2, 1), axis=(1, 2))\n    indices = np.argsort(correlation)\n\n    min_idx = indices[0]\n    P_criteria = P[min_idx]\n\n    for bin_idx in range(1, n_bins):\n        min_idx = indices[bin_idx]\n        P_max = None\n        perm_max = None\n\n        for perm in permutations:\n            P_perm = np.sum(P_criteria * P[min_idx, perm, :])\n\n            if P_max is None or P_perm > P_max:\n                P_max = P_perm\n                perm_max = perm\n\n        P_criteria = P_criteria + P[min_idx, perm_max, :]\n        Y[min_idx, :] = Y[min_idx, perm_max]\n\n        for idx in range(len(permutable)):\n            permutable[idx][min_idx, :] = permutable[idx][min_idx, perm_max]\n\n    if len(permutable) == 0:\n        return Y\n    elif len(permutable) == 1:\n        return Y, permutable[0]\n    else:\n        return Y, permutable\n\n\ndef score_based_permutation_solver(\n    sequence: np.ndarray,\n    *args,\n    global_iter: int = 1,\n    local_iter: int = 1,\n    flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n        max_flooring, eps=EPS\n    ),\n    multi_centroids: bool = False,\n    overwrite: bool = True,\n) -> np.ndarray:\n    r\"\"\"Align permutations between frequencies based on score value [#sawada2010underdetermined]_.\n\n    Args:\n        sequence (numpy.ndarray):\n            Array-like sequence of shape (n_bins, n_sources, n_frames).\n        args (tuple of numpy.ndarray, optional):\n            Positional arguments each of which is ``numpy.ndarray``.\n            The shapes of each item should be (n_bins, n_sources, \\*).\n        global_iter (int):\n            Number of iterations in global optimization.\n        local_iter (int):\n            Number of iterations in local optimization.\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to receive (n_channels, n_bins, n_frames)\n            and return (n_channels, n_bins, n_frames).\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``partial(max_flooring, eps=1e-10)``.\n        multi_centroids (bool):\n            If ``multi_centroids=True``, multiple centroids are used in global optimization.\n            However, this is not supported now. Default: ``False``.\n        overwrite (bool):\n            Overwrite ``sequence`` and ``args`` if ``overwrite=True``.\n            Default: ``True``.\n\n    .. [#sawada2010underdetermined]\n        H. Sawada, S. Araki, and S. Makino,\n        \"Underdetermined convolutive blind source separation \\\n        via frequency bin-wise clustering and permutation alignment,\"\n        in *IEEE Trans. ASLP*, vol. 19, no. 3, pp. 516-527, 2010.\n    \"\"\"\n    assert sequence.ndim == 3, \"Dimension of sequence is expected to be 3.\"\n    assert not multi_centroids, \"multi_centroids version is not supported.\"\n\n    for pos_idx, arg in enumerate(args):\n        if arg.shape[:2] != sequence.shape[:2]:\n            raise ValueError(\"The shape of {}th argument is invalid.\".format(pos_idx + 1))\n\n    if overwrite:\n        permutable = args\n    else:\n        sequence = sequence.copy()\n\n        permutable = []\n\n        for arg in args:\n            permutable.append(arg.copy())\n\n        permutable = tuple(permutable)\n\n    if flooring_fn is None:\n        flooring_fn = identity\n    else:\n        flooring_fn = flooring_fn\n\n    n_bins, n_sources = sequence.shape[:2]\n    na = np.newaxis\n    eye = np.eye(n_sources)\n    permutations = np.array(list(itertools.permutations(range(n_sources))))\n\n    sequence_mean = sequence.mean(axis=-1, keepdims=True)\n    sequence_std = sequence.std(axis=-1, keepdims=True)\n    sequence_normalized = (sequence - sequence_mean) / sequence_std\n\n    for _ in range(global_iter):\n        centroid = sequence_normalized.mean(axis=0)\n        centroid_std = centroid.std(axis=-1, keepdims=True)\n        scores = []\n\n        for perm in permutations:\n            num = np.mean(sequence_normalized[:, perm, na] * centroid[na, :], axis=-1)\n            denom = flooring_fn(centroid_std)\n            corr = num / denom\n            score = np.sum(eye * corr - (1 - eye) * corr, axis=(1, 2))\n            scores.append(score)\n\n        scores = np.stack(scores, axis=1)\n        perm_max = np.argmax(scores, axis=1)\n        perm_max = permutations[perm_max]\n        sequence_normalized = _parallel_sort(sequence_normalized, perm_max)\n        sequence = _parallel_sort(sequence, perm_max)\n\n        for idx in range(len(permutable)):\n            permutable[idx][:] = _parallel_sort(permutable[idx], perm_max)\n\n    # local optimization\n    for _ in range(local_iter):\n        for bin_idx in range(n_bins):\n            min_idx = max(0, bin_idx - 3)\n            max_idx = min(n_bins - 1, bin_idx + 3)\n            covariant_indices = set(range(min_idx, bin_idx)) | set(range(bin_idx + 1, max_idx + 1))\n\n            min_idx = max(0, bin_idx // 2 - 1)\n            max_idx = min(n_bins - 1, bin_idx // 2 + 1)\n            covariant_indices |= set(range(min_idx, max_idx + 1))\n\n            min_idx = max(0, 2 * bin_idx - 1)\n            max_idx = min(n_bins - 1, 2 * bin_idx + 1)\n            covariant_indices |= set(range(min_idx, max_idx + 1))\n\n            # deterministic\n            covariant_indices = sorted(list(covariant_indices))\n            covariant_sequence = sequence_normalized[covariant_indices]\n\n            scores = []\n\n            for perm in permutations:\n                num = np.mean(\n                    sequence_normalized[bin_idx, perm, na] * covariant_sequence[:, na],\n                    axis=-1,\n                )\n                denom = flooring_fn(centroid_std)\n                corr = num / denom\n                score = np.sum(eye * corr - (1 - eye) * corr, axis=(1, 2))\n                score = score.sum(axis=0)\n                scores.append(score)\n\n            scores = np.stack(scores, axis=0)\n            perm_max = np.argmax(scores, axis=0)\n            perm_max = permutations[perm_max]\n            sequence_normalized[bin_idx] = sequence_normalized[bin_idx, perm_max]\n            sequence[bin_idx] = sequence[bin_idx, perm_max]\n\n            for idx in range(len(permutable)):\n                permutable[idx][bin_idx] = permutable[idx][bin_idx, perm_max]\n\n    if len(permutable) == 0:\n        return sequence\n    elif len(permutable) == 1:\n        return sequence, permutable[0]\n    else:\n        return sequence, permutable\n\n\ndef _parallel_sort(X: np.ndarray, indices: np.ndarray) -> np.ndarray:\n    shape = X.shape\n    idx = np.repeat(indices, repeats=np.prod(shape[2:]), axis=-1).reshape(shape)\n    X = np.take_along_axis(X, idx, axis=1)\n\n    return X\n"
  },
  {
    "path": "ssspy/algorithm/projection_back.py",
    "content": "from typing import Optional\n\nimport numpy as np\n\n\ndef projection_back(\n    data_or_filter: np.ndarray,\n    reference: Optional[np.ndarray] = None,\n    reference_id: Optional[int] = 0,\n) -> np.ndarray:\n    r\"\"\"Projection back technique to restore scale ambiguity.\n\n    The implementation is based on [#murata2001approach]_.\n\n    Args:\n        data_or_filter (numpy.ndarray):\n            Estimated spectrograms or demixing filters.\n        reference (numpy.ndarray, optional):\n            Reference spectrogram.\n        reference_id (int, optional):\n            Reference microphone index. Default: ``0``.\n\n    Returns:\n        numpy.ndarray of rescaled estimated spectrograms or demixing filters.\n\n    Examples:\n        When you give estimated spectrograms,\n\n        .. code-block:: python\n\n            >>> import numpy as np\n            >>> from ssspy.algorithm import projection_back\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> n_sources = n_channels\n            >>> rng = np.random.default_rng(42)\n\n            >>> spectrogram_mix = \\\n            ...     rng.standard_normal((n_channels, n_bins, n_frames)) \\\n            ...     + 1j * rng.standard_normal((n_channels, n_bins, n_frames))\n            >>> demix_filter = \\\n            ...     rng.standard_normal((n_sources, n_channels)) \\\n            ...     + 1j * rng.standard_normal((n_sources, n_channels))\n\n            >>> spectrogram_est = demix_filter @ spectrogram_mix.transpose(1, 0, 2)\n\n            >>> # (n_bins, n_sources, n_frames) -> (n_sources, n_bins, n_frames)\n            >>> spectrogram_est = spectrogram_est.transpose(1, 0, 2)\n\n            >>> spectrogram_est_scaled = \\\n            ...     projection_back(spectrogram_est, reference=spectrogram_mix, reference_id=0)\n            >>> spectrogram_est_scaled.shape\n            (2, 2049, 128)\n\n        When you give demixing filters,\n\n        .. code-block:: python\n\n            >>> import numpy as np\n            >>> from ssspy.algorithm import projection_back\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> n_sources = n_channels\n            >>> rng = np.random.default_rng(42)\n\n            >>> spectrogram_mix = \\\n            ...     rng.standard_normal((n_channels, n_bins, n_frames)) \\\n            ...     + 1j * rng.standard_normal((n_channels, n_bins, n_frames))\n            >>> demix_filter = \\\n            ...     rng.standard_normal((n_sources, n_channels)) \\\n            ...     + 1j * rng.standard_normal((n_sources, n_channels))\n\n            >>> demix_filter_scaled = projection_back(demix_filter, reference_id=0)\n\n            >>> spectrogram_est_scaled = demix_filter_scaled @ spectrogram_mix.transpose(1, 0, 2)\n\n            >>> # (n_bins, n_sources, n_frames) -> (n_sources, n_bins, n_frames)\n            >>> spectrogram_est_scaled = spectrogram_est_scaled.transpose(1, 0, 2)\n            >>> spectrogram_est_scaled.shape\n            (2, 2049, 128)\n\n    .. [#murata2001approach]\n        N. Murata, S. Ikeda, and A. Ziehe,\n        \"An approach to blind source separation based on temporal structure of speech signals,\"\n        *Neurocomputing*, vol. 41, no. 1-4, pp. 1-24, 2001.\n    \"\"\"\n    if reference is None:\n        W = data_or_filter  # (*, n_sources, n_channels)\n        scale = np.linalg.inv(W)  # (*, n_channels, n_sources)\n\n        if reference_id is None:\n            scale = scale[..., np.newaxis]  # (*, n_channels, n_sources, 1)\n            scale = np.rollaxis(scale, -3, 0)  # (n_channels, *, n_sources, 1)\n            demix_filter_scaled = W * scale  # (n_channels, *, n_sources, n_channels)\n        else:\n            scale = scale[..., reference_id, :]  # (*, n_sources)\n            demix_filter_scaled = W * scale[..., np.newaxis]  # (*, n_sources, n_channels)\n\n        return demix_filter_scaled\n    else:\n        Y = data_or_filter  # (n_sources, n_bins, n_frames)\n        X = reference  # (n_channels, n_bins, n_frames)\n\n        Y = Y.transpose(1, 0, 2)  # (n_bins, n_sources, n_frames)\n        X = X.transpose(1, 0, 2)  # (n_bins, n_channels, n_frames)\n        Y_Hermite = Y.transpose(0, 2, 1).conj()  # (n_bins, n_frames, n_sources)\n        XY_Hermite = X @ Y_Hermite  # (n_bins, n_channels, n_sources)\n        YY_Hermite = Y @ Y_Hermite  # (n_bins, n_sources, n_sources)\n\n        scale = XY_Hermite @ np.linalg.inv(YY_Hermite)  # (n_bins, n_channels, n_sources)\n\n        if reference_id is None:\n            scale = scale.transpose(1, 0, 2)  # (n_channels, n_bins, n_sources)\n            Y_scaled = Y * scale[..., np.newaxis]  # (n_channels, n_bins, n_sources, n_frames)\n            output_scaled = Y_scaled.swapaxes(-3, -2)  # (n_channels, n_sources, n_bins, n_frames)\n        else:\n            scale = scale[..., reference_id, :]  # (n_bins, n_sources)\n            Y_scaled = Y * scale[..., np.newaxis]  # (n_bins, n_sources, n_frames)\n            output_scaled = Y_scaled.swapaxes(-3, -2)  # (n_sources, n_bins, n_frames)\n\n        return output_scaled\n"
  },
  {
    "path": "ssspy/bss/__init__.py",
    "content": "from . import fdica, ica, ilrma, iva, mnmf\n\n__all__ = [\"ica\", \"fdica\", \"iva\", \"ilrma\", \"mnmf\"]\n"
  },
  {
    "path": "ssspy/bss/_flooring.py",
    "content": "import warnings\n\nimport numpy as np\n\nEPS = 1e-10\n\n\ndef identity(input: np.ndarray) -> np.ndarray:\n    r\"\"\"Identity function.\"\"\"\n    warnings.warn(\"Use ssspy.special.identity instead.\", FutureWarning)\n\n    return input\n\n\ndef max_flooring(input: np.ndarray, eps: float = EPS) -> np.ndarray:\n    r\"\"\"Max flooring operation.\"\"\"\n    warnings.warn(\"Use ssspy.special.max_flooring instead.\", FutureWarning)\n\n    return np.maximum(input, eps)\n\n\ndef add_flooring(input: np.ndarray, eps: float = EPS) -> np.ndarray:\n    r\"\"\"Add flooring operation.\"\"\"\n    warnings.warn(\"Use ssspy.special.add_flooring instead.\", FutureWarning)\n\n    return input + eps\n"
  },
  {
    "path": "ssspy/bss/_psd.py",
    "content": "import functools\nimport warnings\nfrom typing import Callable, Optional\n\nimport numpy as np\n\nfrom ..special.flooring import max_flooring\nfrom ..special.psd import to_psd as _to_psd\n\nEPS = 1e-10\n\n\ndef to_psd(\n    X: np.ndarray,\n    axis1: int = -2,\n    axis2: int = -1,\n    flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n        max_flooring, eps=EPS\n    ),\n) -> np.ndarray:\n    r\"\"\"Ensure matrix to be positive semidefinite.\n\n    Args:\n        X (np.ndarray):\n            A complex Hermitian matrix.\n        axis1 (int):\n            Axis to be used as first axis of 2D sub-arrays.\n        axis2 (int):\n            Axis to be used as second axis of 2D sub-arrays.\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n\n    Returns:\n        Positive semidefinite matrix.\n    \"\"\"\n    warnings.warn(\"Use ssspy.special.to_psd instead.\", FutureWarning)\n\n    return _to_psd(X, axis1=axis1, axis2=axis2, flooring_fn=flooring_fn)\n"
  },
  {
    "path": "ssspy/bss/_select_pair.py",
    "content": "import warnings\nfrom typing import Iterable, Optional, Tuple\n\nfrom ..utils.select_pair import combination_pair_selector as combination_pair_selector_base\nfrom ..utils.select_pair import sequential_pair_selector as sequential_pair_selector_base\n\n\ndef sequential_pair_selector(\n    n_sources: int, stop: Optional[int] = None, step: int = 1, sort: bool = False\n) -> Iterable[Tuple[int, int]]:\n    r\"\"\"Select pair in pairwise update.\n\n    Args:\n        n_sources (int):\n            Number of sources.\n        step (int):\n            This parameter determines step size.\n            For instance, if ``sequential_pair_selector(n_sources=6, step=2, sort=False)``,\n            this function yields ``0, 1``, ``2, 3``, ``4, 5``, ``0, 1``, ``2, 3``, ``4, 5``.\n            Default: ``1``.\n        sort (bool):\n            Sort pair to ensure :math:`m<n` if ``sort=True``.\n            Default: ``False``.\n\n    Yields:\n        Pair (tuple) of indices.\n\n    Examples:\n        .. code-block:: python\n\n            >>> for m, n in combination_pair_selector(4):\n            ...     print(m, n)\n            0 1\n            1 2\n            2 3\n            3 0\n    \"\"\"\n    warnings.warn(\"Use ssspy.utils.select_pair.sequential_pair_selector instead.\", UserWarning)\n\n    yield from sequential_pair_selector_base(n_sources, stop=stop, step=step, sort=sort)\n\n\ndef combination_pair_selector(n_sources: int, sort: bool = False) -> Iterable[Tuple[int, int]]:\n    r\"\"\"Select pair in pairwise update.\n\n    Args:\n        n_sources (int):\n            Number of sources.\n        sort (bool):\n            Sort pair to ensure :math:`m<n` if ``sort=True``.\n            Default: ``False``.\n\n    Yields:\n        Pair (tuple) of indices.\n\n    Examples:\n        .. code-block:: python\n\n            >>> for m, n in combination_pair_selector(4):\n            ...     print(m, n)\n            0 1\n            0 2\n            0 3\n            1 2\n            1 3\n            2 3\n    \"\"\"\n    warnings.warn(\"Use ssspy.utils.select_pair.combination_pair_selector instead.\", UserWarning)\n\n    yield from combination_pair_selector_base(n_sources, sort=sort)\n"
  },
  {
    "path": "ssspy/bss/_solve_permutation.py",
    "content": "import functools\nimport warnings\nfrom typing import Callable, Optional\n\nimport numpy as np\n\nfrom ..algorithm.permutation_alignment import (\n    correlation_based_permutation_solver as correlation_based_permutation_solver_base,\n)\nfrom ..special.flooring import max_flooring\n\nEPS = 1e-10\n\n\ndef correlation_based_permutation_solver(\n    separated: np.ndarray,\n    *args,\n    flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n        max_flooring, eps=EPS\n    ),\n    overwrite: bool = True,\n) -> np.ndarray:\n    r\"\"\"Solve permutaion of estimated spectrograms.\"\"\"\n\n    warnings.warn(\n        \"Use ssspy.algorithm.permutation_alignment.correlation_based_permutation_solver instead.\",\n        UserWarning,\n    )\n\n    return correlation_based_permutation_solver_base(\n        separated, *args, flooring_fn=flooring_fn, overwrite=overwrite\n    )\n"
  },
  {
    "path": "ssspy/bss/_update_spatial_model.py",
    "content": "import functools\nfrom typing import Callable, Iterable, Optional, Tuple\n\nimport numpy as np\n\nfrom ..linalg._solve import solve\nfrom ..linalg.eigh import eigh2\nfrom ..linalg.inv import inv2\nfrom ..linalg.lqpqm import lqpqm2\nfrom ..special.flooring import identity, max_flooring\nfrom ..special.psd import to_psd\nfrom ..utils.select_pair import sequential_pair_selector\n\nEPS = 1e-10\n\n\ndef update_by_ip1(\n    demix_filter: np.ndarray,\n    weighted_covariance: np.ndarray,\n    flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n        max_flooring, eps=EPS\n    ),\n    overwrite: bool = True,\n) -> np.ndarray:\n    r\"\"\"Update demixing filters by iterative projection.\n\n    Args:\n        demix_filter (numpy.ndarray):\n            Demixing filters to be updated.\n            The shape is (n_bins, n_sources, n_channels).\n        weighted_covariance (numpy.ndarray):\n            Weighted covariance matrix.\n            The shape is (n_bins, n_sources, n_channels, n_channels).\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n        overwrite (bool):\n            Overwrite ``demix_filter`` if ``overwrite=True``.\n            Default: ``True``.\n\n    Returns:\n        numpy.ndarray of updated demixing filters.\n        The shape is (n_bins, n_sources, n_channels).\n    \"\"\"\n    if flooring_fn is None:\n        flooring_fn = identity\n\n    if overwrite:\n        W = demix_filter\n    else:\n        W = demix_filter.copy()\n\n    U = weighted_covariance\n\n    n_bins, n_sources, n_channels = W.shape\n\n    E = np.eye(n_sources, n_channels)  # (n_sources, n_channels)\n    E = np.tile(E, reps=(n_bins, 1, 1))  # (n_bins, n_sources, n_channels)\n\n    for src_idx in range(n_sources):\n        w_n_Hermite = W[:, src_idx, :]  # (n_bins, n_channels)\n        U_n = U[:, src_idx, :, :]\n        e_n = E[:, src_idx, :]  # (n_bins, n_n_channels)\n\n        WU = W @ U_n\n        w_n = solve(WU, e_n)  # (n_bins, n_channels)\n        wUw = w_n[:, np.newaxis, :].conj() @ U_n @ w_n[:, :, np.newaxis]\n        wUw = np.real(wUw[..., 0])\n        wUw = np.maximum(wUw, 0)\n        denom = np.sqrt(wUw)\n        denom = flooring_fn(denom)\n        w_n_Hermite = w_n.conj() / denom\n        W[:, src_idx, :] = w_n_Hermite\n\n    return W\n\n\ndef update_by_ip2(\n    demix_filter: np.ndarray,\n    weighted_covariance: np.ndarray,\n    flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n        max_flooring, eps=EPS\n    ),\n    pair_selector: Optional[Callable[[int], Iterable[Tuple[int, int]]]] = None,\n    overwrite: bool = True,\n) -> np.ndarray:\n    r\"\"\"Update demixing filters by pairwise iterative projection [#ono2018fast]_.\n\n    Args:\n        demix_filter (numpy.ndarray):\n            Demixing filters to be updated.\n            The shape is (n_bins, n_sources, n_channels).\n        weighted_covariance (numpy.ndarray):\n            Weighted covariance matrix.\n            The shape is (n_bins, n_sources, n_channels, n_channels).\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n        pair_selector (callable, optional):\n            Selector to choose updaing pair.\n            If ``None`` is given, ``sequential_pair_selector`` is used.\n            Default: ``None``.\n        overwrite (bool):\n            Overwrite ``demix_filter`` if ``overwrite=True``.\n            Default: ``True``.\n\n    Returns:\n        numpy.ndarray of updated demixing filters.\n        The shape is (n_bins, n_sources, n_channels).\n\n    .. [#ono2018fast] N. Ono, \\\n        \"Fast algorithm for independent component/vector/low-rank matrix analysis \\\n        with three or more sources,\" \\\n        in *Proc. ASJ Spring meeting*, 2018 (in Japanese).\n    \"\"\"\n    if flooring_fn is None:\n        flooring_fn = identity\n\n    if pair_selector is None:\n        pair_selector = sequential_pair_selector\n\n    if overwrite:\n        W = demix_filter\n    else:\n        W = demix_filter.copy()\n\n    U = weighted_covariance\n\n    _, n_sources, _ = W.shape\n\n    for m, n in pair_selector(n_sources):\n        pair = (m, n)\n        W[:, pair, :] = update_by_ip2_one_pair(\n            W, U[:, pair, :, :], pair=pair, flooring_fn=flooring_fn\n        )\n\n    return W\n\n\ndef update_by_iss1(\n    separated: np.ndarray,\n    weight: np.ndarray,\n    flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n        max_flooring, eps=EPS\n    ),\n) -> np.ndarray:\n    r\"\"\"Update estimated spectrogram by iterative source steering.\n\n    Args:\n        separated (numpy.ndarray):\n            Estimated spectrograms to be updated.\n            The shape is (n_sources, n_bins, n_frames).\n        weight (numpy.ndarray):\n            Weights for estimated spectrogram.\n            The shape is (n_sources, n_bins, n_frames).\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n\n    Returns:\n        numpy.ndarray of updated spectrograms.\n        The shape is (n_sources, n_bins, n_frames).\n    \"\"\"\n    if flooring_fn is None:\n        flooring_fn = identity\n\n    Y = separated\n    varphi = weight\n\n    n_sources = Y.shape[0]\n\n    for src_idx in range(n_sources):\n        Y_n = Y[src_idx]  # (n_bins, n_frames)\n\n        YY_n_conj = Y * Y_n.conj()\n        YY_n = np.abs(Y_n) ** 2\n        num = np.mean(varphi * YY_n_conj, axis=-1)\n        denom = np.mean(varphi * YY_n, axis=-1)\n        denom = flooring_fn(denom)\n        v_n = num / denom\n        v_n[src_idx] = 1 - 1 / np.sqrt(denom[src_idx])\n\n        Y = Y - v_n[:, :, np.newaxis] * Y_n\n\n    return Y\n\n\ndef update_by_iss2(\n    separated: np.ndarray,\n    weight: np.ndarray,\n    flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n        max_flooring, eps=EPS\n    ),\n    pair_selector: Optional[Callable[[int], Iterable[Tuple[int, int]]]] = None,\n) -> np.ndarray:\n    r\"\"\"Update estimated spectrogram by pairwise iterative source steering.\n\n    Args:\n        separated (numpy.ndarray):\n            Estimated spectrograms to be updated.\n            The shape is (n_sources, n_bins, n_frames).\n        weight (numpy.ndarray):\n            Weights for estimated spectrogram.\n            The shape is (n_sources, n_bins, n_frames).\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n        pair_selector (callable, optional):\n            Selector to choose updaing pair.\n            If ``None`` is given, ``sequential_pair_selector`` is used.\n            Default: ``None``.\n\n    Returns:\n        numpy.ndarray of updated spectrograms.\n        The shape is (n_sources, n_bins, n_frames).\n    \"\"\"\n    Y = separated\n    varphi = weight\n\n    n_sources = Y.shape[0]\n\n    if flooring_fn is None:\n        flooring_fn = identity\n\n    if pair_selector is None:\n        pair_selector = functools.partial(sequential_pair_selector, stop=n_sources, step=2)\n\n    for m, n in pair_selector(n_sources):\n        if m < 0:\n            m = n_sources + m\n        if n < 0:\n            n = n_sources + n\n\n        if m > n:\n            ascend = False\n            m, n = n, m\n        else:\n            ascend = True\n\n        # Split into main and sub\n        Y_1, Y_m, Y_2, Y_n, Y_3 = np.split(Y, [m, m + 1, n, n + 1], axis=0)\n        Y_sub = np.concatenate([Y_1, Y_2, Y_3], axis=0)  # (n_sources - 2, n_bins, n_frames)\n        varphi_1, varphi_m, varphi_2, varphi_n, varphi_3 = np.split(\n            varphi, [m, m + 1, n, n + 1], axis=0\n        )\n        varphi_sub = np.concatenate([varphi_1, varphi_2, varphi_3], axis=0)\n\n        if ascend:\n            Y_main = np.concatenate([Y_m, Y_n], axis=0)  # (2, n_bins, n_frames)\n            varphi_main = np.concatenate([varphi_m, varphi_n], axis=0)\n        else:\n            Y_main = np.concatenate([Y_n, Y_m], axis=0)  # (2, n_bins, n_frames)\n            varphi_main = np.concatenate([varphi_n, varphi_m], axis=0)\n\n        YY_main = Y_main[:, np.newaxis, :, :] * Y_main[np.newaxis, :, :, :].conj()\n        YY_sub = Y_main[:, np.newaxis, :, :] * Y_sub[np.newaxis, :, :, :].conj()\n        YY_main = YY_main.transpose(2, 0, 1, 3)\n        YY_sub = YY_sub.transpose(1, 2, 0, 3)\n\n        Y_main = Y_main.transpose(1, 0, 2)\n\n        # Sub\n        G_sub = np.mean(\n            varphi_sub[:, :, np.newaxis, np.newaxis, :] * YY_main[np.newaxis, :, :, :, :],\n            axis=-1,\n        )\n        F = np.mean(varphi_sub[:, :, np.newaxis, :] * YY_sub, axis=-1)\n        Q = -inv2(G_sub) @ F[:, :, :, np.newaxis]\n        Q = Q.squeeze(axis=-1)\n        Q = Q.transpose(1, 0, 2)\n        QY = Q.conj() @ Y_main\n        Y_sub = Y_sub + QY.transpose(1, 0, 2)\n\n        # Main\n        G_main = np.mean(\n            varphi_main[:, :, np.newaxis, np.newaxis, :] * YY_main[np.newaxis, :, :, :, :],\n            axis=-1,\n        )\n        G_m, G_n = G_main\n        _, H_mn = eigh2(G_m, G_n)\n        h_mn = H_mn.transpose(2, 0, 1)\n        hGh_mn = h_mn[:, :, np.newaxis, :].conj() @ G_main @ h_mn[:, :, :, np.newaxis]\n        hGh_mn = np.squeeze(hGh_mn, axis=-1)\n        hGh_mn = np.real(hGh_mn)\n        hGh_mn = np.maximum(hGh_mn, 0)\n        denom_mn = np.sqrt(hGh_mn)\n        denom_mn = flooring_fn(denom_mn)\n        P = h_mn / denom_mn\n        P = P.transpose(1, 0, 2)\n        Y_main = P.conj() @ Y_main\n        Y_main = Y_main.transpose(1, 0, 2)\n\n        # Concat\n        Y_m, Y_n = np.split(Y_main, [1], axis=0)\n        Y1, Y2, Y3 = np.split(Y_sub, [m, n - 1], axis=0)\n\n        if ascend:\n            Y = np.concatenate([Y1, Y_m, Y2, Y_n, Y3], axis=0)\n        else:\n            Y = np.concatenate([Y1, Y_n, Y2, Y_m, Y3], axis=0)\n\n    return Y\n\n\ndef update_by_ip2_one_pair(\n    demix_filter: np.ndarray,\n    weighted_covariance_pair: np.ndarray,\n    pair: Tuple[int],\n    flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n        max_flooring, eps=EPS\n    ),\n) -> np.ndarray:\n    r\"\"\"Update demixing filters by pairwise iterative projection.\n\n    Args:\n        demix_filter (numpy.ndarray):\n            Demixing filters.\n            The shape is (n_bins, n_sources, n_channels).\n        weighted_covariance_pair (numpy.ndarray):\n            Weighted covariance matrix.\n            The shape is (n_bins, 2, n_channels, n_channels).\n        pair (tuple):\n            Pair of source index to be updated.\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n\n    Returns:\n        numpy.ndarray of updated demixing filter pair.\n        The shape is (n_bins, 2, n_channels).\n    \"\"\"\n    if flooring_fn is None:\n        flooring_fn = identity\n\n    m, n = pair\n    W = demix_filter\n    U_m, U_n = weighted_covariance_pair.transpose(1, 0, 2, 3)\n\n    n_bins, n_sources, n_channels = W.shape\n\n    E = np.eye(n_channels, n_sources)\n    E_mn = E[:, (m, n)]\n    E_mn = np.tile(E_mn, reps=(n_bins, 1, 1))\n\n    WU_m = W @ U_m\n    WU_n = W @ U_n\n\n    P_m = solve(WU_m, E_mn)\n    P_n = solve(WU_n, E_mn)\n\n    PUP_m = P_m.transpose(0, 2, 1).conj() @ U_m @ P_m\n    PUP_n = P_n.transpose(0, 2, 1).conj() @ U_n @ P_n\n\n    _, H_mn = eigh2(PUP_m, PUP_n)\n    H_mn = H_mn[..., ::-1]\n\n    H_mn = H_mn.transpose(2, 0, 1)\n    h_m, h_n = H_mn\n\n    hUh_m = h_m[:, np.newaxis, :].conj() @ PUP_m @ h_m[:, :, np.newaxis]\n    hUh_m = np.real(hUh_m[..., 0])\n    hUh_m = np.maximum(hUh_m, 0)\n    denom = np.sqrt(hUh_m)\n    denom = flooring_fn(denom)\n    h_m = h_m / denom\n\n    hUh_n = h_n[:, np.newaxis, :].conj() @ PUP_n @ h_n[:, :, np.newaxis]\n    hUh_n = np.real(hUh_n[..., 0])\n    hUh_n = np.maximum(hUh_n, 0)\n    denom = np.sqrt(hUh_n)\n    denom = flooring_fn(denom)\n    h_n = h_n / denom\n\n    w_m = P_m @ h_m[..., np.newaxis]\n    w_n = P_n @ h_n[..., np.newaxis]\n\n    W_mn_conj = np.concatenate([w_m, w_n], axis=-1)\n    W_mn = W_mn_conj.transpose(0, 2, 1).conj()\n\n    return W_mn\n\n\ndef update_by_ipa(\n    separated: np.ndarray,\n    weight: np.ndarray,\n    normalization: bool = True,\n    flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n        max_flooring, eps=EPS\n    ),\n    max_iter: int = 1,\n) -> np.ndarray:\n    r\"\"\"Update estimated spectrogram by iterative projection with adjustment (IPA).\n\n    Args:\n        separated (numpy.ndarray):\n            Estimated spectrograms to be updated.\n            The shape is (n_sources, n_bins, n_frames).\n        weight (numpy.ndarray):\n            Weights for estimated spectrogram.\n            The shape is (n_sources, n_bins, n_frames).\n        normalization (bool):\n            If ``normalization=True``, normalization is applied to LQPQM.\n            Default: ``True``.\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n        max_iter (int):\n            Maximum number of Newton-Raphson method. Default: ``1``.\n\n    Returns:\n        numpy.ndarray of estimated spectrograms of shape (n_sources, n_bins, n_frames).\n\n    \"\"\"\n    if flooring_fn is None:\n        flooring_fn = identity\n\n    Y = separated\n    varphi = weight\n\n    n_sources = Y.shape[0]\n\n    E = np.eye(n_sources)\n\n    for source_idx in range(n_sources):\n        YY_conj = Y[:, np.newaxis] * Y[np.newaxis, :].conj()\n        U_tilde = np.mean(varphi[:, np.newaxis, np.newaxis] * YY_conj, axis=-1)\n        U_tilde = U_tilde.transpose(3, 0, 1, 2)\n        U_tilde = to_psd(U_tilde, axis1=-2, axis2=-1, flooring_fn=flooring_fn)\n\n        E_n_left, e_n, E_n_right = np.split(E, [source_idx, source_idx + 1], axis=-1)\n        E_n = np.concatenate([E_n_left, E_n_right], axis=-1)\n\n        U_tilde_n = U_tilde[:, source_idx, :, :]\n        U_tilde_n_inverse = _psd_inv(U_tilde_n, flooring_fn=flooring_fn)\n        a_n = U_tilde[:, :, source_idx, source_idx]\n        a_n = np.real(a_n)\n        a_n = a_n @ E_n\n        b_n = np.diagonal(U_tilde[:, :, source_idx, :], axis1=-2, axis2=-1)\n        b_n = b_n @ E_n\n        d_n = E_n.transpose(1, 0) @ U_tilde_n_inverse.conj()\n        C_n = d_n @ E_n\n        d_n = d_n[:, :, source_idx]\n\n        Cd_n = solve(C_n, d_n)\n        dCd_n = np.sum(d_n.conj() * Cd_n, axis=-1)\n        dCd_n = np.real(dCd_n)\n        eUe_n = U_tilde_n_inverse[:, source_idx, source_idx]\n        eUe_n = np.real(eUe_n)\n        z_n = eUe_n - dCd_n\n\n        a_sqrt_n = np.sqrt(a_n)\n        aa_n = a_sqrt_n[:, :, np.newaxis] * a_sqrt_n[:, np.newaxis, :]\n        H_n = C_n / aa_n\n        v_n = -b_n / a_sqrt_n - a_sqrt_n * Cd_n\n\n        if normalization:\n            trace = np.trace(H_n, axis1=-2, axis2=-1)\n            trace = np.real(trace)\n\n            H_n = H_n / trace[..., np.newaxis, np.newaxis]\n            z_n = z_n / trace\n\n        q_check_n = lqpqm2(\n            H_n,\n            v_n,\n            z_n,\n            flooring_fn=flooring_fn,\n            singular_fn=lambda x: x < flooring_fn(0),\n            max_iter=max_iter,\n        )\n\n        q_n = q_check_n / a_sqrt_n - b_n / a_n\n\n        Eq_n = q_n.conj() @ E_n.transpose(1, 0)\n        q_tilde_n = e_n.transpose(1, 0) - Eq_n\n\n        Uq_n = solve(U_tilde_n, q_tilde_n)\n        qUq_n = np.sum(q_tilde_n.conj() * Uq_n, axis=-1, keepdims=True)\n\n        qUq_n = np.real(qUq_n)\n        qUq_n = np.maximum(qUq_n, 0)\n        denom = np.sqrt(qUq_n)\n        denom = flooring_fn(denom)\n        p_n = Uq_n / denom\n\n        Y_n = Y[source_idx]\n        p_n_conj = p_n.transpose(1, 0).conj()\n        PY_n = np.sum(p_n_conj[..., np.newaxis] * Y, axis=0)\n        PY_n = e_n[:, np.newaxis] * (PY_n - Y_n)\n        Eq_n = Eq_n.transpose(1, 0)\n        QY_n = Eq_n[:, :, np.newaxis] * Y_n\n\n        Y = Y + PY_n + QY_n\n\n    return Y\n\n\ndef update_by_block_decomposition_vcd(\n    demix_filter: np.ndarray,\n    weighted_covariance: np.ndarray,\n    singular_fn: Optional[Callable[[np.ndarray], np.ndarray]] = None,\n    overwrite: bool = True,\n) -> np.ndarray:\n    r\"\"\"\n    Args:\n        demix_filter (numpy.ndarray):\n            Demixing filters to be updated.\n            The shape is (n_blocks, n_neighbors, n_sources, n_channels).\n        weighted_covariance (numpy.ndarray):\n            Weighted covariance matrix.\n            The shape is (n_blocks, n_neighbors, n_neighbors, n_sources, n_channels, n_channels).\n        singular_fn (callable, optional):\n            A flooring function to return singular condition.\n            This function is expected to return the same shape bool tensor as the input.\n            If ``singular_fn=None``,``lambda x: x == 0`` is used.\n        overwrite (bool):\n            Overwrite ``demix_filter`` if ``overwrite=True``.\n            Default: ``True``.\n\n    Returns:\n        numpy.ndarray of updated demixing filters.\n        The shape is (n_blocks, n_neighbors, n_sources, n_channels).\n    \"\"\"\n    na = np.newaxis\n\n    if singular_fn is None:\n\n        def _is_zero(x: np.ndarray) -> np.ndarray:\n            return x == 0\n\n        singular_fn = _is_zero\n\n    if overwrite:\n        W = demix_filter\n    else:\n        W = demix_filter.copy()\n\n    RXX = weighted_covariance\n    U = np.diagonal(RXX, axis1=1, axis2=2)\n\n    n_blocks, n_neighbors, n_sources, n_channels = W.shape\n\n    E_i = np.eye(n_neighbors)\n    E_n = np.eye(n_sources)\n    E_n = np.tile(E_n, reps=(n_blocks, 1, 1))\n\n    for neighbor_idx in range(n_neighbors):\n        pad_mask_i = 1 - E_i[neighbor_idx]\n\n        U_i = U[:, :, :, :, neighbor_idx]\n        RXX_i = RXX[:, neighbor_idx]\n\n        for source_idx in range(n_sources):\n            e_n = E_n[:, source_idx, :]\n            U_in = U_i[:, source_idx, :, :]\n            RXX_in = RXX_i[:, :, source_idx]\n            w_n_conj = W[:, :, source_idx, :].conj()\n\n            RXY_in = RXX_in @ w_n_conj[:, :, :, na]\n\n            gamma_in = np.sum(pad_mask_i[:, na] * RXY_in[..., 0], axis=1)\n\n            WU_in = W[:, neighbor_idx, :, :] @ U_in\n            eta_in = solve(WU_in, e_n)\n            eta_hat_in = solve(U_in, gamma_in)\n            eta_U_in = eta_in[:, na, :].conj() @ U_in\n\n            xi_in = eta_U_in @ eta_in[:, :, na]\n            xi_hat_in = eta_U_in @ eta_hat_in[:, :, na]\n\n            xi_in = np.real(xi_in[..., 0])\n            xi_in = np.maximum(xi_in, 0)\n            xi_hat_in = xi_hat_in[..., 0]\n\n            singular_condition = singular_fn(xi_hat_in)\n\n            # to avoid zero division, but these will be ignored.\n            xi_hat_in[singular_condition] = 1\n\n            coeff = (xi_hat_in / (2 * xi_in)) * (\n                1 - np.sqrt(1 + 4 * xi_in / (np.abs(xi_hat_in) ** 2))\n            )\n            coeff_singular = 1 / np.sqrt(xi_in)\n            coeff = np.where(singular_condition, coeff_singular, coeff)\n\n            w_in = coeff * eta_in - eta_hat_in\n\n            W[:, neighbor_idx, source_idx, :] = w_in.conj()\n\n    return W\n\n\ndef _psd_inv(\n    X: np.ndarray,\n    flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n        max_flooring, eps=EPS\n    ),\n) -> np.ndarray:\n    \"\"\"Compute inversion of positive semidefinite matrix.\n\n    Args:\n        X (np.ndarray): Positive semidefinite matrix of shape (*, N, N).\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n\n    Returns:\n        np.ndarray: Inversion of input matrix.\n\n    \"\"\"\n    if flooring_fn is None:\n        flooring_fn = identity\n\n    Lamb, P = np.linalg.eigh(X)\n\n    P_Hermite = P.swapaxes(-2, -1)\n\n    if np.iscomplexobj(X):\n        P_Hermite = P_Hermite.conj()\n\n    Lamb_inv = 1 / flooring_fn(Lamb)\n    Lamb_inv = Lamb_inv[..., np.newaxis] * np.eye(Lamb.shape[-1])\n\n    return P @ Lamb_inv @ P_Hermite\n"
  },
  {
    "path": "ssspy/bss/admmbss.py",
    "content": "import warnings\nfrom typing import Callable, List, Optional, Union\n\nimport numpy as np\n\nfrom ..linalg import prox\nfrom ..linalg._solve import solve\nfrom .proxbss import ProxBSSBase\n\nEPS = 1e-10\n\n__all__ = [\"ADMMBSS\", \"MaskingADMMBSS\"]\n\n\nclass ADMMBSSBase(ProxBSSBase):\n    \"\"\"Base class of blind source separation via alternative direction method of multiplier.\n\n    Args:\n        penalty_fn (callable):\n            Penalty function that determines source model.\n        prox_penalty (callable):\n            Proximal operator of penalty function.\n            Default: ``None``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back`` explicitly.\n            Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the update algorithm if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back.\n            Default: ``0``.\n\n    \"\"\"\n\n    def __repr__(self) -> str:\n        s = \"ADMMBSS(\"\n        s += \"n_penalties={n_penalties}\".format(n_penalties=self.n_penalties)\n        s += \", scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n\nclass ADMMBSS(ADMMBSSBase):\n    \"\"\"Base class of blind source separation via alternative direction method of multiplier.\n\n    Args:\n        rho (float):\n            Penalty parameter. Default: ``1``.\n        alpha (float):\n            Relaxation parameter (deprecated). Set ``relaxation`` instead.\n        relaxation (float):\n            Relaxation parameter. Default: ``1``.\n        penalty_fn (callable):\n            Penalty function that determines source model.\n        prox_penalty (callable):\n            Proximal operator of penalty function.\n            Default: ``None``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back`` explicitly.\n            Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the update algorithm if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back.\n            Default: ``0``.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        rho: float = 1,\n        alpha: float = None,\n        relaxation: float = 1,\n        penalty_fn: Callable[[np.ndarray, np.ndarray], float] = None,\n        prox_penalty: Callable[[np.ndarray, float], np.ndarray] = None,\n        callbacks: Optional[\n            Union[Callable[[\"ADMMBSS\"], None], List[Callable[[\"ADMMBSS\"], None]]]\n        ] = None,\n        scale_restoration: bool = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n    ) -> None:\n        super().__init__(\n            penalty_fn=penalty_fn,\n            prox_penalty=prox_penalty,\n            callbacks=callbacks,\n            scale_restoration=scale_restoration,\n            record_loss=record_loss,\n            reference_id=reference_id,\n        )\n\n        self.rho = rho\n\n        if alpha is None:\n            self.relaxation = relaxation\n        else:\n            assert relaxation == 1, \"You cannot specify relaxation and alpha simultaneously.\"\n\n            warnings.warn(\"alpha is deprecated. Set relaxation instead.\", DeprecationWarning)\n\n            self.relaxation = alpha\n\n    def __call__(self, input, n_iter=100, initial_call: bool = True, **kwargs) -> np.ndarray:\n        r\"\"\"Separate a frequency-domain multichannel signal.\n\n        Args:\n            input (numpy.ndarray):\n                Mixture signal in frequency-domain.\n                The shape is (n_channels, n_bins, n_frames).\n            n_iter (int):\n                Number of iterations of demixing filter updates.\n                Default: ``100``.\n            initial_call (bool):\n                If ``True``, perform callbacks (and computation of loss if necessary)\n                before iterations.\n\n        Returns:\n            numpy.ndarray of the separated signal in frequency-domain.\n            The shape is (n_channels, n_bins, n_frames).\n        \"\"\"\n        self.input = input.copy()\n\n        self._reset(**kwargs)\n\n        # Call __call__ of ADMMBSSBase's parent, i.e. __call__ of IterativeMethodBase\n        super(ADMMBSSBase, self).__call__(n_iter=n_iter, initial_call=initial_call)\n\n        if self.scale_restoration:\n            self.restore_scale()\n\n        self.output = self.separate(self.input, demix_filter=self.demix_filter)\n\n        return self.output\n\n    def __repr__(self) -> str:\n        s = \"ADMMBSS(\"\n        s += \"rho={rho}\"\n        s += \", relaxation={relaxation}\"\n        s += \", n_penalties={n_penalties}\".format(n_penalties=self.n_penalties)\n        s += \", scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def _reset(self, **kwargs) -> None:\n        r\"\"\"Reset attributes by given keyword arguments.\n\n        Args:\n            kwargs:\n                Keyword arguments to set as attributes of ADMMBSS.\n        \"\"\"\n        if \"aux1\" in kwargs.keys():\n            warnings.warn(\"aux1 is deprecated. Use auxiliary1 instead.\", DeprecationWarning)\n\n            kwargs[\"auxiliary1\"] = kwargs.pop(\"aux1\")\n\n        if \"aux2\" in kwargs.keys():\n            warnings.warn(\"aux2 is deprecated. Use auxiliary2 instead.\", DeprecationWarning)\n\n            kwargs[\"auxiliary2\"] = kwargs.pop(\"aux2\")\n\n        super()._reset(**kwargs)\n\n        n_penalties = self.n_penalties\n        n_sources, n_channels = self.n_sources, self.n_channels\n        n_bins, n_frames = self.n_bins, self.n_frames\n\n        if not hasattr(self, \"auxiliary1\"):\n            auxiliary1 = np.zeros((n_bins, n_sources, n_channels), dtype=np.complex128)\n        else:\n            # To avoid overwriting ``auxiliary1`` given by keyword arguments.\n            auxiliary1 = self.auxiliary1.copy()\n\n        if not hasattr(self, \"auxiliary2\"):\n            auxiliary2 = np.zeros((n_penalties, n_sources, n_bins, n_frames), dtype=np.complex128)\n        else:\n            # To avoid overwriting ``auxiliary2`` given by keyword arguments.\n            auxiliary2 = self.auxiliary2.copy()\n\n        if not hasattr(self, \"dual1\"):\n            dual1 = np.zeros((n_bins, n_sources, n_channels), dtype=np.complex128)\n        else:\n            # To avoid overwriting ``dual1`` given by keyword arguments.\n            dual1 = self.dual1.copy()\n\n        if not hasattr(self, \"dual2\"):\n            dual2 = np.zeros((n_penalties, n_sources, n_bins, n_frames), dtype=np.complex128)\n        else:\n            # To avoid overwriting ``dual2`` given by keyword arguments.\n            dual2 = self.dual2.copy()\n\n        self.auxiliary1 = auxiliary1\n        self.auxiliary2 = auxiliary2\n        self.dual1 = dual1\n        self.dual2 = dual2\n\n    def update_once(self) -> None:\n        r\"\"\"Update demixing filters, auxiliary parameters, and dual parameters once.\"\"\"\n        n_penalties = self.n_penalties\n        n_channels = self.n_channels\n        rho, alpha = self.rho, self.relaxation\n\n        V, V_tilde = self.auxiliary1, self.auxiliary2\n        Y, Y_tilde = self.dual1, self.dual2\n        X, W = self.input, self.demix_filter\n\n        XX = X.transpose(1, 0, 2).conj() @ X.transpose(1, 2, 0)\n        E = np.eye(n_channels)\n        VY = V - Y\n        VY_tilde = np.sum(V_tilde - Y_tilde, axis=0)\n        XVY_tilde = X.transpose(1, 0, 2).conj() @ VY_tilde.transpose(1, 2, 0)\n\n        W = solve(n_penalties * XX + E, VY + XVY_tilde.transpose(0, 2, 1))\n        XW = self.separate(X, demix_filter=W)\n\n        U = alpha * W + (1 - alpha) * V\n        U_tilde = alpha * XW + (1 - alpha) * V_tilde\n\n        V = prox.neg_logdet(U + Y, step_size=1 / rho)\n\n        V_tilde = []\n\n        for U_tilde_q, Y_tilde_q, prox_penalty in zip(U_tilde, Y_tilde, self.prox_penalty):\n            V_tilde_q = prox_penalty(U_tilde_q + Y_tilde_q, step_size=1 / rho)\n            V_tilde.append(V_tilde_q)\n\n        V_tilde = np.stack(V_tilde, axis=0)\n\n        Y = Y + U - V\n        Y_tilde = Y_tilde + U_tilde - V_tilde\n\n        self.auxiliary1, self.auxiliary2 = V, V_tilde\n        self.dual1, self.dual2 = Y, Y_tilde\n        self.demix_filter = W\n\n\nclass MaskingADMMBSS(ADMMBSSBase):\n    \"\"\"Blind source separation via alternative direction method of multiplier\n    with masking function.\n\n    Args:\n        rho (float): Penalty parameter. Default: ``1``.\n        alpha (float): Relaxation parameter (deprecated). Set ``relaxation`` instead.\n        relaxation (float): Relaxation parameter. Default: ``1``.\n        penalty_fn (callable): Penalty function that determines source model.\n        mask_fn (callable): Proximal operator of penalty function. Default: ``None``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        scale_restoration (bool or str): Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back`` explicitly.\n            Default: ``True``.\n        record_loss (bool, optional): Record the loss at each iteration of the update algorithm if\n            ``record_loss=True``. Default: ``None``.\n        reference_id (int): Reference channel for projection back. Default: ``0``.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        rho: float = 1,\n        alpha: float = None,\n        relaxation: float = 1,\n        penalty_fn: Callable[[np.ndarray, np.ndarray], float] = None,\n        mask_fn: Callable[[np.ndarray], float] = None,\n        callbacks: Optional[\n            Union[Callable[[\"MaskingADMMBSS\"], None], List[Callable[[\"MaskingADMMBSS\"], None]]]\n        ] = None,\n        scale_restoration: bool = True,\n        record_loss: Optional[bool] = None,\n        reference_id: int = 0,\n    ) -> None:\n        super(ProxBSSBase, self).__init__(\n            callbacks=callbacks,\n            record_loss=record_loss,\n        )\n\n        if penalty_fn is None:\n            # Since penalty_fn is not necessarily written in closed form,\n            # None is acceptable.\n            if record_loss is None:\n                record_loss = False\n\n            assert not record_loss, \"To record loss, set penalty_fn.\"\n        else:\n            assert callable(penalty_fn), \"penalty_fn should be callable.\"\n\n            if record_loss is None:\n                record_loss = True\n\n        if mask_fn is None:\n            raise ValueError(\"Specify masking function.\")\n        else:\n            assert callable(mask_fn), \"mask_fn should be callable.\"\n\n        self.penalty_fn = penalty_fn\n        self.mask_fn = mask_fn\n\n        self.input = None\n        self.scale_restoration = scale_restoration\n\n        if reference_id is None and scale_restoration:\n            raise ValueError(\"Specify 'reference_id' if scale_restoration=True.\")\n        else:\n            self.reference_id = reference_id\n\n        self.rho = rho\n\n        if alpha is None:\n            self.relaxation = relaxation\n        else:\n            assert relaxation == 1, \"You cannot specify relaxation and alpha simultaneously.\"\n\n            warnings.warn(\"alpha is deprecated. Set relaxation instead.\", DeprecationWarning)\n\n            self.relaxation = alpha\n\n    def __call__(self, input, n_iter=100, initial_call: bool = True, **kwargs) -> np.ndarray:\n        self.input = input.copy()\n\n        self._reset(**kwargs)\n\n        # Call __call__ of ADMMBSSBase's parent, i.e. __call__ of IterativeMethodBase\n        super(ADMMBSSBase, self).__call__(n_iter=n_iter, initial_call=initial_call)\n\n        if self.scale_restoration:\n            self.restore_scale()\n\n        self.output = self.separate(self.input, demix_filter=self.demix_filter)\n\n        return self.output\n\n    def _reset(self, **kwargs) -> None:\n        r\"\"\"Reset attributes by given keyword arguments.\n\n        Args:\n            kwargs:\n                Keyword arguments to set as attributes of ADMMBSS.\n        \"\"\"\n        if \"aux1\" in kwargs.keys():\n            warnings.warn(\"aux1 is deprecated. Use auxiliary1 instead.\", DeprecationWarning)\n\n            kwargs[\"auxiliary1\"] = kwargs.pop(\"aux1\")\n\n        if \"aux2\" in kwargs.keys():\n            warnings.warn(\"aux2 is deprecated. Use auxiliary2 instead.\", DeprecationWarning)\n\n            kwargs[\"auxiliary2\"] = kwargs.pop(\"aux2\")\n\n        super()._reset(**kwargs)\n\n        assert self.n_penalties == 1, \"Number of penalty function should be one.\"\n\n        n_sources, n_channels = self.n_sources, self.n_channels\n        n_bins, n_frames = self.n_bins, self.n_frames\n\n        if not hasattr(self, \"auxiliary1\"):\n            auxiliary1 = np.zeros((n_bins, n_sources, n_channels), dtype=np.complex128)\n        else:\n            # To avoid overwriting ``auxiliary1`` given by keyword arguments.\n            auxiliary1 = self.auxiliary1.copy()\n\n        if not hasattr(self, \"auxiliary2\"):\n            auxiliary2 = np.zeros((n_sources, n_bins, n_frames), dtype=np.complex128)\n        else:\n            # To avoid overwriting ``auxiliary2`` given by keyword arguments.\n            auxiliary2 = self.auxiliary2.copy()\n\n        if not hasattr(self, \"dual1\"):\n            dual1 = np.zeros((n_bins, n_sources, n_channels), dtype=np.complex128)\n        else:\n            # To avoid overwriting ``dual1`` given by keyword arguments.\n            dual1 = self.dual1.copy()\n\n        if not hasattr(self, \"dual2\"):\n            dual2 = np.zeros((n_sources, n_bins, n_frames), dtype=np.complex128)\n        else:\n            # To avoid overwriting ``dual2`` given by keyword arguments.\n            dual2 = self.dual2.copy()\n\n        self.auxiliary1 = auxiliary1\n        self.auxiliary2 = auxiliary2\n        self.dual1 = dual1\n        self.dual2 = dual2\n\n    @property\n    def n_penalties(self) -> int:\n        r\"\"\"Return number of penalty terms.\"\"\"\n        return 1\n\n    def update_once(self) -> None:\n        r\"\"\"Update demixing filters, auxiliary parameters, and dual parameters once.\"\"\"\n        n_channels = self.n_channels\n        rho, alpha = self.rho, self.relaxation\n\n        V, V_tilde = self.auxiliary1, self.auxiliary2\n        Y, Y_tilde = self.dual1, self.dual2\n        X, W = self.input, self.demix_filter\n\n        XX = X.transpose(1, 0, 2).conj() @ X.transpose(1, 2, 0)\n        E = np.eye(n_channels)\n        VY = V - Y\n        VY_tilde = V_tilde - Y_tilde\n        XVY_tilde = X.transpose(1, 0, 2).conj() @ VY_tilde.transpose(1, 2, 0)\n\n        W = solve(XX + E, VY + XVY_tilde.transpose(0, 2, 1))\n        XW = self.separate(X, demix_filter=W)\n\n        U = alpha * W + (1 - alpha) * V\n        U_tilde = alpha * XW + (1 - alpha) * V_tilde\n        V = prox.neg_logdet(U + Y, step_size=1 / rho)\n        V_tilde = self.mask_fn(U_tilde + Y_tilde) * (U_tilde + Y_tilde)\n        Y = Y + U - V\n        Y_tilde = Y_tilde + U_tilde - V_tilde\n\n        self.auxiliary1, self.auxiliary2 = V, V_tilde\n        self.dual1, self.dual2 = Y, Y_tilde\n        self.demix_filter = W\n"
  },
  {
    "path": "ssspy/bss/base.py",
    "content": "from typing import Callable, List, Optional, Union\n\nimport numpy as np\n\n__all__ = [\n    \"IterativeMethodBase\",\n]\n\n\nclass IterativeMethodBase:\n    r\"\"\"Base class of iterative method.\n\n    This class provides prototype of iterative updates.\n\n    Args:\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        record_loss (bool):\n            Record the loss at each iteration of the update algorithm if ``record_loss=True``.\n            Default: ``True``.\n    \"\"\"\n\n    def __init__(\n        self,\n        callbacks: Optional[\n            Union[\n                Callable[[\"IterativeMethodBase\"], None],\n                List[Callable[[\"IterativeMethodBase\"], None]],\n            ]\n        ] = None,\n        record_loss: bool = True,\n    ) -> None:\n        if callbacks is not None:\n            if callable(callbacks):\n                callbacks = [callbacks]\n            self.callbacks = callbacks\n        else:\n            self.callbacks = None\n\n        self.record_loss = record_loss\n\n        if self.record_loss:\n            self.loss = []\n        else:\n            self.loss = None\n\n    def __call__(self, *args, n_iter: int = 100, initial_call: bool = True, **kwargs) -> np.ndarray:\n        r\"\"\"Iteratively call ``update_once``.\n\n        Args:\n            n_iter (int):\n                The number of iterations of demixing filter updates.\n                Default: ``100``.\n            initial_call (bool):\n                If ``True``, perform callbacks (and computation of loss if necessary)\n                before iterations.\n        \"\"\"\n        if initial_call:\n            if self.record_loss:\n                loss = self.compute_loss()\n                self.loss.append(loss)\n\n            if self.callbacks is not None:\n                for callback in self.callbacks:\n                    callback(self)\n\n        for _ in range(n_iter):\n            self.update_once()\n\n            if self.record_loss:\n                loss = self.compute_loss()\n                self.loss.append(loss)\n\n            if self.callbacks is not None:\n                for callback in self.callbacks:\n                    callback(self)\n\n    def update_once(self) -> None:\n        r\"\"\"Update parameters once.\"\"\"\n        raise NotImplementedError(\"Implement 'update_once' method.\")\n\n    def compute_loss(self) -> float:\n        r\"\"\"Compute loss.\n\n        Returns:\n            Computed loss. The type is expected ``float``.\n        \"\"\"\n        raise NotImplementedError(\"Implement 'compute_loss' method.\")\n"
  },
  {
    "path": "ssspy/bss/cacgmm.py",
    "content": "import functools\nfrom typing import Callable, List, Optional, Union\n\nimport numpy as np\n\nfrom ..algorithm.permutation_alignment import (\n    correlation_based_permutation_solver,\n    score_based_permutation_solver,\n)\nfrom ..linalg.quadratic import quadratic\nfrom ..special.flooring import identity, max_flooring\nfrom ..special.logsumexp import logsumexp\nfrom ..special.psd import to_psd\nfrom ..special.softmax import softmax\nfrom ..utils.flooring import choose_flooring_fn\nfrom .base import IterativeMethodBase\n\nEPS = 1e-10\n\n\nclass CACGMMBase(IterativeMethodBase):\n    r\"\"\"Base class of complex angular central Gaussian mixture model (cACGMM).\n\n    Args:\n        n_sources (int, optional):\n            Number of sources to be separated.\n            If ``None`` is given, ``n_sources`` is determined by number of channels\n            in input spectrogram. Default: ``None``.\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        record_loss (bool):\n            Record the loss at each iteration of the update algorithm if ``record_loss=True``.\n            Default: ``True``.\n        rng (numpy.random.Generator, optioinal):\n            Random number generator. This is mainly used to randomly initialize parameters\n            of cACGMM. If ``None`` is given, ``np.random.default_rng()`` is used.\n            Default: ``None``.\n    \"\"\"\n\n    def __init__(\n        self,\n        n_sources: Optional[int] = None,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        callbacks: Optional[\n            Union[\n                Callable[[\"CACGMMBase\"], None],\n                List[Callable[[\"CACGMMBase\"], None]],\n            ]\n        ] = None,\n        record_loss: bool = True,\n        rng: Optional[np.random.Generator] = None,\n    ) -> None:\n        self.normalization: bool\n        self.permutation_alignment: bool\n\n        super().__init__(callbacks=callbacks, record_loss=record_loss)\n\n        self.n_sources = n_sources\n\n        if flooring_fn is None:\n            self.flooring_fn = identity\n        else:\n            self.flooring_fn = flooring_fn\n\n        if rng is None:\n            rng = np.random.default_rng()\n\n        self.rng = rng\n\n    def __call__(\n        self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs\n    ) -> np.ndarray:\n        r\"\"\"Separate a frequency-domain multichannel signal.\n\n        Args:\n            input (numpy.ndarray):\n                The mixture signal in frequency-domain.\n                The shape is (n_channels, n_bins, n_frames).\n            n_iter (int):\n                The number of iterations of demixing filter updates.\n                Default: ``100``.\n            initial_call (bool):\n                If ``True``, perform callbacks (and computation of loss if necessary)\n                before iterations.\n\n        Returns:\n            numpy.ndarray of the separated signal in frequency-domain.\n            The shape is (n_channels, n_bins, n_frames).\n        \"\"\"\n        self.input = input.copy()\n\n        self._reset(**kwargs)\n\n        raise NotImplementedError(\"Implement '__call__' method.\")\n\n    def __repr__(self) -> str:\n        s = \"CACGMM(\"\n\n        if self.n_sources is not None:\n            s += \"n_sources={n_sources}, \"\n\n        s += \"record_loss={record_loss}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def _reset(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n        **kwargs,\n    ) -> None:\n        r\"\"\"Reset attributes by given keyword arguments.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n            kwargs:\n                Keyword arguments to set as attributes of CACGMM.\n        \"\"\"\n        assert self.input is not None, \"Specify data!\"\n\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        for key in kwargs.keys():\n            setattr(self, key, kwargs[key])\n\n        X = self.input\n\n        norm = np.linalg.norm(X, axis=0)\n        Z = X / flooring_fn(norm)\n        self.unit_input = Z\n\n        n_sources = self.n_sources\n        n_channels, n_bins, n_frames = X.shape\n\n        if n_sources is None:\n            n_sources = n_channels\n\n        self.n_sources, self.n_channels = n_sources, n_channels\n        self.n_bins, self.n_frames = n_bins, n_frames\n\n        self._init_parameters(rng=self.rng)\n\n    def _init_parameters(self, rng: Optional[np.random.Generator] = None) -> None:\n        r\"\"\"Initialize parameters of cACGMM.\n\n        Args:\n            rng (numpy.random.Generator, optional):\n                Random number generator. If ``None`` is given,\n                ``np.random.default_rng()`` is used.\n                Default: ``None``.\n\n        .. note::\n\n            Custom initialization is not supported now.\n\n        \"\"\"\n        n_sources, n_channels = self.n_sources, self.n_channels\n        n_bins = self.n_bins\n\n        if rng is None:\n            rng = np.random.default_rng()\n\n        alpha = rng.random((n_sources, n_bins))\n        alpha = alpha / alpha.sum(axis=0)\n\n        eye = np.eye(n_channels, dtype=np.complex128)\n        B_diag = self.rng.random((n_sources, n_bins, n_channels))\n        B_diag = B_diag / B_diag.sum(axis=-1, keepdims=True)\n        B = B_diag[:, :, :, np.newaxis] * eye\n\n        self.mixing = alpha\n        self.covariance = B\n\n        # The shape of posterior is (n_sources, n_bins, n_frames).\n        # This is always required to satisfy posterior.sum(axis=0) = 1\n        self.posterior = None\n\n    def separate(self, input: np.ndarray) -> np.ndarray:\n        r\"\"\"Separate ``input``.\n\n        Args:\n            input (numpy.ndarray):\n                The mixture signal in frequency-domain.\n                The shape is (n_channels, n_bins, n_frames).\n\n        Returns:\n            numpy.ndarray of the separated signal in frequency-domain.\n            The shape is (n_sources, n_bins, n_frames).\n        \"\"\"\n        raise NotImplementedError(\"Implement 'separate' method.\")\n\n    def normalize_covariance(self) -> None:\n        r\"\"\"Normalize covariance of cACG.\n\n        .. math::\n            \\boldsymbol{B}_{in}\n            \\leftarrow\\frac{\\boldsymbol{B}_{in}}{\\mathrm{tr}(\\boldsymbol{B}_{in})}\n        \"\"\"\n        assert self.normalization, \"Set normalization.\"\n\n        B = self.covariance\n\n        trace = np.trace(B, axis1=-2, axis2=-1)\n        trace = np.real(trace)\n        B = B / trace[..., np.newaxis, np.newaxis]\n\n        self.covariance = B\n\n    def compute_loss(self) -> float:\n        r\"\"\"Compute loss :math:`\\mathcal{L}`.\n\n        Returns:\n            Computed loss.\n        \"\"\"\n        raise NotImplementedError(\"Implement 'compute_loss' method.\")\n\n    def compute_logdet(self, covariance: np.ndarray) -> np.ndarray:\n        r\"\"\"Compute log-determinant of input.\n\n        Args:\n            covariance (numpy.ndarray):\n                Covariance matrix with shape of (n_sources, n_bins, n_channels, n_channels).\n\n        Returns:\n            numpy.ndarray of log-determinant.\n        \"\"\"\n        _, logdet = np.linalg.slogdet(covariance)\n\n        return logdet\n\n    def solve_permutation(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Align posteriors and separated spectrograms.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n\n        permutation_alignment = self.permutation_alignment\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        assert permutation_alignment, \"Set permutation_alignment=True.\"\n\n        if type(permutation_alignment) is bool:\n            # when permutation_alignment is True\n            permutation_alignment = \"posterior_score\"\n\n        if permutation_alignment in [\"posterior_score\", \"posterior_correlation\"]:\n            target = \"posterior\"\n        elif permutation_alignment in [\"amplitude_score\", \"amplitude_correlation\"]:\n            target = \"amplitude\"\n        else:\n            raise NotImplementedError(\n                \"permutation_alignment {} is not implemented.\".format(permutation_alignment)\n            )\n\n        if permutation_alignment in [\"posterior_score\", \"amplitude_score\"]:\n            self.solve_permutation_by_score(target=target, flooring_fn=flooring_fn)\n        elif permutation_alignment in [\"posterior_correlation\", \"amplitude_correlation\"]:\n            self.solve_permutation_by_correlation(target=target, flooring_fn=flooring_fn)\n        else:\n            raise NotImplementedError(\n                \"permutation_alignment {} is not implemented.\".format(permutation_alignment)\n            )\n\n    def solve_permutation_by_score(\n        self,\n        target: str = \"posterior\",\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Align posteriors and amplitudes of separated spectrograms by score value.\n\n        Args:\n            target (str):\n                Target to compute score values. Choose ``posterior`` or ``amplitude``.\n                Default: ``posterior``.\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n        \"\"\"\n\n        assert target in [\"posterior\", \"amplitude\"], \"Invalid target {} is specified.\".format(\n            target\n        )\n\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        X = self.input\n        alpha = self.mixing\n        B = self.covariance\n        gamma = self.posterior\n\n        if hasattr(self, \"global_iter\"):\n            global_iter = self.global_iter\n        else:\n            global_iter = 1\n\n        if hasattr(self, \"local_iter\"):\n            local_iter = self.local_iter\n        else:\n            local_iter = 1\n\n        Y = self.separate(X, posterior=gamma)\n\n        alpha = alpha.transpose(1, 0)\n        B = B.transpose(1, 0, 2, 3)\n        gamma = gamma.transpose(1, 0, 2)\n\n        if target == \"posterior\":\n            gamma, (alpha, B) = score_based_permutation_solver(\n                gamma,\n                alpha,\n                B,\n                global_iter=global_iter,\n                local_iter=local_iter,\n                flooring_fn=flooring_fn,\n            )\n        elif target == \"amplitude\":\n            Y = Y.transpose(1, 0, 2)\n            amplitude = np.abs(Y)\n\n            _, (alpha, B, gamma) = score_based_permutation_solver(\n                amplitude,\n                alpha,\n                B,\n                gamma,\n                global_iter=global_iter,\n                local_iter=local_iter,\n                flooring_fn=flooring_fn,\n            )\n        else:\n            raise ValueError(\"Invalid target {} is specified.\".format(target))\n\n        alpha = alpha.transpose(1, 0)\n        B = B.transpose(1, 0, 2, 3)\n        gamma = gamma.transpose(1, 0, 2)\n\n        Y = self.separate(X, posterior=gamma)\n\n        self.mixing = alpha\n        self.covariance = B\n        self.posterior = gamma\n        self.output = Y\n\n    def solve_permutation_by_correlation(\n        self,\n        target: str = \"amplitude\",\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Align posteriors and amplitudes of separated spectrograms by correlation.\n\n        Args:\n            target (str):\n                Target to compute correlations. Choose ``posterior`` or ``amplitude``.\n                Default: ``amplitude``.\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n\n        assert target == \"amplitude\", \"Only amplitude is supported as target.\"\n\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        X = self.input\n        alpha = self.mixing\n        B = self.covariance\n        gamma = self.posterior\n\n        Y = self.separate(X, posterior=self.posterior)\n\n        alpha = alpha.transpose(1, 0)\n        B = B.transpose(1, 0, 2, 3)\n        gamma = gamma.transpose(1, 0, 2)\n        Y = Y.transpose(1, 0, 2)\n        Y, (alpha, B, gamma) = correlation_based_permutation_solver(\n            Y, alpha, B, gamma, flooring_fn=flooring_fn\n        )\n        alpha = alpha.transpose(1, 0)\n        B = B.transpose(1, 0, 2, 3)\n        gamma = gamma.transpose(1, 0, 2)\n        Y = Y.transpose(1, 0, 2)\n\n        self.mixing = alpha\n        self.covariance = B\n        self.posterior = gamma\n        self.output = Y\n\n\nclass CACGMM(CACGMMBase):\n    r\"\"\"Complex angular central Gaussian mixture model (cACGMM) [#ito2016complex]_.\n\n    Args:\n        n_sources (int, optional):\n            Number of sources to be separated.\n            If ``None`` is given, ``n_sources`` is determined by number of channels\n            in input spectrogram. Default: ``None``.\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        normalization (bool):\n            If ``True`` is given, normalization is applied to covariance in cACG.\n        permutation_alignment (bool):\n            If ``permutation_alignment=True``, a permutation solver is used to align\n            estimated spectrograms. Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the update algorithm if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel to extract separated signals. Default: ``0``.\n        rng (numpy.random.Generator, optioinal):\n            Random number generator. This is mainly used to randomly initialize parameters\n            of cACGMM. If ``None`` is given, ``np.random.default_rng()`` is used.\n            Default: ``None``.\n\n    .. [#ito2016complex] N. Ito, S. Araki, and T. Nakatani. \\\n        \"Complex angular central Gaussian mixture model for directional statistics \\\n        in mask-based microphone array signal processing,\"\n        in *Proc. EUSIPCO*, 2016, pp. 1153-1157.\n    \"\"\"\n\n    def __init__(\n        self,\n        n_sources: Optional[int] = None,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        callbacks: Optional[\n            Union[\n                Callable[[\"CACGMM\"], None],\n                List[Callable[[\"CACGMM\"], None]],\n            ]\n        ] = None,\n        normalization: bool = True,\n        permutation_alignment: bool = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n        rng: Optional[np.random.Generator] = None,\n        **kwargs,\n    ) -> None:\n        super().__init__(\n            n_sources=n_sources,\n            flooring_fn=flooring_fn,\n            callbacks=callbacks,\n            record_loss=record_loss,\n            rng=rng,\n        )\n\n        self.normalization = normalization\n        self.permutation_alignment = permutation_alignment\n        self.reference_id = reference_id\n\n        if type(permutation_alignment) is bool and permutation_alignment:\n            valid_keys = {\"global_iter\", \"local_iter\"}\n        elif type(permutation_alignment) is str and permutation_alignment in [\n            \"posterior_score\",\n            \"amplitude_score\",\n        ]:\n            valid_keys = {\"global_iter\", \"local_iter\"}\n        else:\n            valid_keys = set()\n\n        invalid_keys = set(kwargs) - valid_keys\n\n        assert invalid_keys == set(), \"Invalid keywords {} are given.\".format(invalid_keys)\n\n        for key, value in kwargs.items():\n            setattr(self, key, value)\n\n    def __call__(\n        self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs\n    ) -> np.ndarray:\n        r\"\"\"Separate a frequency-domain multichannel signal.\n\n        Args:\n            input (numpy.ndarray):\n                The mixture signal in frequency-domain.\n                The shape is (n_channels, n_bins, n_frames).\n            n_iter (int):\n                The number of iterations of demixing filter updates.\n                Default: ``100``.\n            initial_call (bool):\n                If ``True``, perform callbacks (and computation of loss if necessary)\n                before iterations.\n\n        Returns:\n            numpy.ndarray of the separated signal in frequency-domain.\n            The shape is (n_channels, n_bins, n_frames).\n        \"\"\"\n        self.input = input.copy()\n\n        self._reset(flooring_fn=self.flooring_fn, **kwargs)\n\n        # Call __call__ of CACGMMBase's parent, i.e. __call__ of IterativeMethodBase\n        super(CACGMMBase, self).__call__(n_iter=n_iter, initial_call=initial_call)\n\n        # posterior should be updated\n        self.update_posterior(flooring_fn=self.flooring_fn)\n\n        if self.permutation_alignment:\n            self.solve_permutation(flooring_fn=self.flooring_fn)\n\n        X = self.input\n        self.output = self.separate(X, posterior=self.posterior)\n\n        return self.output\n\n    def __repr__(self) -> str:\n        s = \"CACGMM(\"\n\n        if self.n_sources is not None:\n            s += \"n_sources={n_sources}, \"\n\n        s += \"record_loss={record_loss}\"\n        s += \", normalization={normalization}\"\n        s += \", permutation_alignment={permutation_alignment}\"\n        s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def separate(self, input: np.ndarray, posterior: Optional[np.ndarray] = None) -> np.ndarray:\n        r\"\"\"Separate ``input`` using posterior probabilities.\n\n        In this method, ``self.posterior`` is not updated.\n\n        Args:\n            input (numpy.ndarray):\n                The mixture signal in frequency-domain.\n                The shape is (n_channels, n_bins, n_frames).\n            posterior (numpy.ndarray, optional):\n                Posterior probability. If not specified, ``posterior`` is computed by current\n                parameters.\n\n        Returns:\n            numpy.ndarray of the separated signal in frequency-domain.\n            The shape is (n_sources, n_bins, n_frames).\n        \"\"\"\n        X = input\n\n        if posterior is None:\n            alpha = self.mixing\n            Z = self.unit_input\n            B = self.covariance\n\n            Z = Z.transpose(1, 2, 0)\n            B_inverse = np.linalg.inv(B)\n            ZBZ = quadratic(Z, B_inverse[:, :, np.newaxis])\n            ZBZ = np.real(ZBZ)\n            ZBZ = np.maximum(ZBZ, 0)\n            ZBZ = self.flooring_fn(ZBZ)\n\n            log_alpha = np.log(alpha)\n            _, logdet = np.linalg.slogdet(B)\n            log_prob = log_alpha - logdet\n            log_gamma = log_prob[:, :, np.newaxis] - self.n_channels * np.log(ZBZ)\n\n            gamma = softmax(log_gamma, axis=0)\n        else:\n            gamma = posterior\n\n        return gamma * X[self.reference_id]\n\n    def update_once(\n        self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\"\n    ) -> None:\n        r\"\"\"Perform E and M step once.\n\n        In ``update_posterior``, posterior probabilities are updated, which corresponds to E step.\n        In ``update_parameters``, parameters of cACGMM are updated, which corresponds to M step.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        self.update_posterior(flooring_fn=flooring_fn)\n        self.update_parameters(flooring_fn=flooring_fn)\n\n        if self.normalization:\n            self.normalize_covariance()\n\n    def update_posterior(\n        self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\"\n    ) -> None:\n        r\"\"\"Update posteriors.\n\n        This method corresponds to E step in EM algorithm for cACGMM.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        alpha = self.mixing\n        Z = self.unit_input\n        B = self.covariance\n\n        Z = Z.transpose(1, 2, 0)\n        B_inverse = np.linalg.inv(B)\n        ZBZ = quadratic(Z, B_inverse[:, :, np.newaxis])\n        ZBZ = np.real(ZBZ)\n        ZBZ = np.maximum(ZBZ, 0)\n        ZBZ = flooring_fn(ZBZ)\n\n        log_prob = np.log(alpha) - self.compute_logdet(B)\n        log_gamma = log_prob[:, :, np.newaxis] - self.n_channels * np.log(ZBZ)\n\n        gamma = softmax(log_gamma, axis=0)\n\n        self.posterior = gamma\n\n    def update_parameters(\n        self, flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\"\n    ) -> None:\n        r\"\"\"Update parameters of mixture of complex angular central Gaussian distributions.\n\n        This method corresponds to M step in EM algorithm for cACGMM.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        Z = self.unit_input\n        B = self.covariance\n        gamma = self.posterior\n\n        Z = Z.transpose(1, 2, 0)\n        B_inverse = np.linalg.inv(B)\n        ZBZ = quadratic(Z, B_inverse[:, :, np.newaxis])\n        ZBZ = np.real(ZBZ)\n        ZBZ = np.maximum(ZBZ, 0)\n        ZBZ = flooring_fn(ZBZ)\n        ZZ = Z[:, :, :, np.newaxis] * Z[:, :, np.newaxis, :].conj()\n\n        alpha = np.mean(gamma, axis=-1)\n\n        GZBZ = gamma / ZBZ\n        num = np.sum(GZBZ[:, :, :, np.newaxis, np.newaxis] * ZZ, axis=2)\n        denom = np.sum(gamma, axis=2)\n        B = self.n_channels * (num / denom[:, :, np.newaxis, np.newaxis])\n        B = to_psd(B, flooring_fn=flooring_fn)\n\n        self.mixing = alpha\n        self.covariance = B\n\n    def compute_loss(self) -> float:\n        r\"\"\"Compute loss of cACGMM :math:`\\mathcal{L}`.\n\n        :math:`\\mathcal{L}` is defined as follows:\n\n        .. math::\n            \\mathcal{L}\n            = -\\frac{1}{J}\\sum_{i,j}\\log\\left(\n            \\sum_{n}\\frac{\\alpha_{in}}{\\det\\boldsymbol{B}_{in}}\n            \\frac{1}{(\\boldsymbol{z}_{ij}^{\\mathsf{H}}\\boldsymbol{B}_{in}^{-1}\\boldsymbol{z}_{ij})^{M}}\n            \\right).\n        \"\"\"\n        alpha = self.mixing\n        Z = self.unit_input\n        B = self.covariance\n\n        Z = Z.transpose(1, 2, 0)\n        B_inverse = np.linalg.inv(B)\n        ZBZ = quadratic(Z, B_inverse[:, :, np.newaxis])\n        ZBZ = np.real(ZBZ)\n        ZBZ = np.maximum(ZBZ, 0)\n        ZBZ = self.flooring_fn(ZBZ)\n\n        log_prob = np.log(alpha) - self.compute_logdet(B)\n        log_gamma = log_prob[:, :, np.newaxis] - self.n_channels * np.log(ZBZ)\n\n        loss = -logsumexp(log_gamma, axis=0)\n        loss = np.mean(loss, axis=-1)\n        loss = loss.sum(axis=0)\n        loss = loss.item()\n\n        return loss\n"
  },
  {
    "path": "ssspy/bss/fdica.py",
    "content": "import functools\nfrom typing import Callable, Iterable, List, Optional, Tuple, Union\n\nimport numpy as np\n\nfrom ..algorithm import (\n    MINIMAL_DISTORTION_PRINCIPLE_KEYWORDS,\n    PROJECTION_BACK_KEYWORDS,\n    minimal_distortion_principle,\n    projection_back,\n)\nfrom ..algorithm.permutation_alignment import correlation_based_permutation_solver\nfrom ..special.flooring import identity, max_flooring\nfrom ..utils.flooring import choose_flooring_fn\nfrom ..utils.select_pair import sequential_pair_selector\nfrom ._update_spatial_model import update_by_ip1, update_by_ip2_one_pair\nfrom .base import IterativeMethodBase\n\n__all__ = [\n    \"GradFDICA\",\n    \"NaturalGradFDICA\",\n    \"AuxFDICA\",\n    \"GradLaplaceFDICA\",\n    \"NaturalGradLaplaceFDICA\",\n    \"AuxLaplaceFDICA\",\n]\n\nspatial_algorithms = [\"IP\", \"IP1\", \"IP2\"]\nEPS = 1e-10\n\n\nclass FDICABase(IterativeMethodBase):\n    r\"\"\"Base class of frequency-domain independent component analysis (FDICA).\n\n    Args:\n        contrast_fn (callable):\n            A contrast function which corresponds to :math:`-\\log p(y_{ijn})`.\n            This function is expected to receive (n_channels, n_bins, n_frames)\n            and return (n_channels, n_bins, n_frames).\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to receive (n_channels, n_bins, n_frames)\n            and return (n_channels, n_bins, n_frames).\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``partial(max_flooring, eps=1e-10)``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        permutation_alignment (bool):\n            If ``permutation_alignment=True``, a permutation solver is used to align\n            estimated spectrograms. Default: ``True``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back``\n            or ``minimal_distortion_principle``. Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the update algorithm if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back and minimal distortion principle. Default: ``0``.\n    \"\"\"\n\n    def __init__(\n        self,\n        contrast_fn: Callable[[np.ndarray], np.ndarray] = None,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        callbacks: Optional[\n            Union[Callable[[\"FDICABase\"], None], List[Callable[[\"FDICABase\"], None]]]\n        ] = None,\n        permutation_alignment: bool = True,\n        scale_restoration: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n    ) -> None:\n        super().__init__(callbacks=callbacks, record_loss=record_loss)\n\n        if contrast_fn is None:\n            raise ValueError(\"Specify contrast function.\")\n        else:\n            self.contrast_fn = contrast_fn\n\n        if flooring_fn is None:\n            self.flooring_fn = identity\n        else:\n            self.flooring_fn = flooring_fn\n\n        self.input = None\n        self.permutation_alignment = permutation_alignment\n        self.scale_restoration = scale_restoration\n\n        if reference_id is None and scale_restoration:\n            raise ValueError(\"Specify 'reference_id' if scale_restoration=True.\")\n        else:\n            self.reference_id = reference_id\n\n    def __call__(\n        self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs\n    ) -> np.ndarray:\n        r\"\"\"Separate a frequency-domain multichannel signal.\n\n        Args:\n            input (numpy.ndarray):\n                Mixture signal in frequency-domain.\n                The shape is (n_channels, n_bins, n_frames).\n            n_iter (int):\n                Number of iterations of demixing filter updates.\n                Default: ``100``.\n            initial_call (bool):\n                If ``True``, perform callbacks (and computation of loss if necessary)\n                before iterations.\n\n        Returns:\n            numpy.ndarray of the separated signal in frequency-domain.\n            The shape is (n_channels, n_bins, n_frames).\n        \"\"\"\n        self.input = input.copy()\n\n        self._reset(**kwargs)\n\n        super().__call__(n_iter=n_iter, initial_call=initial_call)\n\n        raise NotImplementedError(\"Implement '__call__' method.\")\n\n    def __repr__(self) -> str:\n        s = \"FDICA(\"\n        s += \", permutation_alignment={permutation_alignment}\"\n        s += \", scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def _reset(self, **kwargs) -> None:\n        r\"\"\"Reset attributes by given keyword arguments.\n\n        Args:\n            kwargs:\n                Keyword arguments to set as attributes of FDICA.\n        \"\"\"\n        assert self.input is not None, \"Specify data!\"\n\n        for key in kwargs.keys():\n            setattr(self, key, kwargs[key])\n\n        X = self.input\n\n        n_channels, n_bins, n_frames = X.shape\n        n_sources = n_channels  # n_channels == n_sources\n\n        self.n_sources, self.n_channels = n_sources, n_channels\n        self.n_bins, self.n_frames = n_bins, n_frames\n\n        if not hasattr(self, \"demix_filter\"):\n            W = np.eye(n_sources, n_channels, dtype=np.complex128)\n            W = np.tile(W, reps=(n_bins, 1, 1))\n        else:\n            if self.demix_filter is None:\n                W = None\n            else:\n                # To avoid overwriting ``demix_filter`` given by keyword arguments.\n                W = self.demix_filter.copy()\n\n        self.demix_filter = W\n        self.output = self.separate(X, demix_filter=W)\n\n    def separate(self, input: np.ndarray, demix_filter: np.ndarray) -> np.ndarray:\n        r\"\"\"Separate ``input`` using ``demixing_filter``.\n\n        .. math::\n            \\boldsymbol{y}_{ij}\n            = \\boldsymbol{W}_{i}\\boldsymbol{x}_{ij}\n\n        Args:\n            input (numpy.ndarray):\n                The mixture signal in frequency-domain.\n                The shape is (n_channels, n_bins, n_frames).\n            demix_filter (numpy.ndarray):\n                The demixing filters to separate ``input``.\n                The shape is (n_bins, n_sources, n_channels).\n\n        Returns:\n            numpy.ndarray of the separated signal in frequency-domain.\n            The shape is (n_sources, n_bins, n_frames).\n        \"\"\"\n        X, W = input, demix_filter\n        Y = W @ X.transpose(1, 0, 2)\n        output = Y.transpose(1, 0, 2)\n\n        return output\n\n    def compute_loss(self) -> float:\n        r\"\"\"Compute loss :math:`\\mathcal{L}`.\n\n        :math:`\\mathcal{L}` is given as follows:\n\n        .. math::\n            \\mathcal{L}\n            &= \\sum_{i}\\mathcal{L}^{[i]}, \\\\\n            \\mathcal{L}^{[i]}\n            &= \\frac{1}{J}\\sum_{j,n}G(y_{ijn})\n            - 2\\log|\\det\\boldsymbol{W}_{i}|, \\\\\n            G(y_{ijn}) \\\n            &= - \\log p(y_{ijn})\n\n        Returns:\n            Computed loss.\n        \"\"\"\n        X, W = self.input, self.demix_filter\n        Y = self.separate(X, demix_filter=W)  # (n_sources, n_bins, n_frames)\n        logdet = self.compute_logdet(W)  # (n_bins,)\n        G = self.contrast_fn(Y)  # (n_sources, n_bins, n_frames)\n        loss = np.sum(np.mean(G, axis=2), axis=0) - 2 * logdet\n        loss = loss.sum(axis=0).item()\n\n        return loss\n\n    def compute_logdet(self, demix_filter: np.ndarray) -> np.ndarray:\n        r\"\"\"Compute log-determinant of demixing filter.\n\n        Args:\n            demix_filter (numpy.ndarray):\n                Demixing filters with shape of (n_bins, n_sources, n_channels).\n\n        Returns:\n            numpy.ndarray of computed log-determinant values.\n        \"\"\"\n        _, logdet = np.linalg.slogdet(demix_filter)  # (n_bins,)\n\n        return logdet\n\n    def solve_permutation(self) -> None:\n        r\"\"\"Align demixing filters and separated spectrograms\"\"\"\n\n        permutation_alignment = self.permutation_alignment\n\n        assert permutation_alignment, \"Set permutation_alignment=True.\"\n\n        if type(permutation_alignment) is bool:\n            # when permutation_alignment is True\n            permutation_alignment = \"spectrogram_correlation\"\n\n        if permutation_alignment == \"spectrogram_correlation\":\n            self.solve_permutation_by_correlation()\n        else:\n            raise NotImplementedError(\n                \"permutation_alignment {} is not implemented.\".format(permutation_alignment)\n            )\n\n    def solve_permutation_by_correlation(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Align posteriors and separated spectrograms by correlation.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n        \"\"\"\n\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n        X, W = self.input, self.demix_filter\n\n        Y = self.separate(X, demix_filter=W)\n        Y = Y.transpose(1, 0, 2)\n        Y, W = correlation_based_permutation_solver(Y, W, flooring_fn=flooring_fn)\n        Y = Y.transpose(1, 0, 2)\n\n        self.output, self.demix_filter = Y, W\n\n    def restore_scale(self) -> None:\n        r\"\"\"Restore scale ambiguity.\n\n        If ``self.scale_restoration=projection_back``, we use projection back technique.\n        If ``self.scale_restoration=minimal_distortion_principle``,\n        we use minimal distortion principle.\n        \"\"\"\n        scale_restoration = self.scale_restoration\n\n        assert scale_restoration, \"Set self.scale_restoration=True.\"\n\n        if type(scale_restoration) is bool:\n            scale_restoration = PROJECTION_BACK_KEYWORDS[0]\n\n        if scale_restoration in PROJECTION_BACK_KEYWORDS:\n            self.apply_projection_back()\n        elif scale_restoration in MINIMAL_DISTORTION_PRINCIPLE_KEYWORDS:\n            self.apply_minimal_distortion_principle()\n        else:\n            raise ValueError(\"{} is not supported for scale restoration.\".format(scale_restoration))\n\n    def apply_projection_back(self) -> None:\n        r\"\"\"Apply projection back technique to estimated spectrograms.\"\"\"\n        assert self.scale_restoration, \"Set self.scale_restoration=True.\"\n\n        X, W = self.input, self.demix_filter\n        W_scaled = projection_back(W, reference_id=self.reference_id)\n        Y_scaled = self.separate(X, demix_filter=W_scaled)\n\n        self.output, self.demix_filter = Y_scaled, W_scaled\n\n    def apply_minimal_distortion_principle(self) -> None:\n        r\"\"\"Apply minimal distortion principle to estimated spectrograms.\"\"\"\n        assert self.scale_restoration, \"Set self.scale_restoration=True.\"\n\n        X, W = self.input, self.demix_filter\n        Y = self.separate(X, demix_filter=W)\n        Y_scaled = minimal_distortion_principle(Y, reference=X, reference_id=self.reference_id)\n        X = X.transpose(1, 0, 2)\n        Y = Y_scaled.transpose(1, 0, 2)\n        X_Hermite = X.transpose(0, 2, 1).conj()\n        W_scaled = Y @ X_Hermite @ np.linalg.inv(X @ X_Hermite)\n\n        self.output, self.demix_filter = Y_scaled, W_scaled\n\n\nclass GradFDICABase(FDICABase):\n    r\"\"\"Base class of frequency-domain independent component analysis (FDICA) \\\n    using the gradient descent.\n\n    Args:\n        step_size (float):\n            A step size of the gradient descent. Default: ``1e-1``.\n        contrast_fn (callable):\n            A contrast function which corresponds to :math:`-\\log p(y_{ijn})`.\n            This function is expected to receive (n_channels, n_bins, n_frames)\n            and return (n_channels, n_bins, n_frames).\n        score_fn (callable):\n            A score function which corresponds to the partial derivative of the contrast function.\n            This function is expected to receive (n_channels, n_bins, n_frames)\n            and return (n_channels, n_bins, n_frames).\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to receive (n_channels, n_bins, n_frames)\n            and return (n_channels, n_bins, n_frames).\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``partial(max_flooring, eps=1e-10)``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        permutation_alignment (bool):\n            If ``permutation_alignment=True``, a permutation solver is used to align\n            estimated spectrograms. Default: ``True``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back``\n            or ``minimal_distortion_principle``. Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the gradient descent if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back and minimal distortion principle. Default: ``0``.\n    \"\"\"\n\n    def __init__(\n        self,\n        step_size: float = 1e-1,\n        contrast_fn: Callable[[np.ndarray], np.ndarray] = None,\n        score_fn: Callable[[np.ndarray], np.ndarray] = None,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        callbacks: Optional[\n            Union[Callable[[\"GradFDICABase\"], None], List[Callable[[\"GradFDICABase\"], None]]]\n        ] = None,\n        permutation_alignment: bool = True,\n        scale_restoration: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n    ) -> None:\n        super().__init__(\n            contrast_fn=contrast_fn,\n            flooring_fn=flooring_fn,\n            callbacks=callbacks,\n            permutation_alignment=permutation_alignment,\n            scale_restoration=scale_restoration,\n            record_loss=record_loss,\n            reference_id=reference_id,\n        )\n\n        self.step_size = step_size\n\n        if score_fn is None:\n            raise ValueError(\"Specify score function.\")\n        else:\n            self.score_fn = score_fn\n\n    def __call__(\n        self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs\n    ) -> np.ndarray:\n        r\"\"\"Separate a frequency-domain multichannel signal.\n\n        Args:\n            input (numpy.ndarray):\n                The mixture signal in frequency-domain.\n                The shape is (n_channels, n_bins, n_frames).\n            n_iter (int):\n                The number of iterations of demixing filter updates.\n                Default: ``100``.\n            initial_call (bool):\n                If ``True``, perform callbacks (and computation of loss if necessary)\n                before iterations.\n\n        Returns:\n            numpy.ndarray of the separated signal in frequency-domain.\n            The shape is (n_channels, n_bins, n_frames).\n        \"\"\"\n        self.input = input.copy()\n\n        self._reset(**kwargs)\n\n        # Call __call__ of FDICABase's parent, i.e. __call__ of IterativeMethodBase\n        super(FDICABase, self).__call__(n_iter=n_iter, initial_call=initial_call)\n\n        if self.permutation_alignment:\n            self.solve_permutation()\n\n        if self.scale_restoration:\n            self.restore_scale()\n\n        self.output = self.separate(self.input, demix_filter=self.demix_filter)\n\n        return self.output\n\n    def __repr__(self) -> str:\n        s = \"GradFDICA(\"\n        s += \"step_size={step_size}\"\n        s += \", permutation_alignment={permutation_alignment}\"\n        s += \", scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def update_once(self) -> None:\n        r\"\"\"Update demixing filters once.\"\"\"\n        raise NotImplementedError(\"Implement 'update_once' method.\")\n\n\nclass GradFDICA(GradFDICABase):\n    r\"\"\"Frequency-domain independent component analysis (FDICA) \\\n    using the gradient descent.\n\n    Args:\n        step_size (float):\n            A step size of the gradient descent. Default: ``1e-1``.\n        contrast_fn (callable):\n            A contrast function corresponds to :math:`-\\log p(y_{ijn})`.\n            This function is expected to receive (n_channels, n_bins, n_frames)\n            and return (n_channels, n_bins, n_frames).\n        score_fn (callable):\n            A score function corresponds to the partial derivative of the contrast function.\n            This function is expected to receive (n_channels, n_bins, n_frames)\n            and return (n_channels, n_bins, n_frames).\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to receive (n_channels, n_bins, n_frames)\n            and return (n_channels, n_bins, n_frames).\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``partial(max_flooring, eps=1e-10)``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        is_holonomic (bool):\n            If ``is_holonomic=True``, Holonomic-type update is used.\n            Otherwise, Nonholonomic-type update is used. Default: ``False``.\n        permutation_alignment (bool):\n            If ``permutation_alignment=True``, a permutation solver is used to align\n            estimated spectrograms. Default: ``True``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back``\n            or ``minimal_distortion_principle``. Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the gradient descent if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back and minimal distortion principle. Default: ``0``.\n\n    Examples:\n        Update demixing filters using Holonomic-type update:\n\n        .. code-block:: python\n\n            >>> def contrast_fn(y):\n            ...     return 2 * np.abs(y)\n\n            >>> def score_fn(y):\n            ...     denom = np.maximum(np.abs(y), 1e-10)\n            ...     return y / denom\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = \\\n            ...     np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> fdica = GradFDICA(\n            ...     contrast_fn=contrast_fn,\n            ...     score_fn=score_fn,\n            ...     is_holonomic=True,\n            ... )\n            >>> spectrogram_est = fdica(spectrogram_mix, n_iter=1000)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n        Update demixing filters using Nonholonomic-type update:\n\n        .. code-block:: python\n\n            >>> def contrast_fn(y):\n            ...     return 2 * np.abs(y)\n\n            >>> def score_fn(y):\n            ...     denom = np.maximum(np.abs(y), 1e-10)\n            ...     return y / denom\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = \\\n            ...     np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> fdica = GradFDICA(\n            ...     contrast_fn=contrast_fn,\n            ...     score_fn=score_fn,\n            ...     is_holonomic=False,\n            ... )\n            >>> spectrogram_est = fdica(spectrogram_mix, n_iter=1000)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n    \"\"\"\n\n    def __init__(\n        self,\n        step_size: float = 1e-1,\n        contrast_fn: Callable[[np.ndarray], np.ndarray] = None,\n        score_fn: Callable[[np.ndarray], np.ndarray] = None,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        callbacks: Optional[\n            Union[Callable[[\"GradFDICA\"], None], List[Callable[[\"GradFDICA\"], None]]]\n        ] = None,\n        is_holonomic: bool = False,\n        permutation_alignment: bool = True,\n        scale_restoration: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n    ) -> None:\n        super().__init__(\n            step_size=step_size,\n            contrast_fn=contrast_fn,\n            score_fn=score_fn,\n            flooring_fn=flooring_fn,\n            callbacks=callbacks,\n            permutation_alignment=permutation_alignment,\n            scale_restoration=scale_restoration,\n            record_loss=record_loss,\n            reference_id=reference_id,\n        )\n\n        self.is_holonomic = is_holonomic\n\n    def __repr__(self) -> str:\n        s = \"GradFDICA(\"\n        s += \"step_size={step_size}\"\n        s += \", is_holonomic={is_holonomic}\"\n        s += \", permutation_alignment={permutation_alignment}\"\n        s += \", scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def update_once(self) -> None:\n        r\"\"\"Update demixing filters once using the gradient descent.\n\n        If ``is_holonomic=True``, demixing filters are updated as follows:\n\n        .. math::\n            \\boldsymbol{W}_{i}\n            \\leftarrow\\boldsymbol{W}_{i} - \\eta\\left(\\frac{1}{J}\\sum_{j}\n            \\boldsymbol{\\phi}(\\boldsymbol{y}_{ij})\\boldsymbol{y}_{ij}^{\\mathsf{H}}\n            -\\boldsymbol{I}\\right)\\boldsymbol{W}_{i}^{-\\mathsf{H}},\n\n        where\n\n        .. math::\n            \\boldsymbol{\\phi}(\\boldsymbol{y}_{ij})\n            &= \\left(\\phi(y_{ij1}),\\ldots,\\phi(y_{ijn}),\\ldots,\\phi(y_{ijN})\n            \\right)^{\\mathsf{T}}\\in\\mathbb{C}^{N}, \\\\\n            \\phi(y_{ijn})\n            &= \\frac{\\partial G(y_{ijn})}{\\partial y_{ijn}^{*}}, \\\\\n            G(y_{ijn})\n            &= -\\log p(y_{ijn}).\n\n        Otherwise (``is_holonomic=False``),\n\n        .. math::\n            \\boldsymbol{W}_{i}\n            \\leftarrow\\boldsymbol{W}_{i}\n            - \\eta\\cdot\\mathrm{offdiag}\\left(\\frac{1}{J}\\sum_{j}\n            \\boldsymbol{\\phi}(\\boldsymbol{y}_{ij})\\boldsymbol{y}_{ij}^{\\mathsf{H}}\\right)\n            \\boldsymbol{W}_{i}^{-\\mathsf{H}}.\n        \"\"\"\n        X, W = self.input, self.demix_filter\n        Y = self.separate(X, demix_filter=W)\n\n        Phi = self.score_fn(Y)\n        Y_conj = Y.conj()\n        PhiY = np.mean(Phi[:, np.newaxis, :, :] * Y_conj[np.newaxis, :, :, :], axis=-1)\n        PhiY = PhiY.transpose(2, 0, 1)  # (n_bins, n_sources, n_sources)\n        W_inv = np.linalg.inv(W)\n        W_inv_Hermite = W_inv.transpose(0, 2, 1).conj()\n        eye = np.eye(self.n_sources)\n\n        if self.is_holonomic:\n            delta = (PhiY - eye) @ W_inv_Hermite\n        else:\n            delta = ((1 - eye) * PhiY) @ W_inv_Hermite\n\n        W = W - self.step_size * delta\n\n        Y = self.separate(X, demix_filter=W)\n\n        self.demix_filter = W\n        self.output = Y\n\n\nclass NaturalGradFDICA(GradFDICABase):\n    r\"\"\"Frequency-domain independent component analysis (FDICA) \\\n    using the natural gradient descent.\n\n    Args:\n        step_size (float):\n            A step size of the gradient descent. Default: ``1e-1``.\n        contrast_fn (callable):\n            A contrast function corresponds to :math:`-\\log p(y_{ijn})`.\n            This function is expected to receive (n_channels, n_bins, n_frames)\n            and return (n_channels, n_bins, n_frames).\n        score_fn (callable):\n            A score function corresponds to the partial derivative of the contrast function.\n            This function is expected to receive (n_channels, n_bins, n_frames)\n            and return (n_channels, n_bins, n_frames).\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to receive (n_channels, n_bins, n_frames)\n            and return (n_channels, n_bins, n_frames).\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``partial(max_flooring, eps=1e-10)``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        is_holonomic (bool):\n            If ``is_holonomic=True``, Holonomic-type update is used.\n            Otherwise, Nonholonomic-type update is used. Default: ``False``.\n        permutation_alignment (bool):\n            If ``permutation_alignment=True``, a permutation solver is used to align\n            estimated spectrograms. Default: ``True``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back``\n            or ``minimal_distortion_principle``. Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the gradient descent if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back and minimal distortion principle. Default: ``0``.\n\n    Examples:\n        Update demixing filters using Holonomic-type update:\n\n        .. code-block:: python\n\n            >>> def contrast_fn(y):\n            ...     return 2 * np.abs(y)\n\n            >>> def score_fn(y):\n            ...     denom = np.maximum(np.abs(y), 1e-10)\n            ...     return y / denom\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = \\\n            ...     np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> fdica = NaturalGradFDICA(\n            ...     contrast_fn=contrast_fn,\n            ...     score_fn=score_fn,\n            ...     is_holonomic=True,\n            ... )\n            >>> spectrogram_est = fdica(spectrogram_mix, n_iter=1000)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n        Update demixing filters using Nonholonomic-type update:\n\n        .. code-block:: python\n\n            >>> def contrast_fn(y):\n            ...     return 2 * np.abs(y)\n\n            >>> def score_fn(y):\n            ...     denom = np.maximum(np.abs(y), 1e-10)\n            ...     return y / denom\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = \\\n            ...     np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> fdica = NaturalGradFDICA(\n            ...     contrast_fn=contrast_fn,\n            ...     score_fn=score_fn,\n            ...     is_holonomic=False,\n            ... )\n            >>> spectrogram_est = fdica(spectrogram_mix, n_iter=1000)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n    \"\"\"\n\n    def __init__(\n        self,\n        step_size: float = 1e-1,\n        contrast_fn: Callable[[np.ndarray], np.ndarray] = None,\n        score_fn: Callable[[np.ndarray], np.ndarray] = None,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        callbacks: Optional[\n            Union[Callable[[\"NaturalGradFDICA\"], None], List[Callable[[\"NaturalGradFDICA\"], None]]]\n        ] = None,\n        is_holonomic: bool = False,\n        permutation_alignment: bool = True,\n        scale_restoration: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n    ) -> None:\n        super().__init__(\n            step_size=step_size,\n            contrast_fn=contrast_fn,\n            score_fn=score_fn,\n            flooring_fn=flooring_fn,\n            callbacks=callbacks,\n            permutation_alignment=permutation_alignment,\n            scale_restoration=scale_restoration,\n            record_loss=record_loss,\n            reference_id=reference_id,\n        )\n\n        self.is_holonomic = is_holonomic\n\n    def __repr__(self) -> str:\n        s = \"NaturalGradFDICA(\"\n        s += \"step_size={step_size}\"\n        s += \", is_holonomic={is_holonomic}\"\n        s += \", permutation_alignment={permutation_alignment}\"\n        s += \", scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def update_once(self) -> None:\n        r\"\"\"Update demixing filters once using the gradient descent.\n\n        If ``is_holonomic=True``, demixing filters are updated as follows:\n\n        .. math::\n            \\boldsymbol{W}_{i}\n            \\leftarrow\\boldsymbol{W}_{i} - \\eta\\left(\\frac{1}{J}\\sum_{j}\n            \\boldsymbol{\\phi}(\\boldsymbol{y}_{ij})\\boldsymbol{y}_{ij}^{\\mathsf{H}}\n            -\\boldsymbol{I}\\right)\\boldsymbol{W}_{i},\n\n        where\n\n        .. math::\n            \\boldsymbol{\\phi}(\\boldsymbol{y}_{ij})\n            &= \\left(\\phi(y_{ij1}),\\ldots,\\phi(y_{ijn}),\\ldots,\\phi(y_{ijN})\n            \\right)^{\\mathsf{T}}\\in\\mathbb{C}^{N}, \\\\\n            \\phi(y_{ijn})\n            &= \\frac{\\partial G(y_{ijn})}{\\partial y_{ijn}^{*}}, \\\\\n            G(y_{ijn})\n            &= -\\log p(y_{ijn}).\n\n        Otherwise (``is_holonomic=False``),\n\n        .. math::\n            \\boldsymbol{W}_{i}\n            \\leftarrow\\boldsymbol{W}_{i}\n            - \\eta\\cdot\\mathrm{offdiag}\\left(\\frac{1}{J}\\sum_{j}\n            \\boldsymbol{\\phi}(\\boldsymbol{y}_{ij})\\boldsymbol{y}_{ij}^{\\mathsf{H}}\\right)\n            \\boldsymbol{W}_{i}.\n        \"\"\"\n        X, W = self.input, self.demix_filter\n        Y = self.separate(X, demix_filter=W)\n\n        Phi = self.score_fn(Y)\n        Y_conj = Y.conj()\n        PhiY = np.mean(Phi[:, np.newaxis, :, :] * Y_conj[np.newaxis, :, :, :], axis=-1)\n        PhiY = PhiY.transpose(2, 0, 1)  # (n_bins, n_sources, n_sources)\n        eye = np.eye(self.n_sources)\n\n        if self.is_holonomic:\n            delta = (PhiY - eye) @ W\n        else:\n            delta = ((1 - eye) * PhiY) @ W\n\n        W = W - self.step_size * delta\n\n        Y = self.separate(X, demix_filter=W)\n\n        self.demix_filter = W\n        self.output = Y\n\n\nclass AuxFDICA(FDICABase):\n    r\"\"\"Auxiliary-function-based frequency-domain independent component analysis \\\n    (AuxFDICA) [#ono2010auxiliary]_.\n\n    Args:\n        spatial_algorithm (str):\n            Algorithm to update demixing filters.\n            Choose ``IP``, ``IP1``, or ``IP2``. Default: ``IP``.\n        contrast_fn (callable):\n            A contrast function corresponds to :math:`-\\log p(y_{ijn})`.\n            This function is expected to receive (n_channels, n_bins, n_frames)\n            and return (n_channels, n_bins, n_frames).\n        d_contrast_fn (callable):\n            A partial derivative of the real contrast function.\n            This function is expected to receive (n_channels, n_bins, n_frames)\n            and return (n_channels, n_bins, n_frames).\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to receive (n_channels, n_bins, n_frames)\n            and return (n_channels, n_bins, n_frames).\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``partial(max_flooring, eps=1e-10)``.\n        pair_selector (callable, optional):\n            Selector to choose updaing pair in ``IP2`` and ``ISS2``.\n            If ``None`` is given, ``partial(sequential_pair_selector, sort=True)`` is used.\n            Default: ``None``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        permutation_alignment (bool):\n            If ``permutation_alignment=True``, a permutation solver is used to align\n            estimated spectrograms. Default: ``True``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back``\n            or ``minimal_distortion_principle``. Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the demixing filter update if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back and minimal distortion principle. Default: ``0``.\n\n    Examples:\n        Update demixing filters by IP:\n\n        .. code-block:: python\n\n            >>> def contrast_fn(y):\n            ...     return 2 * np.abs(y)\n\n            >>> def d_contrast_fn(y):\n            ...     return 2 * np.ones_like(y)\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> fdica = AuxFDICA(\n            ...     spatial_algorithm=\"IP\",\n            ...     contrast_fn=contrast_fn,\n            ...     d_contrast_fn=d_contrast_fn,\n            ... )\n            >>> spectrogram_est = fdica(spectrogram_mix, n_iter=100)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n        Update demixing filters by IP2:\n\n        .. code-block:: python\n\n            >>> from ssspy.utils.select_pair import sequential_pair_selector\n\n            >>> def contrast_fn(y):\n            ...     return 2 * np.abs(y)\n\n            >>> def d_contrast_fn(y):\n            ...     return 2 * np.ones_like(y)\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> fdica = AuxFDICA(\n            ...     spatial_algorithm=\"IP2\",\n            ...     contrast_fn=contrast_fn,\n            ...     d_contrast_fn=d_contrast_fn,\n            ...     pair_selector=sequential_pair_selector,\n            ... )\n            >>> spectrogram_est = fdica(spectrogram_mix, n_iter=100)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n    .. [#ono2010auxiliary]\n        N. Ono and S. Miyabe,\n        \"Auxiliary-function-based independent component analysis for super-Gaussian sources,\"\n        in *Proc. LVA/ICA*, 2010, pp.165-172.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_algorithm: str = \"IP\",\n        contrast_fn: Callable[[np.ndarray], np.ndarray] = None,\n        d_contrast_fn: Callable[[np.ndarray], np.ndarray] = None,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        pair_selector: Optional[Callable[[int], Iterable[Tuple[int, int]]]] = None,\n        callbacks: Optional[\n            Union[Callable[[\"AuxFDICA\"], None], List[Callable[[\"AuxFDICA\"], None]]]\n        ] = None,\n        permutation_alignment: bool = True,\n        scale_restoration: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n    ) -> None:\n        super().__init__(\n            contrast_fn=contrast_fn,\n            flooring_fn=flooring_fn,\n            callbacks=callbacks,\n            permutation_alignment=permutation_alignment,\n            scale_restoration=scale_restoration,\n            record_loss=record_loss,\n            reference_id=reference_id,\n        )\n        assert spatial_algorithm in spatial_algorithms, \"Not support {}.\".format(spatial_algorithms)\n\n        self.spatial_algorithm = spatial_algorithm\n        self.d_contrast_fn = d_contrast_fn\n\n        if pair_selector is None:\n            if spatial_algorithm == \"IP2\":\n                self.pair_selector = sequential_pair_selector\n        else:\n            self.pair_selector = pair_selector\n\n    def __call__(\n        self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs\n    ) -> np.ndarray:\n        r\"\"\"Separate a frequency-domain multichannel signal.\n\n        Args:\n            input (numpy.ndarray):\n                The mixture signal in frequency-domain.\n                The shape is (n_channels, n_bins, n_frames).\n            n_iter (int):\n                The number of iterations of demixing filter updates.\n                Default: ``100``.\n            initial_call (bool):\n                If ``True``, perform callbacks (and computation of loss if necessary)\n                before iterations.\n\n        Returns:\n            numpy.ndarray of the separated signal in frequency-domain.\n            The shape is (n_channels, n_bins, n_frames).\n        \"\"\"\n        self.input = input.copy()\n\n        self._reset(**kwargs)\n\n        # Call __call__ of FDICABase's parent, i.e. __call__ of IterativeMethodBase\n        super(FDICABase, self).__call__(n_iter=n_iter, initial_call=initial_call)\n\n        if self.permutation_alignment:\n            self.solve_permutation()\n\n        if self.scale_restoration:\n            self.restore_scale()\n\n        if self.demix_filter is not None:\n            self.output = self.separate(self.input, demix_filter=self.demix_filter)\n        else:\n            # TODO: implement demixing-filter-free algorithms (e.g. ISS, IPA, etc.)\n            pass\n\n        return self.output\n\n    def __repr__(self) -> str:\n        s = \"AuxFDICA(\"\n        s += \"spatial_algorithm={spatial_algorithm}\"\n        s += \", permutation_alignment={permutation_alignment}\"\n        s += \", scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def update_once(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update demixing filters once.\n\n        - If ``self.spatial_algorithm`` is ``IP`` or ``IP1``, ``update_once_ip1`` is called.\n        - If ``self.spatial_algorithm`` is ``IP2``, ``update_once_ip2`` is called.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if self.spatial_algorithm in [\"IP\", \"IP1\"]:\n            self.update_once_ip1(flooring_fn=flooring_fn)\n        elif self.spatial_algorithm in [\"IP2\"]:\n            self.update_once_ip2(flooring_fn=flooring_fn)\n        else:\n            raise NotImplementedError(\"Not support {}.\".format(self.spatial_algorithm))\n\n    def update_once_ip1(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update demixing filters once using iterative projection.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        Demixing filters are updated sequentially for :math:`n=1,\\ldots,N` as follows:\n\n        .. math::\n            \\boldsymbol{w}_{in}\n            &\\leftarrow\\left(\\boldsymbol{W}_{in}^{\\mathsf{H}}\\boldsymbol{U}_{in}\\right)^{-1}\n            \\boldsymbol{e}_{n}, \\\\\n            \\boldsymbol{w}_{in}\n            &\\leftarrow\\frac{\\boldsymbol{w}_{in}}\n            {\\sqrt{\\boldsymbol{w}_{in}^{\\mathsf{H}}\\boldsymbol{U}_{in}\\boldsymbol{w}_{in}}}, \\\\\n\n        where\n\n        .. math::\n            \\boldsymbol{U}_{in}\n            &= \\frac{1}{J}\\sum_{j}\n            \\frac{G'_{\\mathbb{R}}(|y_{ijn}|)}{2|y_{ijn}|}\n            \\boldsymbol{x}_{ij}\\boldsymbol{x}_{ij}^{\\mathsf{H}}, \\\\\n            G(y_{ijn})\n            &= -\\log p(y_{ijn}), \\\\\n            G_{\\mathbb{R}}(|y_{ijn}|)\n            &= G(y_{ijn}).\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        X, W = self.input, self.demix_filter\n        Y = self.separate(X, demix_filter=W)\n\n        XX_Hermite = X[:, np.newaxis, :, :] * X[np.newaxis, :, :, :].conj()\n        XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3)  # (n_bins, n_channels, n_channels, n_frames)\n        Y_abs = np.abs(Y)\n        denom = flooring_fn(2 * Y_abs)\n        varphi = self.d_contrast_fn(Y_abs) / denom  # (n_sources, n_bins, n_frames)\n        varphi = varphi.transpose(1, 0, 2)  # (n_bins, n_sources, n_frames)\n        GXX = varphi[:, :, np.newaxis, np.newaxis, :] * XX_Hermite[:, np.newaxis, :, :, :]\n        U = np.mean(GXX, axis=-1)  # (n_bins, n_sources, n_channels, n_channels)\n\n        self.demix_filter = update_by_ip1(W, U, flooring_fn=flooring_fn)\n\n    def update_once_ip2(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update demixing filters once using pairwise iterative projection.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        For :math:`n_{1}` and :math:`n_{2}` (:math:`n_{1}\\neq n_{2}`),\n        compute auxiliary variables:\n\n        .. math::\n            \\bar{r}_{ijn_{1}}\n            &\\leftarrow|y_{ijn_{1}}| \\\\\n            \\bar{r}_{ijn_{2}}\n            &\\leftarrow|y_{ijn_{2}}|\n\n        Then, for :math:`n=n_{1},n_{2}`, compute weighted covariance matrix as follows:\n\n        .. math::\n            \\boldsymbol{U}_{in_{1}}\n            &= \\frac{1}{J}\\sum_{j}\n            \\frac{G'_{\\mathbb{R}}(\\bar{r}_{ijn_{1}})}{2\\bar{r}_{ijn_{1}}}\n            \\boldsymbol{x}_{ij}\\boldsymbol{x}_{ij}^{\\mathsf{H}}, \\\\\n            \\boldsymbol{U}_{in_{2}}\n            &= \\frac{1}{J}\\sum_{j}\n            \\frac{G'_{\\mathbb{R}}(\\bar{r}_{ijn_{2}})}{2\\bar{r}_{ijn_{2}}}\n            \\boldsymbol{x}_{ij}\\boldsymbol{x}_{ij}^{\\mathsf{H}},\n\n        where\n\n        .. math::\n            G(y_{ijn})\n            &= -\\log p(y_{ijn}), \\\\\n            G_{\\mathbb{R}}(|y_{ijn}|)\n            &= G(y_{ijn}).\n\n        Using :math:`\\boldsymbol{U}_{in_{1}}` and\n        :math:`\\boldsymbol{U}_{in_{2}}`, we compute generalized eigenvectors.\n\n        .. math::\n            \\left({\\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}}^{\\mathsf{H}}\\boldsymbol{U}_{in_{1}}\n            \\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\\right)\\boldsymbol{h}_{i}\n            = \\lambda_{i}\n            \\left({\\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}}^{\\mathsf{H}}\\boldsymbol{U}_{in_{2}}\n            \\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\\right)\\boldsymbol{h}_{i},\n\n        where\n\n        .. math::\n            \\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\n            &= (\\boldsymbol{W}_{i}\\boldsymbol{U}_{in_{1}})^{-1}\n            (\n            \\begin{array}{cc}\n                \\boldsymbol{e}_{n_{1}} & \\boldsymbol{e}_{n_{2}}\n            \\end{array}\n            ), \\\\\n            \\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\n            &= (\\boldsymbol{W}_{i}\\boldsymbol{U}_{in_{2}})^{-1}\n            (\n            \\begin{array}{cc}\n                \\boldsymbol{e}_{n_{1}} & \\boldsymbol{e}_{n_{2}}\n            \\end{array}\n            ).\n\n        After that, we standardize two eigenvectors :math:`\\boldsymbol{h}_{in_{1}}`\n        and :math:`\\boldsymbol{h}_{in_{2}}`.\n\n        .. math::\n            \\boldsymbol{h}_{in_{1}}\n            &\\leftarrow\\frac{\\boldsymbol{h}_{in_{1}}}\n            {\\sqrt{\\boldsymbol{h}_{in_{1}}^{\\mathsf{H}}\n            \\left({\\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}}^{\\mathsf{H}}\\boldsymbol{U}_{in_{1}}\n            \\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\\right)\n            \\boldsymbol{h}_{in_{1}}}}, \\\\\n            \\boldsymbol{h}_{in_{2}}\n            &\\leftarrow\\frac{\\boldsymbol{h}_{in_{2}}}\n            {\\sqrt{\\boldsymbol{h}_{in_{2}}^{\\mathsf{H}}\n            \\left({\\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}}^{\\mathsf{H}}\\boldsymbol{U}_{in_{2}}\n            \\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\\right)\n            \\boldsymbol{h}_{in_{2}}}}.\n\n        Then, update :math:`\\boldsymbol{w}_{in_{1}}` and :math:`\\boldsymbol{w}_{in_{2}}`\n        simultaneously.\n\n        .. math::\n            \\boldsymbol{w}_{in_{1}}\n            &\\leftarrow \\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\\boldsymbol{h}_{in_{1}} \\\\\n            \\boldsymbol{w}_{in_{2}}\n            &\\leftarrow \\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\\boldsymbol{h}_{in_{2}}\n\n        At each iteration, we update pairs of :math:`n_{1}` and :math:`n_{1}`\n        for :math:`n_{1}\\neq n_{2}`.\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        n_sources = self.n_sources\n        X, W = self.input, self.demix_filter\n\n        XX_Hermite = X[:, np.newaxis, :, :] * X[np.newaxis, :, :, :].conj()\n        XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3)\n\n        for m, n in self.pair_selector(n_sources):\n            W_mn = W[:, (m, n), :]\n            Y_mn = self.separate(X, demix_filter=W_mn)\n\n            Y_abs_mn = np.abs(Y_mn)\n            denom = flooring_fn(2 * Y_abs_mn)\n            varphi_mn = self.d_contrast_fn(Y_abs_mn) / denom\n            varphi_mn = varphi_mn.transpose(1, 0, 2)\n            GXX_mn = varphi_mn[:, :, np.newaxis, np.newaxis, :] * XX_Hermite[:, np.newaxis, :, :, :]\n            U_mn = np.mean(GXX_mn, axis=-1)\n\n            W[:, (m, n), :] = update_by_ip2_one_pair(\n                W,\n                U_mn,\n                pair=(m, n),\n                flooring_fn=flooring_fn,\n            )\n\n        self.demix_filter = W\n\n\nclass GradLaplaceFDICA(GradFDICA):\n    r\"\"\"Frequency-domain independent component analysis (FDICA) \\\n    using the gradient descent on a Laplace distribution.\n\n    We assume :math:`y_{ijn}` follows a Laplace distribution.\n\n    .. math::\n        p(y_{ijn})\\propto\\exp(|y_{ijn}|)\n\n    Args:\n        step_size (float):\n            A step size of the gradient descent. Default: ``1e-1``.\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to receive (n_channels, n_bins, n_frames)\n            and return (n_channels, n_bins, n_frames).\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``partial(max_flooring, eps=1e-10)``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        is_holonomic (bool):\n            If ``is_holonomic=True``, Holonomic-type update is used.\n            Otherwise, Nonholonomic-type update is used. Default: ``False``.\n        permutation_alignment (bool):\n            If ``permutation_alignment=True``, a permutation solver is used to align\n            estimated spectrograms. Default: ``True``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back``\n            or ``minimal_distortion_principle``. Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the gradient descent if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back and minimal distortion principle. Default: ``0``.\n\n    Examples:\n        Update demixing filters using Holonomic-type update:\n\n        .. code-block:: python\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = \\\n            ...     np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> fdica = GradLaplaceFDICA(is_holonomic=True)\n            >>> spectrogram_est = fdica(spectrogram_mix, n_iter=1000)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n        Update demixing filters using Nonholonomic-type update:\n\n        .. code-block:: python\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = \\\n            ...     np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> fdica = GradLaplaceFDICA(is_holonomic=False)\n            >>> spectrogram_est = fdica(spectrogram_mix, n_iter=1000)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n    \"\"\"\n\n    def __init__(\n        self,\n        step_size: float = 1e-1,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        callbacks: Optional[\n            Union[Callable[[\"GradLaplaceFDICA\"], None], List[Callable[[\"GradLaplaceFDICA\"], None]]]\n        ] = None,\n        is_holonomic: bool = False,\n        permutation_alignment: bool = True,\n        scale_restoration: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n    ) -> None:\n        def contrast_fn(y: np.ndarray) -> np.ndarray:\n            r\"\"\"Contrast function.\n\n            Args:\n                y (numpy.ndarray):\n                    The shape is (n_sources, n_bins, n_frames).\n\n            Returns:\n                The shape is (n_sources, n_bins, n_frames).\n            \"\"\"\n            return 2 * np.abs(y)\n\n        def score_fn(y: np.ndarray) -> np.ndarray:\n            r\"\"\"Score function.\n\n            Args:\n                y (numpy.ndarray):\n                    The shape is (n_sources, n_bins, n_frames).\n\n            Returns:\n                The shape is (n_sources, n_bins, n_frames).\n            \"\"\"\n            denom = self.flooring_fn(np.abs(y))\n            return y / denom\n\n        super().__init__(\n            step_size=step_size,\n            contrast_fn=contrast_fn,\n            score_fn=score_fn,\n            flooring_fn=flooring_fn,\n            callbacks=callbacks,\n            is_holonomic=is_holonomic,\n            permutation_alignment=permutation_alignment,\n            scale_restoration=scale_restoration,\n            record_loss=record_loss,\n            reference_id=reference_id,\n        )\n\n    def __repr__(self) -> str:\n        s = \"GradLaplaceFDICA(\"\n        s += \"step_size={step_size}\"\n        s += \", is_holonomic={is_holonomic}\"\n        s += \", permutation_alignment={permutation_alignment}\"\n        s += \", scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n\nclass NaturalGradLaplaceFDICA(NaturalGradFDICA):\n    r\"\"\"Frequency-domain independent component analysis (FDICA) \\\n    using the natural gradient descent on a Laplace distribution.\n\n    We assume :math:`y_{ijn}` follows a Laplace distribution.\n\n    .. math::\n        p(y_{ijn})\\propto\\exp(|y_{ijn}|)\n\n    Args:\n        step_size (float):\n            A step size of the gradient descent. Default: ``1e-1``.\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to receive (n_channels, n_bins, n_frames)\n            and return (n_channels, n_bins, n_frames).\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``partial(max_flooring, eps=1e-10)``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        is_holonomic (bool):\n            If ``is_holonomic=True``, Holonomic-type update is used.\n            Otherwise, Nonholonomic-type update is used. Default: ``False``.\n        permutation_alignment (bool):\n            If ``permutation_alignment=True``, a permutation solver is used to align\n            estimated spectrograms. Default: ``True``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back``\n            or ``minimal_distortion_principle``. Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the gradient descent if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back and minimal distortion principle. Default: ``0``.\n\n    Examples:\n        Update demixing filters using Holonomic-type update:\n\n        .. code-block:: python\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = \\\n            ...     np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> fdica = NaturalGradLaplaceFDICA(is_holonomic=True)\n            >>> spectrogram_est = fdica(spectrogram_mix, n_iter=1000)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n        Update demixing filters using Nonholonomic-type update:\n\n        .. code-block:: python\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = \\\n            ...     np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> fdica = NaturalGradLaplaceFDICA(is_holonomic=False)\n            >>> spectrogram_est = fdica(spectrogram_mix, n_iter=1000)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n    \"\"\"\n\n    def __init__(\n        self,\n        step_size: float = 1e-1,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        callbacks: Optional[\n            Union[\n                Callable[[\"NaturalGradLaplaceFDICA\"], None],\n                List[Callable[[\"NaturalGradLaplaceFDICA\"], None]],\n            ]\n        ] = None,\n        is_holonomic: bool = False,\n        permutation_alignment: bool = True,\n        scale_restoration: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n    ) -> None:\n        def contrast_fn(y: np.ndarray) -> np.ndarray:\n            r\"\"\"Contrast function.\n\n            Args:\n                y (numpy.ndarray):\n                    The shape is (n_sources, n_bins, n_frames).\n\n            Returns:\n                The shape is (n_sources, n_bins, n_frames).\n            \"\"\"\n            return 2 * np.abs(y)\n\n        def score_fn(y: np.ndarray) -> np.ndarray:\n            r\"\"\"Score function.\n\n            Args:\n                y (numpy.ndarray):\n                    The shape is (n_sources, n_bins, n_frames).\n\n            Returns:\n                The shape is (n_sources, n_bins, n_frames).\n            \"\"\"\n            denom = self.flooring_fn(np.abs(y))\n            return y / denom\n\n        super().__init__(\n            step_size=step_size,\n            contrast_fn=contrast_fn,\n            score_fn=score_fn,\n            flooring_fn=flooring_fn,\n            callbacks=callbacks,\n            is_holonomic=is_holonomic,\n            permutation_alignment=permutation_alignment,\n            scale_restoration=scale_restoration,\n            record_loss=record_loss,\n            reference_id=reference_id,\n        )\n\n    def __repr__(self) -> str:\n        s = \"NaturalGradLaplaceFDICA(\"\n        s += \"step_size={step_size}\"\n        s += \", is_holonomic={is_holonomic}\"\n        s += \", permutation_alignment={permutation_alignment}\"\n        s += \", scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n\nclass AuxLaplaceFDICA(AuxFDICA):\n    r\"\"\"Auxiliary-function-based frequency-domain independent component analysis \\\n    on a Laplace distribution.\n\n    We assume :math:`y_{ijn}` follows a Laplace distribution.\n\n    .. math::\n        p(y_{ijn})\\propto\\exp(|y_{ijn}|)\n\n    Args:\n        spatial_algorithm (str):\n            Algorithm to update demixing filters.\n            Choose ``IP``, ``IP1``, or ``IP2``. Default: ``IP``.\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to receive (n_channels, n_bins, n_frames)\n            and return (n_channels, n_bins, n_frames).\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``partial(max_flooring, eps=1e-10)``.\n        pair_selector (callable, optional):\n            Selector to choose updaing pair in ``IP2`` and ``ISS2``.\n            If ``None`` is given, ``partial(sequential_pair_selector, sort=True)`` is used.\n            Default: ``None``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        permutation_alignment (bool):\n            If ``permutation_alignment=True``, a permutation solver is used to align\n            estimated spectrograms. Default: ``True``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back``\n            or ``minimal_distortion_principle``. Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the demixing filter update if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back and minimal distortion principle. Default: ``0``.\n\n    Examples:\n        Update demixing filters by IP:\n\n        .. code-block:: python\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = \\\n            ...     np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> fdica = AuxLaplaceFDICA(spatial_algorithm=\"IP\")\n            >>> spectrogram_est = fdica(spectrogram_mix, n_iter=1000)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n        Update demixing filters by IP2:\n\n        .. code-block:: python\n\n            >>> from ssspy.utils.select_pair import sequential_pair_selector\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = \\\n            ...     np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> fdica = AuxLaplaceFDICA(\n            ...     spatial_algorithm=\"IP2\",\n            ...     pair_selector=sequential_pair_selector,\n            ... )\n            >>> spectrogram_est = fdica(spectrogram_mix, n_iter=1000)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_algorithm: str = \"IP\",\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        pair_selector: Optional[Callable[[int], Iterable[Tuple[int, int]]]] = None,\n        callbacks: Optional[\n            Union[Callable[[\"AuxLaplaceFDICA\"], None], List[Callable[[\"AuxLaplaceFDICA\"], None]]]\n        ] = None,\n        permutation_alignment: bool = True,\n        scale_restoration: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n    ) -> None:\n        def contrast_fn(y: np.ndarray):\n            r\"\"\"Contrast function.\n\n            Args:\n                y (numpy.ndarray):\n                    The shape is (n_sources, n_bins, n_frames).\n\n            Returns:\n                The shape is (n_sources, n_bins, n_frames).\n            \"\"\"\n            return 2 * np.abs(y)\n\n        def d_contrast_fn(y: np.ndarray):\n            r\"\"\"Partial derivative of score function.\n\n            Args:\n                y (numpy.ndarray):\n                    The shape is (n_sources, n_bins, n_frames).\n\n            Returns:\n                The shape is (n_sources, n_bins, n_frames).\n            \"\"\"\n            return 2 * np.ones_like(y)\n\n        super().__init__(\n            spatial_algorithm=spatial_algorithm,\n            contrast_fn=contrast_fn,\n            d_contrast_fn=d_contrast_fn,\n            flooring_fn=flooring_fn,\n            pair_selector=pair_selector,\n            callbacks=callbacks,\n            permutation_alignment=permutation_alignment,\n            scale_restoration=scale_restoration,\n            record_loss=record_loss,\n            reference_id=reference_id,\n        )\n\n    def __repr__(self) -> str:\n        s = \"AuxLaplaceFDICA(\"\n        s += \"spatial_algorithm={spatial_algorithm}\"\n        s += \", permutation_alignment={permutation_alignment}\"\n        s += \", scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n"
  },
  {
    "path": "ssspy/bss/hva.py",
    "content": "import functools\nimport math\nfrom typing import Callable, List, Optional, Union\n\nimport numpy as np\n\nfrom ..special.flooring import identity, max_flooring\nfrom .admmbss import MaskingADMMBSS\nfrom .pdsbss import MaskingPDSBSS\n\n__all__ = [\n    \"MaskingPDSHVA\",\n    \"MaskingADMMHVA\",\n    \"HVA\",\n]\n\nEPS = 1e-10\n\n\nclass MaskingPDSHVA(MaskingPDSBSS):\n    r\"\"\"Harmonic vector analysis proposed in [#yatabe2021determined]_.\n\n    Args:\n        mu1 (float):\n            Step size. Default: ``1``.\n        mu2 (float):\n            Step size. Default: ``1``.\n        alpha (float):\n            Relaxation parameter (deprecated). Set ``relaxation`` instead.\n        relaxation (float):\n            Relaxation parameter. Default: ``1``.\n        attenuation (float, optional):\n            Attenuation parameter in masking. Default: ``1 / n_sources``.\n        mask_iter (int):\n            Number of iterations in application of cosine shrinkage operator.\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back`` explicitly.\n            Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the update algorithm if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back.\n            Default: ``0``.\n\n    .. [#yatabe2021determined] K. Yatabe and D. Kitamura,\n        \"Determined BSS based on time-frequency masking and its application to \\\n        harmonic vector analysis,\" *IEEE/ACM Trans. ASLP*, vol. 29, pp. 1609-1625, 2021.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        mu1: float = 1,\n        mu2: float = 1,\n        alpha: float = None,\n        relaxation: float = 1,\n        attenuation: Optional[float] = None,\n        mask_iter: int = 1,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        callbacks: Optional[\n            Union[Callable[[\"MaskingPDSHVA\"], None], List[Callable[[\"MaskingPDSHVA\"], None]]]\n        ] = None,\n        scale_restoration: bool = True,\n        record_loss: Optional[bool] = None,\n        reference_id: int = 0,\n    ) -> None:\n        def mask_fn(y: np.ndarray) -> np.ndarray:\n            \"\"\"Masking function to emphasize harmonic components.\n\n            Args:\n                y (np.ndarray):\n                    The shape is (n_sources, n_bins, n_frames).\n\n            Returns:\n                np.ndarray of mask. The shape is (n_sources, n_bins, n_frames).\n            \"\"\"\n            n_sources, n_bins, _ = y.shape\n\n            if self.attenuation is None:\n                self.attenuation = 1 / n_sources\n\n            gamma = self.attenuation\n\n            y = self.flooring_fn(np.abs(y))\n            zeta = np.log(y)\n            zeta_mean = zeta.mean(axis=1, keepdims=True)\n            rho = zeta - zeta_mean\n            nu = np.fft.irfft(rho, axis=1, norm=\"backward\")\n            nu = nu[:, :n_bins]\n            varsigma = np.minimum(1, nu)\n\n            for _ in range(mask_iter):\n                varsigma = (1 - np.cos(math.pi * varsigma)) / 2\n\n            xi = np.fft.irfft(varsigma * nu, axis=1, norm=\"forward\")\n            xi = xi[:, :n_bins]\n            varrho = xi + zeta_mean\n            v = np.exp(2 * varrho)\n            mask = (v / v.sum(axis=0)) ** gamma\n\n            return mask\n\n        super().__init__(\n            mu1=mu1,\n            mu2=mu2,\n            alpha=alpha,\n            relaxation=relaxation,\n            penalty_fn=None,\n            mask_fn=mask_fn,\n            callbacks=callbacks,\n            scale_restoration=scale_restoration,\n            record_loss=record_loss,\n            reference_id=reference_id,\n        )\n\n        self.attenuation = attenuation\n        self.mask_iter = mask_iter\n\n        if flooring_fn is None:\n            self.flooring_fn = identity\n        else:\n            self.flooring_fn = flooring_fn\n\n    def __repr__(self) -> str:\n        s = \"MaskingPDSHVA(\"\n        s += \"mu1={mu1}, mu2={mu2}\"\n        s += \", relaxation={relaxation}\"\n\n        if self.attenuation is not None:\n            s += \", attenuation={attenuation}\"\n\n        s += \", mask_iter={mask_iter}\"\n        s += \", scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n\nclass MaskingADMMHVA(MaskingADMMBSS):\n    \"\"\"Harmonic vector analysis using ADMM with masking.\n\n    Args:\n        rho (float): Penalty parameter. Default: ``1``.\n        alpha (float): Relaxation parameter (deprecated). Set ``relaxation`` instead.\n        relaxation (float): Relaxation parameter. Default: ``1``.\n        attenuation (float, optional): Attenuation parameter.\n        mask_iter (int): Number of iterations in application of cosine shrinkage operator.\n        flooring_fn (callable, optional): A flooring function for numerical stability.\n            This function is expected to receive (n_channels, n_bins, n_frames)\n            and return (n_channels, n_bins, n_frames). If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``partial(max_flooring, eps=1e-10)``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        scale_restoration (bool or str): Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back`` explicitly.\n            Default: ``True``.\n        record_loss (bool, optional): Record the loss at each iteration of the update algorithm if\n            ``record_loss=True``. Default: ``None``.\n        reference_id (int): Reference channel for projection back. Default: ``0``.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        rho: float = 1,\n        alpha: float = None,\n        relaxation: float = 1,\n        attenuation: Optional[float] = None,\n        mask_iter: int = 1,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        callbacks: Optional[\n            Union[Callable[[\"MaskingADMMHVA\"], None], List[Callable[[\"MaskingADMMHVA\"], None]]]\n        ] = None,\n        scale_restoration: bool = True,\n        record_loss: Optional[bool] = None,\n        reference_id: int = 0,\n    ) -> None:\n        def mask_fn(y: np.ndarray) -> np.ndarray:\n            \"\"\"Masking function to emphasize harmonic components.\n\n            Args:\n                y (np.ndarray):\n                    The shape is (n_sources, n_bins, n_frames).\n\n            Returns:\n                np.ndarray of mask. The shape is (n_sources, n_bins, n_frames).\n            \"\"\"\n            n_sources, n_bins, _ = y.shape\n\n            if self.attenuation is None:\n                self.attenuation = 1 / n_sources\n\n            gamma = self.attenuation\n\n            y = self.flooring_fn(np.abs(y))\n            zeta = np.log(y)\n            zeta_mean = zeta.mean(axis=1, keepdims=True)\n            rho = zeta - zeta_mean\n            nu = np.fft.irfft(rho, axis=1, norm=\"backward\")\n            nu = nu[:, :n_bins]\n            varsigma = np.minimum(1, nu)\n\n            for _ in range(mask_iter):\n                varsigma = (1 - np.cos(math.pi * varsigma)) / 2\n\n            xi = np.fft.irfft(varsigma * nu, axis=1, norm=\"forward\")\n            xi = xi[:, :n_bins]\n            varrho = xi + zeta_mean\n            v = np.exp(2 * varrho)\n            mask = (v / v.sum(axis=0)) ** gamma\n\n            return mask\n\n        super().__init__(\n            rho=rho,\n            alpha=alpha,\n            relaxation=relaxation,\n            penalty_fn=None,\n            mask_fn=mask_fn,\n            callbacks=callbacks,\n            scale_restoration=scale_restoration,\n            record_loss=record_loss,\n            reference_id=reference_id,\n        )\n\n        self.attenuation = attenuation\n        self.mask_iter = mask_iter\n\n        if flooring_fn is None:\n            self.flooring_fn = identity\n        else:\n            self.flooring_fn = flooring_fn\n\n    def __repr__(self) -> str:\n        s = \"MaskingADMMHVA(\"\n        s += \"rho={rho}\"\n        s += \", relaxation={relaxation}\"\n\n        if self.attenuation is not None:\n            s += \", attenuation={attenuation}\"\n\n        s += \", mask_iter={mask_iter}\"\n        s += \", scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n\nclass HVA(MaskingPDSHVA):\n    \"\"\"Alias of MaskingPDSHVA.\"\"\"\n\n    def __repr__(self) -> str:\n        s = \"HVA(\"\n        s += \"mu1={mu1}, mu2={mu2}\"\n        s += \", relaxation={relaxation}\"\n\n        if self.attenuation is not None:\n            s += \", attenuation={attenuation}\"\n\n        s += \", mask_iter={mask_iter}\"\n        s += \", scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n"
  },
  {
    "path": "ssspy/bss/ica.py",
    "content": "from typing import Callable, List, Optional, Union\n\nimport numpy as np\n\nfrom ..transform import whiten\nfrom .base import IterativeMethodBase\n\n__all__ = [\"GradICA\", \"NaturalGradICA\", \"FastICA\", \"GradLaplaceICA\", \"NaturalGradLaplaceICA\"]\n\n\nclass GradICABase(IterativeMethodBase):\n    r\"\"\"Base class of independent component analysis (ICA) using the gradient descent.\n\n    Args:\n        step_size (float):\n            A step size of the gradient descent. Default: ``1e-1``.\n        contrast_fn (callable):\n            A contrast function which corresponds to :math:`-\\log p(y_{tn})`.\n            This function is expected to receive (n_channels, n_samples)\n            and return (n_channels, n_samples).\n        score_fn (callable):\n            A score function which corresponds to the partial derivative of the contrast function.\n            This function is expected to receive (n_channels, n_samples)\n            and return (n_channels, n_samples).\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        record_loss (bool):\n            Record the loss at each iteration of the gradient descent if ``record_loss=True``.\n            Default: ``True``.\n    \"\"\"\n\n    def __init__(\n        self,\n        step_size: float = 1e-1,\n        contrast_fn: Callable[[np.ndarray], np.ndarray] = None,\n        score_fn: Callable[[np.ndarray], np.ndarray] = None,\n        callbacks: Optional[\n            Union[Callable[[\"GradICABase\"], None], List[Callable[[\"GradICABase\"], None]]]\n        ] = None,\n        record_loss: bool = True,\n    ) -> None:\n        super().__init__(callbacks=callbacks, record_loss=record_loss)\n\n        self.step_size = step_size\n\n        if contrast_fn is None:\n            raise ValueError(\"Specify contrast function.\")\n        else:\n            self.contrast_fn = contrast_fn\n\n        if score_fn is None:\n            raise ValueError(\"Specify score function.\")\n        else:\n            self.score_fn = score_fn\n\n        self.input = None\n\n    def __call__(\n        self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs\n    ) -> np.ndarray:\n        r\"\"\"Separate a time-domain multichannel signal.\n\n        Args:\n            input (numpy.ndarray):\n                Mixture signal in time-domain.\n                The shape is (n_channels, n_samples).\n            n_iter (int):\n                Number of iterations of demixing filter updates.\n                Default: ``100``.\n            initial_call (bool):\n                If ``True``, perform callbacks (and computation of loss if necessary)\n                before iterations.\n\n        Returns:\n            numpy.ndarray of separated signal in time-domain.\n            The shape is (n_sources, n_samples).\n        \"\"\"\n        self.input = input.copy()\n\n        self._reset(**kwargs)\n\n        super().__call__(n_iter=n_iter, initial_call=initial_call)\n\n        self.output = self.separate(self.input, demix_filter=self.demix_filter)\n\n        return self.output\n\n    def __repr__(self) -> str:\n        s = \"GradICA(\"\n        s += \"step_size={step_size}\"\n        s += \", record_loss={record_loss}\"\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def _reset(self, **kwargs) -> None:\n        r\"\"\"Reset attributes by given keyword arguments.\n\n        Args:\n            kwargs:\n                Keyword arguments to set as attributes of ICA.\n        \"\"\"\n        assert self.input is not None, \"Specify data!\"\n\n        for key in kwargs.keys():\n            setattr(self, key, kwargs[key])\n\n        X = self.input\n\n        n_channels, n_samples = X.shape\n        n_sources = n_channels  # n_channels == n_sources\n\n        self.n_sources, self.n_channels = n_sources, n_channels\n        self.n_samples = n_samples\n\n        if not hasattr(self, \"demix_filter\"):\n            W = np.eye(n_sources, n_channels, dtype=np.float64)\n        else:\n            if self.demix_filter is None:\n                W = None\n            else:\n                # To avoid overwriting ``demix_filter`` given by keyword arguments.\n                W = self.demix_filter.copy()\n\n        self.demix_filter = W\n        self.output = self.separate(X, demix_filter=W)\n\n    def update_once(self) -> None:\n        r\"\"\"Update demixing filters once.\"\"\"\n        raise NotImplementedError(\"Implement 'update_once' method.\")\n\n    def separate(self, input: np.ndarray, demix_filter: np.ndarray) -> np.ndarray:\n        r\"\"\"Separate ``input`` using ``demixing_filter``.\n\n        .. math::\n            \\boldsymbol{y}_{t}\n            = \\boldsymbol{W}\\boldsymbol{x}_{t}\n\n        Args:\n            input (numpy.ndarray):\n                The mixture signal in time-domain.\n                The shape is (n_channels, n_samples).\n            demix_filter (numpy.ndarray):\n                The demixing filters to separate ``input``.\n                The shape is (n_sources, n_channels).\n\n        Returns:\n            numpy.ndarray of the separated signal in time-domain.\n            The shape is (n_sources, n_samples).\n        \"\"\"\n        output = demix_filter @ input\n\n        return output\n\n    def compute_loss(self) -> float:\n        r\"\"\"Compute loss :math:`\\mathcal{L}`.\n\n        :math:`\\mathcal{L}` is given as follows:\n\n        .. math::\n            \\mathcal{L} \\\n            &= \\frac{1}{T}\\sum_{t,n}G(y_{tn}) \\\n                - \\log|\\det\\boldsymbol{W}| \\\\\n            G(y_{tn}) \\\n            &= - \\log p(y_{tn})\n\n        Returns:\n            Computed loss.\n        \"\"\"\n        X, W = self.input, self.demix_filter\n        Y = self.separate(X, demix_filter=W)  # (n_channels, n_samples)\n        logdet = self.compute_logdet(W)\n        G = self.contrast_fn(Y)\n        loss = np.sum(np.mean(G, axis=1)) - logdet\n        loss = loss.item()\n\n        return loss\n\n    def compute_logdet(self, demix_filter: np.ndarray) -> np.ndarray:\n        r\"\"\"Compute log-determinant of demixing filter\n\n        Args:\n            demix_filter (numpy.ndarray):\n                Demixing filter with shape of (n_sources, n_channels).\n\n        Returns:\n            numpy.ndarray of computed log-determinant value.\n            The shape is (n_bins,).\n        \"\"\"\n        _, logdet = np.linalg.slogdet(demix_filter)  # (n_bins,)\n\n        return logdet\n\n\nclass FastICABase(IterativeMethodBase):\n    r\"\"\"Base class of fast independent component analysis (FastICA).\n\n    Args:\n        contrast_fn (callable):\n            A contrast function which corresponds to :math:`-\\log p(y_{tn})`.\n            This function is expected to receive (n_channels, n_samples)\n            and return (n_channels, n_samples).\n        score_fn (callable):\n            A score function which corresponds to the partial derivative of the contrast function.\n            This function is expected to receive (n_channels, n_samples)\n            and return (n_channels, n_samples).\n        d_score_fn (callable):\n            A partial derivative of the score function.\n            This function is expected to return the same shape tensor as the input.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        record_loss (bool):\n            Record the loss at each of the fixed-point iteration if ``record_loss=True``.\n            Default: ``True``.\n    \"\"\"\n\n    def __init__(\n        self,\n        contrast_fn: Callable[[np.ndarray], np.ndarray] = None,\n        score_fn: Callable[[np.ndarray], np.ndarray] = None,\n        d_score_fn: Callable[[np.ndarray], np.ndarray] = None,\n        callbacks: Optional[\n            Union[Callable[[\"FastICABase\"], None], List[Callable[[\"FastICABase\"], None]]]\n        ] = None,\n        record_loss: bool = True,\n    ) -> None:\n        super().__init__(callbacks=callbacks, record_loss=record_loss)\n\n        if contrast_fn is None:\n            raise ValueError(\"Specify contrast function.\")\n        else:\n            self.contrast_fn = contrast_fn\n\n        if score_fn is None:\n            raise ValueError(\"Specify score function.\")\n        else:\n            self.score_fn = score_fn\n\n        if d_score_fn is None:\n            raise ValueError(\"Specify derivative of score function.\")\n        else:\n            self.d_score_fn = d_score_fn\n\n        self.input = None\n\n    def __call__(\n        self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs\n    ) -> np.ndarray:\n        r\"\"\"Separate a time-domain multichannel signal.\n\n        Args:\n            input (numpy.ndarray):\n                Mixture signal in time-domain.\n                The shape is (n_channels, n_samples).\n            n_iter (int):\n                Number of iterations of demixing filter updates.\n                Default: ``100``.\n            initial_call (bool):\n                If ``True``, perform callbacks (and computation of loss if necessary)\n                before iterations.\n\n        Returns:\n            numpy.ndarray of the separated signal in time-domain.\n            The shape is (n_sources, n_samples).\n        \"\"\"\n        self.input = input.copy()\n\n        self._reset(**kwargs)\n\n        super().__call__(n_iter=n_iter, initial_call=initial_call)\n\n        self.output = self.separate(\n            self.whitened_input, demix_filter=self.demix_filter, use_whitening=False\n        )\n\n        return self.output\n\n    def __repr__(self) -> str:\n        s = \"FastICA(\"\n        s += \"record_loss={record_loss}\"\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def _reset(self, **kwargs) -> None:\n        r\"\"\"Reset attributes by given keyword arguments.\n\n        Args:\n            kwargs:\n                Keyword arguments to set as attributes of ICA.\n        \"\"\"\n        assert self.input is not None, \"Specify data!\"\n\n        for key in kwargs.keys():\n            setattr(self, key, kwargs[key])\n\n        X = self.input\n\n        n_channels, n_samples = X.shape\n        n_sources = n_channels  # n_channels == n_sources\n\n        self.n_sources, self.n_channels = n_sources, n_channels\n        self.n_samples = n_samples\n\n        if not hasattr(self, \"demix_filter\"):\n            W = np.eye(n_sources, n_channels, dtype=np.float64)\n        else:\n            if self.demix_filter is None:\n                W = None\n            else:\n                # To avoid overwriting ``demix_filter`` given by keyword arguments.\n                W = self.demix_filter.copy()\n\n        Z = whiten(X)\n\n        self.whitened_input = Z\n        self.demix_filter = W\n\n        self.output = self.separate(Z, demix_filter=W, use_whitening=False)\n\n    def update_once(self) -> None:\n        r\"\"\"Update demixing filters once.\"\"\"\n        raise NotImplementedError(\"Implement 'update_once' method.\")\n\n    def separate(\n        self, input: np.ndarray, demix_filter: np.ndarray, use_whitening: bool = True\n    ) -> np.ndarray:\n        r\"\"\"Separate ``input`` using ``demixing_filter``.\n\n        If ``use_whitening=True``, we apply whitening to input mixture :math:`\\boldsymbol{x}_{t}`.\n\n        .. math::\n            \\boldsymbol{y}_{t}\n            &= \\boldsymbol{W}\\boldsymbol{z}_{t}, \\\\\n            \\boldsymbol{z}_{t}\n            &= \\boldsymbol{\\Lambda}^{-\\frac{1}{2}} \\\n            \\boldsymbol{\\Gamma}^{\\mathsf{T}}\\boldsymbol{x}_{t}, \\\\\n            \\boldsymbol{\\Lambda}\n            &:= \\mathrm{diag}(\\lambda_{1},\\ldots,\\lambda_{m},\\ldots,\\lambda_{M}) \\\n            \\in\\mathbb{R}^{M\\times M}, \\\\\n            \\boldsymbol{\\Gamma}\n            &:= (\\boldsymbol{\\gamma}_{1}, \\ldots,\n            \\boldsymbol{\\gamma}_{m}, \\ldots, \\boldsymbol{\\gamma}_{M}) \\\n            \\in\\mathbb{R}^{M\\times M},\n\n        where :math:`\\lambda_{m}` and :math:`\\boldsymbol{\\gamma}_{m}` are\n        an eigenvalue and eigenvector of\n        :math:`\\sum_{t}\\boldsymbol{x}_{t}\\boldsymbol{x}_{t}^{\\mathsf{T}}`,\n        respectively.\n\n        Otherwise (``use_whitening=False``), we do not apply whitening.\n\n        .. math::\n            \\boldsymbol{y}_{t}\n            = \\boldsymbol{W}\\boldsymbol{x}_{t}.\n\n        Args:\n            input (numpy.ndarray):\n                The mixture signal in time-domain.\n                The shape is (n_channels, n_samples).\n            demix_filter (numpy.ndarray):\n                The demixing filters to separate ``input``.\n                The shape is (n_sources, n_channels).\n            use_whitening (bool):\n                If ``use_whitening=True``, use_whitening (sphering) is applied to ``input``.\n                Default: ``True``.\n\n        Returns:\n            numpy.ndarray of the separated signal in time-domain.\n            The shape is (n_sources, n_samples).\n        \"\"\"\n        if use_whitening:\n            whitened_input = whiten(input)\n        else:\n            whitened_input = input\n\n        output = demix_filter @ whitened_input\n\n        return output\n\n    def compute_loss(self) -> float:\n        r\"\"\"Compute loss :math:`\\mathcal{L}`.\n\n        :math:`\\mathcal{L}` is given as follows:\n\n        .. math::\n            \\mathcal{L} \\\n            &= \\frac{1}{T}\\sum_{t,n}G(y_{tn}) \\\\\n            G(y_{tn}) \\\n            &= - \\log p(y_{tn})\n\n        Returns:\n            Computed loss.\n        \"\"\"\n        Z, W = self.whitened_input, self.demix_filter\n        Y = self.separate(Z, demix_filter=W, use_whitening=False)\n\n        loss = np.mean(self.contrast_fn(Y), axis=-1)\n        loss = loss.sum().item()\n\n        return loss\n\n\nclass GradICA(GradICABase):\n    r\"\"\"Independent component analysis (ICA) using the gradient descent.\n\n    Args:\n        step_size (float):\n            A step size of the gradient descent. Default: ``1e-1``.\n        contrast_fn (callable):\n            A contrast function which corresponds to :math:`-\\log p(y_{tn})`.\n            This function is expected to receive (n_channels, n_samples)\n            and return (n_channels, n_samples).\n        score_fn (callable):\n            A score function which corresponds to the partial derivative of the contrast function.\n            This function is expected to receive (n_channels, n_samples)\n            and return (n_channels, n_samples).\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        is_holonomic (bool):\n            If ``is_holonomic=True``, Holonomic-type update is used.\n            Otherwise, Nonholonomic-type update is used. Default: ``False``.\n        record_loss (bool):\n            Record the loss at each iteration of the gradient descent if ``record_loss=True``.\n            Default: ``True``.\n\n    Examples:\n        Update demixing filters using Holonomic-type update:\n\n        .. code-block:: python\n\n            >>> def contrast_fn(y):\n            ...     return np.abs(y)\n\n            >>> def score_fn(y):\n            ...     return np.sign(y)\n\n            >>> n_channels, n_samples = 2, 160000\n            >>> waveform_mix = np.random.randn(n_channels, n_samples)\n\n            >>> ica = GradICA(\n            ...     contrast_fn=contrast_fn,\n            ...     score_fn=score_fn,\n            ...     is_holonomic=True,\n            ... )\n            >>> waveform_est = ica(waveform_mix, n_iter=1000)\n            >>> print(waveform_mix.shape, waveform_est.shape)\n            (2, 160000), (2, 160000)\n\n        Update demixing filters using Nonholonomic-type update:\n\n        .. code-block:: python\n\n            >>> def contrast_fn(y):\n            ...     return np.abs(y)\n\n            >>> def score_fn(y):\n            ...     return np.sign(y)\n\n            >>> n_channels, n_samples = 2, 160000\n            >>> waveform_mix = np.random.randn(n_channels, n_samples)\n\n            >>> ica = GradICA(\n            ...     contrast_fn=contrast_fn,\n            ...     score_fn=score_fn,\n            ...     is_holonomic=False,\n            ... )\n            >>> waveform_est = ica(waveform_mix, n_iter=1000)\n            >>> print(waveform_mix.shape, waveform_est.shape)\n            (2, 160000), (2, 160000)\n    \"\"\"\n\n    def __init__(\n        self,\n        step_size: float = 1e-1,\n        contrast_fn: Callable[[np.ndarray], np.ndarray] = None,\n        score_fn: Callable[[np.ndarray], np.ndarray] = None,\n        callbacks: Optional[\n            Union[Callable[[\"GradICA\"], None], List[Callable[[\"GradICA\"], None]]]\n        ] = None,\n        is_holonomic: bool = False,\n        record_loss: bool = True,\n    ) -> None:\n        super().__init__(\n            step_size=step_size,\n            contrast_fn=contrast_fn,\n            score_fn=score_fn,\n            callbacks=callbacks,\n            record_loss=record_loss,\n        )\n\n        self.is_holonomic = is_holonomic\n\n    def __repr__(self) -> str:\n        s = \"GradICA(\"\n        s += \"step_size={step_size}\"\n        s += \", is_holonomic={is_holonomic}\"\n        s += \", record_loss={record_loss}\"\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def update_once(self) -> None:\n        r\"\"\"Update demixing filters once using the gradient descent.\n\n        If ``is_holonomic=True``, demixing filters are updated as follows:\n\n        .. math::\n            \\boldsymbol{W}\n            \\leftarrow\\boldsymbol{W} - \\eta\\left(\\frac{1}{T}\\sum_{t} \\\n            \\boldsymbol{\\phi}(\\boldsymbol{y}_{t})\\boldsymbol{y}_{t}^{\\mathsf{T}} \\\n            -\\boldsymbol{I}\\right)\\boldsymbol{W}^{-\\mathsf{T}},\n\n        where\n\n        .. math::\n            \\boldsymbol{\\phi}(\\boldsymbol{y}_{t})\n            &= \\left(\\phi(y_{t1}),\\ldots,\\phi(y_{tN})\\right)^{\\mathsf{T}}\\in\\mathbb{R}^{N}, \\\\\n            \\phi(y_{tn})\n            &= \\frac{\\partial G(y_{tn})}{\\partial y_{tn}}, \\\\\n            G(y_{tn})\n            &= -\\log p(y_{tn}).\n\n        Otherwise (``is_holonomic=False``),\n\n        .. math::\n            \\boldsymbol{W}\n            \\leftarrow\\boldsymbol{W} - \\eta\\cdot\\mathrm{offdiag}\\left(\\frac{1}{T}\\sum_{t} \\\n            \\boldsymbol{\\phi}(\\boldsymbol{y}_{t})\\boldsymbol{y}_{t}^{\\mathsf{T}}\\right) \\\n            \\boldsymbol{W}^{-\\mathsf{T}}.\n        \"\"\"\n        X, W = self.input, self.demix_filter\n        Y = self.separate(X, demix_filter=W)\n\n        Phi = self.score_fn(Y)\n        PhiY = np.mean(Phi[:, np.newaxis, :] * Y[np.newaxis, :, :], axis=-1)\n        W_inv = np.linalg.inv(W)\n        W_inv_trans = W_inv.transpose(1, 0)\n        eye = np.eye(self.n_sources)\n\n        if self.is_holonomic:\n            delta = (PhiY - eye) @ W_inv_trans\n        else:\n            delta = ((1 - eye) * PhiY) @ W_inv_trans\n\n        W = W - self.step_size * delta\n\n        Y = self.separate(X, demix_filter=W)\n\n        self.demix_filter = W\n        self.output = Y\n\n\nclass NaturalGradICA(GradICABase):\n    r\"\"\"Independent component analysis (ICA) using the natural gradient descent [#amari1995new]_.\n\n    Args:\n        step_size (float):\n            A step size of the gradient descent. Default: ``1e-1``.\n        contrast_fn (callable):\n            A contrast function which corresponds to :math:`-\\log p(y_{tn})`.\n            This function is expected to receive (n_channels, n_samples)\n            and return (n_channels, n_samples).\n        score_fn (callable):\n            A score function which corresponds to the partial derivative of the contrast function.\n            This function is expected to receive (n_channels, n_samples)\n            and return (n_channels, n_samples).\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        is_holonomic (bool):\n            If ``is_holonomic=True``, Holonomic-type update is used.\n            Otherwise, Nonholonomic-type update is used. Default: ``False``.\n        record_loss (bool):\n            Record the loss at each iteration of the gradient descent if ``record_loss=True``.\n            Default: ``True``.\n\n    Examples:\n        Update demixing filters using Holonomic-type update:\n\n        .. code-block:: python\n\n            >>> def contrast_fn(y):\n            ...     return np.abs(y)\n\n            >>> def score_fn(y):\n            ...     return np.sign(y)\n\n            >>> n_channels, n_samples = 2, 160000\n            >>> waveform_mix = np.random.randn(n_channels, n_samples)\n\n            >>> ica = NaturalGradICA(\n            ...     contrast_fn=contrast_fn,\n            ...     score_fn=score_fn,\n            ...     is_holonomic=True,\n            ... )\n            >>> waveform_est = ica(waveform_mix, n_iter=100)\n            >>> print(waveform_mix.shape, waveform_est.shape)\n            (2, 160000), (2, 160000)\n\n        Update demixing filters using Nonholonomic-type update:\n\n        .. code-block:: python\n\n            >>> def contrast_fn(y):\n            ...     return np.abs(y)\n\n            >>> def score_fn(y):\n            ...     return np.sign(y)\n\n            >>> n_channels, n_samples = 2, 160000\n            >>> waveform_mix = np.random.randn(n_channels, n_samples)\n\n            >>> ica = NaturalGradICA(\n            ...     contrast_fn=contrast_fn,\n            ...     score_fn=score_fn,\n            ...     is_holonomic=False,\n            ... )\n            >>> waveform_est = ica(waveform_mix, n_iter=100)\n            >>> print(waveform_mix.shape, waveform_est.shape)\n            (2, 160000), (2, 160000)\n\n    .. [#amari1995new] S. Amari, A. Cichocki, and H. H. Yang,\n        \"A new learning algorithm for blind signal separation,\"\n        in *Proc. NIPS.*, pp. 757-763, 1996.\n    \"\"\"\n\n    def __init__(\n        self,\n        step_size: float = 1e-1,\n        contrast_fn: Callable[[np.ndarray], np.ndarray] = None,\n        score_fn: Callable[[np.ndarray], np.ndarray] = None,\n        callbacks: Optional[\n            Union[Callable[[\"GradICA\"], None], List[Callable[[\"GradICA\"], None]]]\n        ] = None,\n        is_holonomic: bool = False,\n        record_loss: bool = True,\n    ) -> None:\n        super().__init__(\n            step_size=step_size,\n            contrast_fn=contrast_fn,\n            score_fn=score_fn,\n            callbacks=callbacks,\n            record_loss=record_loss,\n        )\n\n        self.is_holonomic = is_holonomic\n\n    def __repr__(self) -> str:\n        s = \"NaturalGradICA(\"\n        s += \"step_size={step_size}\"\n        s += \", is_holonomic={is_holonomic}\"\n        s += \", record_loss={record_loss}\"\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def update_once(self) -> None:\n        r\"\"\"Update demixing filters once using the natural gradient descent.\n\n        If ``is_holonomic=True``, demixing filters are updated as follows:\n\n        .. math::\n            \\boldsymbol{W}\n            \\leftarrow\\boldsymbol{W} - \\eta\\left(\\frac{1}{T}\\sum_{t} \\\n            \\boldsymbol{\\phi}(\\boldsymbol{y}_{t})\\boldsymbol{y}_{t}^{\\mathsf{T}} \\\n            -\\boldsymbol{I}\\right)\\boldsymbol{W},\n\n        where\n\n        .. math::\n            \\boldsymbol{\\phi}(\\boldsymbol{y}_{t})\n            &= \\left(\\phi(y_{t1}),\\ldots,\\phi(y_{tN})\\right)^{\\mathsf{T}}\\in\\mathbb{R}^{N}, \\\\\n            \\phi(y_{tn})\n            &= \\frac{\\partial G(y_{tn})}{\\partial y_{tn}}, \\\\\n            G(y_{tn})\n            &= -\\log p(y_{tn}).\n\n        Otherwise (``is_holonomic=False``),\n\n        .. math::\n            \\boldsymbol{W}\n            \\leftarrow\\boldsymbol{W} - \\eta\\cdot\\mathrm{offdiag}\\left(\\frac{1}{T}\\sum_{t} \\\n            \\boldsymbol{\\phi}(\\boldsymbol{y}_{t})\\boldsymbol{y}_{t}^{\\mathsf{T}}\\right) \\\n            \\boldsymbol{W}.\n        \"\"\"\n        X, W = self.input, self.demix_filter\n        Y = self.separate(X, demix_filter=W)\n\n        Phi = self.score_fn(Y)\n        PhiY = np.mean(Phi[:, np.newaxis, :] * Y[np.newaxis, :, :], axis=-1)\n        eye = np.eye(self.n_sources)\n\n        if self.is_holonomic:\n            delta = (PhiY - eye) @ W\n        else:\n            delta = ((1 - eye) * PhiY) @ W\n\n        W = W - self.step_size * delta\n\n        Y = self.separate(X, demix_filter=W)\n\n        self.demix_filter = W\n        self.output = Y\n\n\nclass FastICA(FastICABase):\n    r\"\"\"Fast independent component analysis (FastICA) [#hyvarinen1999fast]_.\n\n    In FastICA, a whitening (sphering) is applied to input signal.\n\n    .. math::\n        \\boldsymbol{z}_{t}\n        &= \\boldsymbol{\\Lambda}^{-\\frac{1}{2}} \\\n        \\boldsymbol{\\Gamma}^{\\mathsf{T}}\\boldsymbol{x}_{t}, \\\\\n        \\boldsymbol{\\Lambda}\n        &:= \\mathrm{diag}(\\lambda_{1},\\ldots,\\lambda_{m},\\ldots,\\lambda_{M}) \\\n        \\in\\mathbb{R}^{M\\times M}, \\\\\n        \\boldsymbol{\\Gamma}\n        &:= (\\boldsymbol{\\gamma}_{1}, \\ldots,\n        \\boldsymbol{\\gamma}_{m}, \\ldots, \\boldsymbol{\\gamma}_{M}) \\\n        \\in\\mathbb{R}^{M\\times M},\n\n    where :math:`\\lambda_{m}` and :math:`\\boldsymbol{\\gamma}_{m}` are\n    an eigenvalue and eigenvector of\n    :math:`\\sum_{t}\\boldsymbol{x}_{t}\\boldsymbol{x}_{t}^{\\mathsf{T}}`,\n    respectively.\n\n    Furthermore, :math:`\\boldsymbol{W}` is constrained to be orthogonal.\n\n    .. math::\n        \\boldsymbol{W}\\boldsymbol{W}^{\\mathsf{T}}\n        = \\boldsymbol{I}\n\n    Args:\n        contrast_fn (callable):\n            A contrast function which corresponds to :math:`-\\log p(y_{tn})`.\n            This function is expected to receive (n_channels, n_samples)\n            and return (n_channels, n_samples).\n        score_fn (callable):\n            A score function which corresponds to the partial derivative of the contrast function.\n            This function is expected to receive (n_channels, n_samples)\n            and return (n_channels, n_samples).\n        d_score_fn (callable):\n            A partial derivative of the score function.\n            This function is expected to return the same shape tensor as the input.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        record_loss (bool):\n            Record the loss at each of the fixed-point iteration if ``record_loss=True``.\n            Default: ``True``.\n\n    Examples:\n        .. code-block:: python\n\n            >>> def contrast_fn(y):\n            ...     return np.log(1 + np.exp(y))\n\n            >>> def score_fn(y):\n            ...     return 1 / (1 + np.exp(-y))\n\n            >>> def d_score_fn(y):\n            ...     sigmoid_y = 1 / (1 + np.exp(-y))\n            ...     return sigmoid_y * (1 - sigmoid_y)\n\n            >>> n_channels, n_samples = 2, 160000\n            >>> waveform_mix = np.random.randn(n_channels, n_samples)\n\n            >>> ica = FastICA(contrast_fn=contrast_fn, score_fn=score_fn, d_score_fn=d_score_fn)\n            >>> waveform_est = ica(waveform_mix, n_iter=10)\n            >>> print(waveform_mix.shape, waveform_est.shape)\n            (2, 160000), (2, 160000)\n\n    .. [#hyvarinen1999fast] A. Hyvärinen,\n        \"Fast and robust fixed-point algorithms for independent component analysis,\"\n        *IEEE Trans. on Neural Netw.*, vol. 10, no. 3, pp. 626-634, 1999.\n    \"\"\"\n\n    def __init__(\n        self,\n        contrast_fn: Callable[[np.ndarray], np.ndarray] = None,\n        score_fn: Callable[[np.ndarray], np.ndarray] = None,\n        d_score_fn: Callable[[np.ndarray], np.ndarray] = None,\n        callbacks: Optional[\n            Union[Callable[[\"FastICA\"], None], List[Callable[[\"FastICA\"], None]]]\n        ] = None,\n        record_loss: bool = True,\n    ) -> None:\n        super().__init__(\n            contrast_fn=contrast_fn,\n            score_fn=score_fn,\n            d_score_fn=d_score_fn,\n            callbacks=callbacks,\n            record_loss=record_loss,\n        )\n\n    def update_once(self) -> None:\n        r\"\"\"Update demixing filters once using the fixed-point iteration algorithm.\n\n        For :math:`n=1,\\dots,N`, the demixing flter :math:`\\boldsymbol{w}_{n}`\n        is updated sequentially,\n\n        .. math::\n            y_{tn}\n            &=\\boldsymbol{w}_{n}^{\\mathsf{T}}\\boldsymbol{z}_{t}, \\\\\n            \\boldsymbol{w}_{n}^{+}\n            &\\leftarrow \\frac{1}{T}\\sum_{t}\\phi(y_{tn})\\boldsymbol{z}_{tn} \\\n            - \\frac{1}{T}\\sum_{t}\\frac{\\partial\\phi(y_{tn})}{\\partial y_{tn}} \\\n            \\boldsymbol{w}_{n}, \\\\\n            \\boldsymbol{w}_{n}^{+}\n            &\\leftarrow\\boldsymbol{w}_{n}^{+} \\\n            - \\sum_{n'=1}^{n-1}\\boldsymbol{w}_{n'}^{\\mathsf{T}}\\boldsymbol{w}_{n}^{+} \\\n            \\boldsymbol{w}_{n}^{+}, \\\\\n            \\boldsymbol{w}_{n}\n            &\\leftarrow \\frac{\\boldsymbol{w}_{n}^{+}}{\\|\\boldsymbol{w}_{n}^{+}\\|}.\n        \"\"\"\n        Z, W = self.whitened_input, self.demix_filter\n\n        for src_idx in range(self.n_sources):\n            w_n = W[src_idx]  # (n_channels,)\n            y_n = w_n @ Z  # (n_samples,)\n            Gw_n = np.mean(self.d_score_fn(y_n), axis=-1) * w_n\n            Gz = np.mean(self.score_fn(y_n) * Z, axis=-1)\n            w_n = Gw_n - Gz\n\n            if src_idx > 0:\n                W_n = W[:src_idx]  # (src_idx - 1, n_channels)\n                scale = np.sum(W_n * w_n, axis=-1, keepdims=True)\n                w_n = w_n - np.sum(scale * W_n, axis=0)\n\n            norm = np.linalg.norm(w_n)\n            W[src_idx] = w_n / norm\n\n        Y = self.separate(Z, demix_filter=W, use_whitening=False)\n\n        self.demix_filter = W\n        self.output = Y\n\n\nclass GradLaplaceICA(GradICA):\n    r\"\"\"Independent component analysis (ICA) using the gradient descent on a Laplace distribution.\n\n    We assume :math:`y_{ijn}` follows a Laplace distribution.\n\n    .. math::\n        p(y_{ijn})\\propto\\exp(|y_{ijn}|)\n\n    Args:\n        step_size (float):\n            A step size of the gradient descent. Default: ``1e-1``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        is_holonomic (bool):\n            If ``is_holonomic=True``, Holonomic-type update is used.\n            Otherwise, Nonholonomic-type update is used. Default: ``False``.\n        record_loss (bool):\n            Record the loss at each iteration of the gradient descent \\\n            if ``record_loss=True``.\n            Default: ``True``.\n\n    Examples:\n        Update demixing filters using Holonomic-type update:\n\n        .. code-block:: python\n\n            >>> n_channels, n_samples = 2, 160000\n            >>> waveform_mix = np.random.randn(n_channels, n_samples)\n\n            >>> ica = GradLaplaceICA(is_holonomic=True)\n            >>> waveform_est = ica(waveform_mix, n_iter=1000)\n            >>> print(waveform_mix.shape, waveform_est.shape)\n            (2, 160000), (2, 160000)\n\n        Update demixing filters using Nonholonomic-type update:\n\n        .. code-block:: python\n\n            >>> n_channels, n_samples = 2, 160000\n            >>> waveform_mix = np.random.randn(n_channels, n_samples)\n\n            >>> ica = GradLaplaceICA(is_holonomic=False)\n            >>> waveform_est = ica(waveform_mix, n_iter=1000)\n            >>> print(waveform_mix.shape, waveform_est.shape)\n            (2, 160000), (2, 160000)\n    \"\"\"\n\n    def __init__(\n        self,\n        step_size: float = 1e-1,\n        callbacks: Optional[\n            Union[Callable[[\"GradLaplaceICA\"], None], List[Callable[[\"GradLaplaceICA\"], None]]]\n        ] = None,\n        is_holonomic: bool = False,\n        record_loss: bool = True,\n    ) -> None:\n        def contrast_fn(input):\n            return np.abs(input)\n\n        def score_fn(input):\n            return np.sign(input)\n\n        super().__init__(\n            step_size=step_size,\n            contrast_fn=contrast_fn,\n            score_fn=score_fn,\n            callbacks=callbacks,\n            is_holonomic=is_holonomic,\n            record_loss=record_loss,\n        )\n\n    def __repr__(self) -> str:\n        s = \"GradLaplaceICA(\"\n        s += \"step_size={step_size}\"\n        s += \", is_holonomic={is_holonomic}\"\n        s += \", record_loss={record_loss}\"\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def update_once(self) -> None:\n        r\"\"\"Update demixing filters once using the gradient descent.\n\n        If ``is_holonomic=True``, demixing filters are updated as follows:\n\n        .. math::\n            \\boldsymbol{W}\n            \\leftarrow\\boldsymbol{W} - \\eta\\left(\\frac{1}{T}\\sum_{t} \\\n            \\boldsymbol{\\phi}(\\boldsymbol{y}_{t})\\boldsymbol{y}_{t}^{\\mathsf{T}} \\\n            -\\boldsymbol{I}\\right)\\boldsymbol{W}^{-\\mathsf{T}},\n\n        where\n\n        .. math::\n            \\boldsymbol{\\phi}(\\boldsymbol{y}_{t})\n            = \\left(\\mathrm{sign}(y_{t1}),\\ldots,\\mathrm{sign}(y_{tN})\\right)^{\\mathsf{T}} \\\n            \\in\\mathbb{R}^{N}.\n\n        Otherwise (``is_holonomic=False``),\n\n        .. math::\n            \\boldsymbol{W}\n            \\leftarrow\\boldsymbol{W} - \\eta\\cdot\\mathrm{offdiag}\\left(\\frac{1}{T}\\sum_{t} \\\n            \\boldsymbol{\\phi}(\\boldsymbol{y}_{t})\\boldsymbol{y}_{t}^{\\mathsf{T}}\\right) \\\n            \\boldsymbol{W}^{-\\mathsf{T}}.\n        \"\"\"\n        super().update_once()\n\n    def compute_loss(self) -> float:\n        r\"\"\"Compute loss :math:`\\mathcal{L}`.\n\n        :math:`\\mathcal{L}` is given as follows:\n\n        .. math::\n            \\mathcal{L} \\\n            &= \\frac{1}{T}\\sum_{t,n}|y_{tn}| \\\n                - \\log|\\det\\boldsymbol{W}| \\\\\n\n        Returns:\n            Computed loss.\n        \"\"\"\n        return super().compute_loss()\n\n\nclass NaturalGradLaplaceICA(NaturalGradICA):\n    r\"\"\"Independent component analysis (ICA) using the natural gradient descent \\\n    on a Laplace distribution.\n\n    We assume :math:`y_{ijn}` follows a Laplace distribution.\n\n    .. math::\n        p(y_{ijn})\\propto\\exp(|y_{ijn}|)\n\n    Args:\n        step_size (float):\n            A step size of the gradient descent. Default: ``1e-1``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        is_holonomic (bool):\n            If ``is_holonomic=True``, Holonomic-type update is used.\n            Otherwise, Nonholonomic-type update is used. Default: ``False``.\n        record_loss (bool):\n            Record the loss at each iteration of the gradient descent \\\n            if ``record_loss=True``.\n            Default: ``True``.\n\n    Examples:\n        Update demixing filters using Holonomic-type update:\n\n        .. code-block:: python\n\n            >>> n_channels, n_samples = 2, 160000\n            >>> waveform_mix = np.random.randn(n_channels, n_samples)\n\n            >>> ica = NaturalGradLaplaceICA(is_holonomic=True)\n            >>> waveform_est = ica(waveform_mix, n_iter=100)\n            >>> print(waveform_mix.shape, waveform_est.shape)\n            (2, 160000), (2, 160000)\n\n        Update demixing filters using Nonholonomic-type update:\n\n        .. code-block:: python\n\n            >>> n_channels, n_samples = 2, 160000\n            >>> waveform_mix = np.random.randn(n_channels, n_samples)\n\n            >>> ica = NaturalGradLaplaceICA(is_holonomic=False)\n            >>> waveform_est = ica(waveform_mix, n_iter=100)\n            >>> print(waveform_mix.shape, waveform_est.shape)\n            (2, 160000), (2, 160000)\n    \"\"\"\n\n    def __init__(\n        self,\n        step_size: float = 1e-1,\n        callbacks: Optional[\n            Union[\n                Callable[[\"NaturalGradLaplaceICA\"], None],\n                List[Callable[[\"NaturalGradLaplaceICA\"], None]],\n            ]\n        ] = None,\n        is_holonomic: bool = False,\n        record_loss: bool = True,\n    ) -> None:\n        def contrast_fn(input):\n            return np.abs(input)\n\n        def score_fn(input):\n            return np.sign(input)\n\n        super().__init__(\n            step_size=step_size,\n            contrast_fn=contrast_fn,\n            score_fn=score_fn,\n            callbacks=callbacks,\n            is_holonomic=is_holonomic,\n            record_loss=record_loss,\n        )\n\n    def __repr__(self) -> str:\n        s = \"NaturalGradLaplaceICA(\"\n        s += \"step_size={step_size}\"\n        s += \", is_holonomic={is_holonomic}\"\n        s += \", record_loss={record_loss}\"\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def update_once(self) -> None:\n        r\"\"\"Update demixing filters once using the natural gradient descent.\n\n        If ``is_holonomic=True``, demixing filters are updated as follows:\n\n        .. math::\n            \\boldsymbol{W}\n            \\leftarrow\\boldsymbol{W} - \\eta\\left(\\frac{1}{T}\\sum_{t} \\\n            \\boldsymbol{\\phi}(\\boldsymbol{y}_{t})\\boldsymbol{y}_{t}^{\\mathsf{T}} \\\n            -\\boldsymbol{I}\\right)\\boldsymbol{W},\n\n        where\n\n        .. math::\n            \\boldsymbol{\\phi}(\\boldsymbol{y}_{t})\n            = \\left(\\mathrm{sign}(y_{t1}),\\ldots,\\mathrm{sign}(y_{tN})\\right)^{\\mathsf{T}} \\\n            \\in\\mathbb{R}^{N}.\n\n        Otherwise (``is_holonomic=False``),\n\n        .. math::\n            \\boldsymbol{W}\n            \\leftarrow\\boldsymbol{W} - \\eta\\cdot\\mathrm{offdiag}\\left(\\frac{1}{T}\\sum_{t} \\\n            \\boldsymbol{\\phi}(\\boldsymbol{y}_{t})\\boldsymbol{y}_{t}^{\\mathsf{T}}\\right) \\\n            \\boldsymbol{W}.\n        \"\"\"\n        super().update_once()\n\n    def compute_loss(self) -> float:\n        r\"\"\"Compute loss :math:`\\mathcal{L}`.\n\n        :math:`\\mathcal{L}` is given as follows:\n\n        .. math::\n            \\mathcal{L} \\\n            &= \\frac{1}{T}\\sum_{t,n}|y_{tn}| \\\n                - \\log|\\det\\boldsymbol{W}| \\\\\n\n        Returns:\n            Computed loss.\n        \"\"\"\n        return super().compute_loss()\n"
  },
  {
    "path": "ssspy/bss/ilrma.py",
    "content": "import functools\nimport warnings\nfrom typing import Callable, Iterable, List, Optional, Tuple, Union\n\nimport numpy as np\n\nfrom ..algorithm import (\n    MINIMAL_DISTORTION_PRINCIPLE_KEYWORDS,\n    PROJECTION_BACK_KEYWORDS,\n    minimal_distortion_principle,\n    projection_back,\n)\nfrom ..special.flooring import identity, max_flooring\nfrom ..utils.flooring import choose_flooring_fn\nfrom ..utils.select_pair import sequential_pair_selector\nfrom ._update_spatial_model import (\n    update_by_ip1,\n    update_by_ip2,\n    update_by_ipa,\n    update_by_iss1,\n    update_by_iss2,\n)\nfrom .base import IterativeMethodBase\n\n__all__ = [\"GaussILRMA\", \"TILRMA\", \"GGDILRMA\"]\n\nspatial_algorithms = [\"IP\", \"IP1\", \"IP2\", \"ISS\", \"ISS1\", \"ISS2\", \"IPA\"]\nsource_algorithms = [\"MM\", \"ME\"]\nEPS = 1e-10\n\n\nclass ILRMABase(IterativeMethodBase):\n    r\"\"\"Base class of independent low-rank matrix analysis (ILRMA).\n\n    Args:\n        n_basis (int):\n            Number of NMF bases.\n        partitioning (bool):\n            Whether to use partioning function. Default: ``False``.\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back`` explicitly.\n            Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the update algorithm if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back.\n            Default: ``0``.\n        rng (numpy.random.Generator, optioinal):\n            Random number generator. This is mainly used to randomly initialize NMF.\n            If ``None`` is given, ``np.random.default_rng()`` is used.\n            Default: ``None``.\n    \"\"\"\n\n    def __init__(\n        self,\n        n_basis: int,\n        partitioning: bool = False,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        callbacks: Optional[\n            Union[Callable[[\"ILRMABase\"], None], List[Callable[[\"ILRMABase\"], None]]]\n        ] = None,\n        scale_restoration: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n        rng: Optional[np.random.Generator] = None,\n    ) -> None:\n        super().__init__(callbacks=callbacks, record_loss=record_loss)\n\n        self.n_basis = n_basis\n        self.partitioning = partitioning\n\n        if flooring_fn is None:\n            self.flooring_fn = identity\n        else:\n            self.flooring_fn = flooring_fn\n\n        self.input = None\n        self.scale_restoration = scale_restoration\n\n        if reference_id is None and scale_restoration:\n            raise ValueError(\"Specify 'reference_id' if scale_restoration=True.\")\n        else:\n            self.reference_id = reference_id\n\n        if rng is None:\n            rng = np.random.default_rng()\n\n        self.rng = rng\n\n    def __call__(\n        self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs\n    ) -> np.ndarray:\n        r\"\"\"Separate a frequency-domain multichannel signal.\n\n        Args:\n            input (numpy.ndarray):\n                The mixture signal in frequency-domain.\n                The shape is (n_channels, n_bins, n_frames).\n            n_iter (int):\n                The number of iterations of demixing filter updates.\n                Default: ``100``.\n            initial_call (bool):\n                If ``True``, perform callbacks (and computation of loss if necessary)\n                before iterations.\n\n        Returns:\n            numpy.ndarray of the separated signal in frequency-domain.\n            The shape is (n_channels, n_bins, n_frames).\n        \"\"\"\n        self.input = input.copy()\n\n        self._reset(flooring_fn=self.flooring_fn, **kwargs)\n\n        super().__call__(n_iter=n_iter, initial_call=initial_call)\n\n        if self.scale_restoration:\n            self.restore_scale()\n\n        self.output = self.separate(self.input, demix_filter=self.demix_filter)\n\n        return self.output\n\n    def __repr__(self) -> str:\n        s = \"ILRMA(\"\n        s += \"n_basis={n_basis}\"\n        s += \", partitioning={partitioning}\"\n        s += \", scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def _reset(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n        **kwargs,\n    ) -> None:\n        r\"\"\"Reset attributes by given keyword arguments.\n\n        We also set variance of Gaussian distribution.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n            kwargs:\n                Keyword arguments to set as attributes of ILRMA.\n        \"\"\"\n        assert self.input is not None, \"Specify data!\"\n\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        for key in kwargs.keys():\n            setattr(self, key, kwargs[key])\n\n        X = self.input\n\n        n_channels, n_bins, n_frames = X.shape\n        n_sources = n_channels  # n_channels == n_sources\n\n        self.n_sources, self.n_channels = n_sources, n_channels\n        self.n_bins, self.n_frames = n_bins, n_frames\n\n        if not hasattr(self, \"demix_filter\"):\n            W = np.eye(n_sources, n_channels, dtype=np.complex128)\n            W = np.tile(W, reps=(n_bins, 1, 1))\n        else:\n            if self.demix_filter is None:\n                W = None\n            else:\n                # To avoid overwriting ``demix_filter`` given by keyword arguments.\n                W = self.demix_filter.copy()\n\n        self.demix_filter = W\n        self.output = self.separate(X, demix_filter=W)\n\n        self._init_nmf(flooring_fn=flooring_fn, rng=self.rng)\n\n    def _init_nmf(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n        rng: Optional[np.random.Generator] = None,\n    ) -> None:\n        r\"\"\"Initialize NMF.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n            rng (numpy.random.Generator, optional):\n                Random number generator. If ``None`` is given,\n                ``np.random.default_rng()`` is used.\n                Default: ``None``.\n        \"\"\"\n        n_basis = self.n_basis\n        n_sources = self.n_sources\n        n_bins, n_frames = self.n_bins, self.n_frames\n\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if rng is None:\n            rng = np.random.default_rng()\n\n        if self.partitioning:\n            if not hasattr(self, \"latent\"):\n                Z = rng.random((n_sources, n_basis))\n                Z = Z / Z.sum(axis=0)\n                Z = flooring_fn(Z)\n            else:\n                # To avoid overwriting.\n                Z = self.latent.copy()\n\n            if not hasattr(self, \"basis\"):\n                T = rng.random((n_bins, n_basis))\n                T = flooring_fn(T)\n            else:\n                # To avoid overwriting.\n                T = self.basis.copy()\n\n            if not hasattr(self, \"activation\"):\n                V = rng.random((n_basis, n_frames))\n                V = flooring_fn(V)\n            else:\n                # To avoid overwriting.\n                V = self.activation.copy()\n\n            self.latent = Z\n            self.basis, self.activation = T, V\n        else:\n            if not hasattr(self, \"basis\"):\n                T = rng.random((n_sources, n_bins, n_basis))\n                T = flooring_fn(T)\n            else:\n                # To avoid overwriting.\n                T = self.basis.copy()\n\n            if not hasattr(self, \"activation\"):\n                V = rng.random((n_sources, n_basis, n_frames))\n                V = flooring_fn(V)\n            else:\n                # To avoid overwriting.\n                V = self.activation.copy()\n\n            self.basis, self.activation = T, V\n\n    def separate(self, input: np.ndarray, demix_filter: np.ndarray) -> np.ndarray:\n        r\"\"\"Separate ``input`` using ``demixing_filter``.\n\n        .. math::\n            \\boldsymbol{y}_{ij}\n            = \\boldsymbol{W}_{i}\\boldsymbol{x}_{ij}\n\n        Args:\n            input (numpy.ndarray):\n                The mixture signal in frequency-domain.\n                The shape is (n_channels, n_bins, n_frames).\n            demix_filter (numpy.ndarray):\n                The demixing filters to separate ``input``.\n                The shape is (n_bins, n_sources, n_channels).\n\n        Returns:\n            numpy.ndarray of the separated signal in frequency-domain.\n            The shape is (n_sources, n_bins, n_frames).\n        \"\"\"\n        X, W = input, demix_filter\n        Y = W @ X.transpose(1, 0, 2)\n        output = Y.transpose(1, 0, 2)\n\n        return output\n\n    def reconstruct_nmf(\n        self, basis: np.ndarray, activation: np.ndarray, latent: Optional[np.ndarray] = None\n    ) -> np.ndarray:\n        r\"\"\"Reconstruct NMF.\n\n        Args:\n            basis (numpy.ndarray):\n                Basis matrix.\n                The shape is (n_sources, n_basis, n_frames) if latent is given.\n                Otherwise, (n_basis, n_frames).\n            activation (numpy.ndarray):\n                Activation matrix.\n                The shape is (n_sources, n_bins, n_basis) if latent is given.\n                Otherwise, (n_bins, n_basis).\n            latent (numpy.ndarray, optional):\n                Latent variable that determines number of bases per source.\n\n        Returns:\n            numpy.ndarray of theconstructed NMF.\n            The shape is (n_sources, n_bins, n_frames).\n        \"\"\"\n        if latent is None:\n            T, V = basis, activation\n            R = T @ V\n        else:\n            Z = latent\n            T, V = basis, activation\n            TV = T[:, :, np.newaxis] * V[np.newaxis, :, :]\n            R = np.sum(Z[:, np.newaxis, :, np.newaxis] * TV[np.newaxis, :, :, :], axis=2)\n\n        return R\n\n    def update_once(self) -> None:\n        r\"\"\"Update demixing filters once.\"\"\"\n        raise NotImplementedError(\"Implement 'update_once' method.\")\n\n    def normalize(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Normalize demixing filters and NMF parameters.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        normalization = self.normalization\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        assert normalization, \"Set normalization.\"\n\n        if type(normalization) is bool:\n            # when normalization is True\n            normalization = \"power\"\n\n        if normalization == \"power\":\n            self.normalize_by_power(flooring_fn=flooring_fn)\n        elif normalization == \"projection_back\":\n            self.normalize_by_projection_back()\n        else:\n            raise NotImplementedError(\"Normalization {} is not implemented.\".format(normalization))\n\n    def normalize_by_power(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Normalize demixing filters and NMF parameters by power.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        Demixing filters are normalized by\n\n        .. math::\n            \\boldsymbol{w}_{in}\n            \\leftarrow\\frac{\\boldsymbol{w}_{in}}{\\psi_{in}},\n\n        where\n\n        .. math::\n            \\psi_{in}\n            = \\sqrt{\\frac{1}{IJ}|\\boldsymbol{w}_{in}^{\\mathsf{H}}\n            \\boldsymbol{x}_{ij}|^{2}}.\n\n        For NMF parameters,\n\n        .. math::\n            t_{ik}\n            &\\leftarrow t_{ik}\\sum_{n}\\frac{z_{nk}}{\\psi_{in}^{p}}, \\\\\n            z_{nk}\n            &\\leftarrow \\frac{\\frac{z_{nk}}{\\psi_{in}^{p}}}\n            {\\sum_{n'}\\frac{z_{n'k}}{\\psi_{in'}^{p}}},\n\n        if ``self.partitioning=True``. Otherwise,\n\n        .. math::\n            t_{ikn}\n            \\leftarrow\\frac{t_{ikn}}{\\psi_{in}^{p}}.\n        \"\"\"\n        p = self.domain\n\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if self.demix_filter is None:\n            Y = self.output\n        else:\n            X, W = self.input, self.demix_filter\n            Y = self.separate(X, demix_filter=W)\n\n        Y2 = np.mean(np.abs(Y) ** 2, axis=(-2, -1))\n        psi = np.sqrt(Y2)\n        psi = flooring_fn(psi)\n\n        if self.partitioning:\n            Z, T = self.latent, self.basis\n\n            Z_psi = Z / (psi[:, np.newaxis] ** p)\n            scale = np.sum(Z_psi, axis=0)\n            T = T * scale[np.newaxis, :]\n            Z = Z_psi / scale\n\n            self.latent, self.basis = Z, T\n        else:\n            T = self.basis\n\n            T = T / (psi[:, np.newaxis, np.newaxis] ** p)\n\n            self.basis = T\n\n        if self.demix_filter is None:\n            Y = Y / psi[:, np.newaxis, np.newaxis]\n            self.output = Y\n        else:\n            W = self.demix_filter\n            W = W / psi[np.newaxis, :, np.newaxis]\n            self.demix_filter = W\n\n    def normalize_by_projection_back(self) -> None:\n        r\"\"\"Normalize demixing filters and NMF parameters by projection back.\n\n        Demixing filters are normalized by\n\n        .. math::\n            \\boldsymbol{w}_{in}\n            \\leftarrow\\frac{\\boldsymbol{w}_{in}}{\\psi_{in}},\n\n        where\n\n        .. math::\n            \\boldsymbol{\\psi}_{i}\n            = \\boldsymbol{W}_{i}^{-1}\\boldsymbol{e}_{m_{\\mathrm{ref}}}.\n\n        For NMF parameters,\n\n        .. math::\n            t_{ikn}\n            \\leftarrow\\frac{t_{ikn}}{\\psi_{in}^{p}}.\n        \"\"\"\n        p = self.domain\n        reference_id = self.reference_id\n\n        X = self.input\n\n        if reference_id is None:\n            warnings.warn(\n                \"channel 0 is used for reference_id \\\n                    of projection-back-based normalization.\",\n                UserWarning,\n            )\n            reference_id = 0\n\n        if self.partitioning:\n            raise NotImplementedError(\n                \"Projection-back-based normalization is not applicable with partitioning function.\"\n            )\n        else:\n            T = self.basis\n\n            if self.demix_filter is None:\n                Y = self.output\n\n                Y = Y.transpose(1, 0, 2)  # (n_bins, n_sources, n_frames)\n                X = X.transpose(1, 0, 2)  # (n_bins, n_channels, n_frames)\n                Y_Hermite = Y.transpose(0, 2, 1).conj()  # (n_bins, n_frames, n_sources)\n                XY_Hermite = X @ Y_Hermite  # (n_bins, n_channels, n_sources)\n                YY_Hermite = Y @ Y_Hermite  # (n_bins, n_sources, n_sources)\n                scale = XY_Hermite @ np.linalg.inv(YY_Hermite)  # (n_bins, n_channels, n_sources)\n                scale = scale[..., reference_id, :]  # (n_bins, n_sources)\n                Y_scaled = Y * scale[..., np.newaxis]  # (n_bins, n_sources, n_frames)\n                Y = Y_scaled.swapaxes(-3, -2)  # (n_sources, n_bins, n_frames)\n\n                self.output = Y\n            else:\n                W = self.demix_filter\n\n                scale = np.linalg.inv(W)\n                scale = scale[:, reference_id, :]\n                W = W * scale[:, :, np.newaxis]\n\n                self.demix_filter = W\n\n            scale = scale.transpose(1, 0)\n            scale = np.abs(scale) ** p\n            T = T * scale[:, :, np.newaxis]\n\n            self.basis = T\n\n    def compute_loss(self) -> float:\n        r\"\"\"Compute loss :math:`\\mathcal{L}`.\n\n        Returns:\n            Computed loss.\n        \"\"\"\n        raise NotImplementedError(\"Implement 'compute_loss' method.\")\n\n    def compute_logdet(self, demix_filter: np.ndarray) -> np.ndarray:\n        r\"\"\"Compute log-determinant of demixing filter\n\n        Args:\n            demix_filter (numpy.ndarray):\n                Demixing filters with shape of (n_bins, n_sources, n_channels).\n\n        Returns:\n            numpy.ndarray of computed log-determinant values.\n        \"\"\"\n        _, logdet = np.linalg.slogdet(demix_filter)  # (n_bins,)\n\n        return logdet\n\n    def restore_scale(self) -> None:\n        r\"\"\"Restore scale ambiguity.\n\n        If ``self.scale_restoration=\"projection_back``, we use projection back technique.\n        \"\"\"\n        scale_restoration = self.scale_restoration\n\n        assert scale_restoration, \"Set self.scale_restoration=True.\"\n\n        if type(scale_restoration) is bool:\n            scale_restoration = PROJECTION_BACK_KEYWORDS[0]\n\n        if scale_restoration in PROJECTION_BACK_KEYWORDS:\n            self.apply_projection_back()\n        elif scale_restoration in MINIMAL_DISTORTION_PRINCIPLE_KEYWORDS:\n            self.apply_minimal_distortion_principle()\n        else:\n            raise ValueError(\"{} is not supported for scale restoration.\".format(scale_restoration))\n\n    def apply_projection_back(self) -> None:\n        r\"\"\"Apply projection back technique to estimated spectrograms.\"\"\"\n        assert self.scale_restoration, \"Set self.scale_restoration=True.\"\n\n        X, W = self.input, self.demix_filter\n        W_scaled = projection_back(W, reference_id=self.reference_id)\n        Y_scaled = self.separate(X, demix_filter=W_scaled)\n\n        self.output, self.demix_filter = Y_scaled, W_scaled\n\n    def apply_minimal_distortion_principle(self) -> None:\n        r\"\"\"Apply minimal distortion principle to estimated spectrograms.\"\"\"\n        assert self.scale_restoration, \"Set self.scale_restoration=True.\"\n\n        X, W = self.input, self.demix_filter\n        Y = self.separate(X, demix_filter=W)\n        Y_scaled = minimal_distortion_principle(Y, reference=X, reference_id=self.reference_id)\n        X = X.transpose(1, 0, 2)\n        Y = Y_scaled.transpose(1, 0, 2)\n        X_Hermite = X.transpose(0, 2, 1).conj()\n        W_scaled = Y @ X_Hermite @ np.linalg.inv(X @ X_Hermite)\n\n        self.output, self.demix_filter = Y_scaled, W_scaled\n\n\nclass GaussILRMA(ILRMABase):\n    r\"\"\"Independent low-rank matrix analysis (ILRMA) [#kitamura2016determined]_ \\\n    on Gaussian distribution.\n\n    We assume :math:`y_{ijn}` follows a Gaussian distribution.\n\n    .. math::\n        p(y_{ijn})\n        = \\frac{1}{\\pi r_{ijn}}\\exp\\left(-\\frac{|y_{ijn}|^{2}}{r_{ijn}}\\right),\n\n    where\n\n    .. math::\n        r_{ijn}\n        = \\left(\\sum_{k}z_{nk}t_{ik}v_{kj}\\right)^{\\frac{2}{p}},\n\n    if ``partitioning=True``. Otherwise,\n\n    .. math::\n        r_{ijn}\n        = \\left(\\sum_{k}t_{ikn}v_{kjn}\\right)^{\\frac{2}{p}}.\n\n    Args:\n        n_basis (int):\n            Number of NMF bases.\n        spatial_algorithm (str):\n            Algorithm for demixing filter updates.\n            Choose ``IP``, ``IP1``, ``IP2``, ``ISS``, ``ISS1``, or ``ISS2``.\n            Default: ``IP``.\n        source_algorithm (str):\n            Algorithm for source model updates.\n            Choose ``MM`` or ``ME``. Default: ``MM``.\n        domain (float):\n            Domain parameter. Default: ``2``.\n        partitioning (bool):\n            Whether to use partioning function. Default: ``False``.\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n        pair_selector (callable, optional):\n            Selector to choose updaing pair in ``IP2`` and ``ISS2``.\n            If ``None`` is given, ``sequential_pair_selector`` is used.\n            Default: ``None``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        normalization (bool or str, optional):\n            Normalization of demixing filters and NMF parameters.\n            Choose ``power`` or ``projection_back``.\n            Default: ``power``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back``\n            or ``minimal_distortion_principle``. Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the update algorithm if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back and minimal distortion principle. Default: ``0``.\n        rng (numpy.random.Generator, optioinal):\n            Random number generator. This is mainly used to randomly initialize NMF.\n            If ``None`` is given, ``np.random.default_rng()`` is used.\n            Default: ``None``.\n\n    Examples:\n        Update demixing filters by IP:\n\n        .. code-block:: python\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> ilrma = GaussILRMA(\n            ...     n_basis=2,\n            ...     spatial_algorithm=\"IP\",\n            ...     rng=np.random.default_rng(42),\n            ... )\n            >>> spectrogram_est = ilrma(spectrogram_mix, n_iter=100)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n        Update demixing filters by IP2:\n\n        .. code-block:: python\n\n            >>> from ssspy.utils.select_pair import sequential_pair_selector\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> ilrma = GaussILRMA(\n            ...     n_basis=2,\n            ...     spatial_algorithm=\"IP2\",\n            ...     pair_selector=sequential_pair_selector,\n            ...     rng=np.random.default_rng(42),\n            ... )\n            >>> spectrogram_est = ilrma(spectrogram_mix, n_iter=100)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n        Update demixing filters by ISS:\n\n        .. code-block:: python\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> ilrma = GaussILRMA(\n            ...     n_basis=2,\n            ...     spatial_algorithm=\"ISS\",\n            ...     rng=np.random.default_rng(42),\n            ... )\n            >>> spectrogram_est = ilrma(spectrogram_mix, n_iter=100)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n        Update demixing filters by ISS2:\n\n        .. code-block:: python\n\n            >>> import functools\n            >>> from ssspy.utils.select_pair import sequential_pair_selector\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> ilrma = GaussILRMA(\n            ...     n_basis=2,\n            ...     spatial_algorithm=\"ISS2\",\n            ...     pair_selector=functools.partial(sequential_pair_selector, step=2),\n            ...     rng=np.random.default_rng(42),\n            ... )\n            >>> spectrogram_est = ilrma(spectrogram_mix, n_iter=100)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n    Update demixing filters by IPA:\n\n        .. code-block:: python\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> ilrma = GaussILRMA(\n            ...     n_basis=2,\n            ...     spatial_algorithm=\"IPA\",\n            ...     rng=np.random.default_rng(42),\n            ... )\n            >>> spectrogram_est = ilrma(spectrogram_mix, n_iter=100)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n    .. [#kitamura2016determined] D. Kitamura, N. Ono, H. Sawada, H. Kameoka, and H. Saruwatari, \\\n        \"Determined blind source separation unifying independent vector analysis and \\\n        nonnegative matrix factorization,\" \\\n        *IEEE/ACM Trans. ASLP*, vol. 24, no. 9, pp. 1626-1641, 2016.\n    \"\"\"\n\n    _ipa_default_kwargs = {\"lqpqm_normalization\": True, \"newton_iter\": 1}\n    _default_kwargs = _ipa_default_kwargs\n\n    def __init__(\n        self,\n        n_basis: int,\n        spatial_algorithm: str = \"IP\",\n        source_algorithm: str = \"MM\",\n        domain: float = 2,\n        partitioning: bool = False,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        pair_selector: Optional[Callable[[int], Iterable[Tuple[int, int]]]] = None,\n        callbacks: Optional[\n            Union[Callable[[\"GaussILRMA\"], None], List[Callable[[\"GaussILRMA\"], None]]]\n        ] = None,\n        normalization: Optional[Union[bool, str]] = True,\n        scale_restoration: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n        rng: Optional[np.random.Generator] = None,\n        **kwargs,\n    ) -> None:\n        super().__init__(\n            n_basis=n_basis,\n            partitioning=partitioning,\n            flooring_fn=flooring_fn,\n            callbacks=callbacks,\n            scale_restoration=scale_restoration,\n            record_loss=record_loss,\n            reference_id=reference_id,\n            rng=rng,\n        )\n\n        assert spatial_algorithm in spatial_algorithms, \"Not support {}.\".format(spatial_algorithm)\n        assert source_algorithm in source_algorithms, \"Not support {}.\".format(source_algorithm)\n        assert 0 < domain <= 2, \"domain parameter should be chosen from [0, 2].\"\n\n        if source_algorithm == \"ME\":\n            assert domain == 2, \"domain parameter should be 2 when you specify ME algorithm.\"\n\n        self.spatial_algorithm = spatial_algorithm\n        self.source_algorithm = source_algorithm\n        self.domain = domain\n        self.normalization = normalization\n\n        if pair_selector is None:\n            if spatial_algorithm in [\"IP2\", \"ISS2\"]:\n                self.pair_selector = sequential_pair_selector\n        else:\n            self.pair_selector = pair_selector\n\n        if spatial_algorithm == \"IPA\":\n            valid_keys = set(self.__class__._ipa_default_kwargs.keys())\n        else:\n            valid_keys = set()\n\n        invalid_keys = set(kwargs) - valid_keys\n\n        assert invalid_keys == set(), \"Invalid keywords {} are given.\".format(invalid_keys)\n\n        for key, value in kwargs.items():\n            setattr(self, key, value)\n\n        # set default values if necessary\n        for key in valid_keys:\n            if not hasattr(self, key):\n                value = self.__class__._default_kwargs[key]\n                setattr(self, key, value)\n\n    def __call__(\n        self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs\n    ) -> np.ndarray:\n        r\"\"\"Separate a frequency-domain multichannel signal.\n\n        Args:\n            input (numpy.ndarray):\n                The mixture signal in frequency-domain.\n                The shape is (n_channels, n_bins, n_frames).\n            n_iter (int):\n                The number of iterations of demixing filter updates.\n                Default: ``100``.\n            initial_call (bool):\n                If ``True``, perform callbacks (and computation of loss if necessary)\n                before iterations.\n\n        Returns:\n            numpy.ndarray of the separated signal in frequency-domain.\n            The shape is (n_channels, n_bins, n_frames).\n        \"\"\"\n        self.input = input.copy()\n\n        self._reset(flooring_fn=self.flooring_fn, **kwargs)\n\n        # Call __call__ of ILRMABase's parent, i.e. __call__ of IterativeMethodBase\n        super(ILRMABase, self).__call__(n_iter=n_iter, initial_call=initial_call)\n\n        if self.scale_restoration:\n            self.restore_scale()\n\n        if self.demix_filter is None:\n            pass\n        else:\n            self.output = self.separate(self.input, demix_filter=self.demix_filter)\n\n        return self.output\n\n    def __repr__(self) -> str:\n        s = \"GaussILRMA(\"\n        s += \"n_basis={n_basis}\"\n        s += \", spatial_algorithm={spatial_algorithm}\"\n        s += \", source_algorithm={source_algorithm}\"\n        s += \", domain={domain}\"\n        s += \", partitioning={partitioning}\"\n        s += \", normalization={normalization}\"\n        s += \", scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def _reset(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n        **kwargs,\n    ) -> None:\n        r\"\"\"Reset attributes by given keyword arguments.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n            kwargs:\n                Keyword arguments to set as attributes of ILRMA.\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        super()._reset(flooring_fn=flooring_fn, **kwargs)\n\n        if self.spatial_algorithm in [\"ISS\", \"ISS1\", \"ISS2\", \"IPA\"]:\n            self.demix_filter = None\n\n    def update_once(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update NMF parameters and demixing filters once.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        self.update_source_model(flooring_fn=flooring_fn)\n        self.update_spatial_model(flooring_fn=flooring_fn)\n\n        if self.normalization:\n            self.normalize(flooring_fn=flooring_fn)\n\n    def update_source_model(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update NMF bases, activations, and latent variables.\n\n        - If ``source_algorithm`` is ``MM``, ``update_source_model_mm`` is called.\n        - If ``source_algorithm`` is ``ME``, ``update_source_model_me`` is called.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if self.source_algorithm == \"MM\":\n            self.update_source_model_mm(flooring_fn=flooring_fn)\n        elif self.source_algorithm == \"ME\":\n            self.update_source_model_me(flooring_fn=flooring_fn)\n        else:\n            raise ValueError(\n                \"{}-algorithm-based source model updates are not supported.\".format(\n                    self.source_algorithm\n                )\n            )\n\n    def update_source_model_mm(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update NMF bases, activations, and latent variables by MM algorithm.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if self.partitioning:\n            self.update_latent_mm()\n\n        self.update_basis_mm(flooring_fn=flooring_fn)\n        self.update_activation_mm(flooring_fn=flooring_fn)\n\n    def update_source_model_me(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update NMF bases, activations, and latent variables by ME algorithm.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if self.domain != 2:\n            raise ValueError(\"Domain parameter is expected 2, but given {}.\".format(self.domain))\n\n        if self.partitioning:\n            self.update_latent_me()\n\n        self.update_basis_me(flooring_fn=flooring_fn)\n        self.update_activation_me(flooring_fn=flooring_fn)\n\n    def update_latent_mm(self) -> None:\n        r\"\"\"Update latent variables in NMF by MM algorithm.\n\n        Update :math:`z_{nk}` as follows:\n\n        .. math::\n            z_{nk}\n            &\\leftarrow\\left[\\frac{\\displaystyle\\sum_{i,j}\\frac{t_{ik}v_{kj}}\n            {(\\sum_{k'}z_{nk'}t_{ik'}v_{k'j})^{\\frac{p+2}{p}}}\n            |y_{ijn}|^{2}}{\\displaystyle\\sum_{i,j}\\dfrac{t_{ik}v_{kj}}{\\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}}\n            \\right]^{\\frac{p}{p+2}}z_{nk} \\\\\n            z_{nk}\n            &\\leftarrow\\frac{z_{nk}}{\\sum_{n'}z_{n'k}}.\n        \"\"\"\n        p = self.domain\n\n        if self.demix_filter is None:\n            Y = self.output\n        else:\n            X, W = self.input, self.demix_filter\n            Y = self.separate(X, demix_filter=W)\n\n        Y2 = np.abs(Y) ** 2\n        p2_p = (p + 2) / p\n        p_p2 = p / (p + 2)\n\n        Z = self.latent\n        T, V = self.basis, self.activation\n\n        TV = T[:, :, np.newaxis] * V[np.newaxis, :, :]\n        ZTV = self.reconstruct_nmf(T, V, latent=Z)\n\n        ZTVp2p = ZTV**p2_p\n        TV_ZTVp2p = TV[np.newaxis, :, :, :] / ZTVp2p[:, :, np.newaxis, :]\n        num = np.sum(TV_ZTVp2p * Y2[:, :, np.newaxis, :], axis=(1, 3))\n\n        TV_ZTV = TV[np.newaxis, :, :, :] / ZTV[:, :, np.newaxis, :]\n        denom = np.sum(TV_ZTV, axis=(1, 3))\n\n        Z = ((num / denom) ** p_p2) * Z\n        Z = Z / Z.sum(axis=0)\n\n        self.latent = Z\n\n    def update_basis_mm(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update NMF bases.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        Update :math:`t_{ikn}` as follows:\n\n        .. math::\n            t_{ik}\n            \\leftarrow\\left[\n            \\frac{\\displaystyle\\sum_{j,n}\\frac{z_{nk}v_{kj}}\n            {(\\sum_{k'}z_{nk'}t_{ik'}v_{k'j})^{\\frac{p+2}{p}}}\n            |y_{ijn}|^{2}}{\\displaystyle\\sum_{j,n}\n            \\dfrac{z_{nk}v_{kj}}{\\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}}\n            \\right]^{\\frac{p}{p+2}}t_{ik},\n\n        if ``partitioning=True``. Otherwise\n\n        .. math::\n            t_{ikn}\n            \\leftarrow \\left[\\frac{\\displaystyle\\sum_{j}\n            \\dfrac{v_{kjn}}{(\\sum_{k'}t_{ik'n}v_{k'jn})^{\\frac{p+2}{p}}}|y_{ijn}|^{2}}\n            {\\displaystyle\\sum_{j}\\frac{v_{kjn}}{\\sum_{k'}t_{ik'n}v_{k'jn}}}\\right]\n            ^{\\frac{p}{p+2}}t_{ikn}.\n        \"\"\"\n        p = self.domain\n\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if self.demix_filter is None:\n            Y = self.output\n        else:\n            X, W = self.input, self.demix_filter\n            Y = self.separate(X, demix_filter=W)\n\n        Y2 = np.abs(Y) ** 2\n        p2_p = (p + 2) / p\n        p_p2 = p / (p + 2)\n\n        if self.partitioning:\n            Z = self.latent\n            T, V = self.basis, self.activation\n\n            ZV = Z[:, :, np.newaxis] * V[np.newaxis, :, :]\n            ZTV = self.reconstruct_nmf(T, V, latent=Z)\n\n            ZTVp2p = ZTV**p2_p\n            ZV_ZTVp2p = ZV[:, np.newaxis, :, :] / ZTVp2p[:, :, np.newaxis, :]\n            num = np.sum(ZV_ZTVp2p * Y2[:, :, np.newaxis, :], axis=(0, 3))\n\n            ZV_ZTV = ZV[:, np.newaxis, :, :] / ZTV[:, :, np.newaxis, :]\n            denom = np.sum(ZV_ZTV, axis=(0, 3))\n        else:\n            T, V = self.basis, self.activation\n\n            TV = self.reconstruct_nmf(T, V)\n\n            TVp2p = TV**p2_p\n            V_TVp2p = V[:, np.newaxis, :, :] / TVp2p[:, :, np.newaxis, :]\n            num = np.sum(V_TVp2p * Y2[:, :, np.newaxis, :], axis=3)\n\n            V_TV = V[:, np.newaxis, :, :] / TV[:, :, np.newaxis, :]\n            denom = np.sum(V_TV, axis=3)\n\n        T = ((num / denom) ** p_p2) * T\n        T = flooring_fn(T)\n\n        self.basis = T\n\n    def update_activation_mm(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update NMF activations by MM algorithm.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        Update :math:`v_{kjn}` as follows:\n\n        .. math::\n            v_{kj}\n            \\leftarrow\\left[\\frac{\\displaystyle\\sum_{i,n}\\frac{z_{nk}t_{ik}}\n            {(\\sum_{k'}z_{nk'}t_{ik'}v_{k'j})^{\\frac{p+2}{p}}}\n            |y_{ijn}|^{2}}{\\displaystyle\\sum_{i,n}\\dfrac{z_{nk}t_{ik}}{\\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}}\n            \\right]^{\\frac{p}{p+2}}v_{kj},\n\n        if ``partitioning=True``. Otherwise\n\n        .. math::\n            v_{kjn}\n            \\leftarrow \\left[\\frac{\\displaystyle\\sum_{i}\n            \\dfrac{t_{ikn}}{(\\sum_{k'}t_{ik'n}v_{k'jn})^{\\frac{p+2}{p}}}|y_{ijn}|^{2}}\n            {\\displaystyle\\sum_{i}\\frac{t_{ikn}}{\\sum_{k'}t_{ik'n}v_{k'jn}}}\n            \\right]^{\\frac{p}{p+2}}v_{kjn}.\n        \"\"\"\n        p = self.domain\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if self.demix_filter is None:\n            Y = self.output\n        else:\n            X, W = self.input, self.demix_filter\n            Y = self.separate(X, demix_filter=W)\n\n        Y2 = np.abs(Y) ** 2\n        p2_p = (p + 2) / p\n        p_p2 = p / (p + 2)\n\n        if self.partitioning:\n            Z = self.latent\n            T, V = self.basis, self.activation\n\n            ZT = Z[:, np.newaxis, :] * T[np.newaxis, :, :]\n            ZTV = self.reconstruct_nmf(T, V, latent=Z)\n\n            ZTVp2p = ZTV**p2_p\n            ZT_ZTVp2p = ZT[:, :, :, np.newaxis] / ZTVp2p[:, :, np.newaxis, :]\n            num = np.sum(ZT_ZTVp2p * Y2[:, :, np.newaxis, :], axis=(0, 1))\n\n            ZT_ZTV = ZT[:, :, :, np.newaxis] / ZTV[:, :, np.newaxis, :]\n            denom = np.sum(ZT_ZTV, axis=(0, 1))\n        else:\n            T, V = self.basis, self.activation\n\n            TV = self.reconstruct_nmf(T, V)\n\n            TVp2p = TV**p2_p\n            T_TVp2p = T[:, :, :, np.newaxis] / TVp2p[:, :, np.newaxis, :]\n            num = np.sum(T_TVp2p * Y2[:, :, np.newaxis, :], axis=1)\n\n            T_TV = T[:, :, :, np.newaxis] / TV[:, :, np.newaxis, :]\n            denom = np.sum(T_TV, axis=1)\n\n        V = ((num / denom) ** p_p2) * V\n        V = flooring_fn(V)\n\n        self.activation = V\n\n    def update_latent_me(self) -> None:\n        r\"\"\"Update latent variables in NMF by ME algorithm.\n\n        Update :math:`z_{nk}` as follows:\n\n        .. math::\n            z_{nk}\n            &\\leftarrow\\left[\\frac{\\displaystyle\\sum_{i,j}\\frac{t_{ik}v_{kj}}\n            {(\\sum_{k'}z_{nk'}t_{ik'}v_{k'j})^{2}}\n            |y_{ijn}|^{2}}{\\displaystyle\\sum_{i,j}\\dfrac{t_{ik}v_{kj}}{\\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}}\n            \\right]z_{nk} \\\\\n            z_{nk}\n            &\\leftarrow\\frac{z_{nk}}{\\sum_{n'}z_{n'k}}.\n        \"\"\"\n        if self.domain != 2:\n            raise ValueError(\"Domain parameter is expected 2, but given {}.\".format(self.domain))\n\n        if self.demix_filter is None:\n            Y = self.output\n        else:\n            X, W = self.input, self.demix_filter\n            Y = self.separate(X, demix_filter=W)\n\n        Y2 = np.abs(Y) ** 2\n\n        Z = self.latent\n        T, V = self.basis, self.activation\n\n        TV = T[:, :, np.newaxis] * V[np.newaxis, :, :]\n        ZTV = self.reconstruct_nmf(T, V, latent=Z)\n\n        ZTV2 = ZTV**2\n        TV_ZTV2 = TV[np.newaxis, :, :, :] / ZTV2[:, :, np.newaxis, :]\n        num = np.sum(TV_ZTV2 * Y2[:, :, np.newaxis, :], axis=(1, 3))\n\n        TV_ZTV = TV[np.newaxis, :, :, :] / ZTV[:, :, np.newaxis, :]\n        denom = np.sum(TV_ZTV, axis=(1, 3))\n\n        Z = (num / denom) * Z\n        Z = Z / Z.sum(axis=0)\n\n        self.latent = Z\n\n    def update_basis_me(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update NMF bases by ME algorithm.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        Update :math:`t_{ikn}` as follows:\n\n        .. math::\n            t_{ik}\n            \\leftarrow\\left[\n            \\frac{\\displaystyle\\sum_{j,n}\\frac{z_{nk}v_{kj}}\n            {(\\sum_{k'}z_{nk'}t_{ik'}v_{k'j})^{2}}\n            |y_{ijn}|^{2}}{\\displaystyle\\sum_{j,n}\n            \\dfrac{z_{nk}v_{kj}}{\\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}}\n            \\right]t_{ik},\n\n        if ``partitioning=True``. Otherwise\n\n        .. math::\n            t_{ikn}\n            \\leftarrow\\left[\\frac{\\displaystyle\\sum_{j}\n            \\dfrac{v_{kjn}}{(\\sum_{k'}t_{ik'n}v_{k'jn})^{2}}|y_{ijn}|^{2}}\n            {\\displaystyle\\sum_{j}\\frac{v_{kjn}}{\\sum_{k'}t_{ik'n}v_{k'jn}}}\\right]\n            t_{ikn}.\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if self.domain != 2:\n            raise ValueError(\"Domain parameter is expected 2, but given {}.\".format(self.domain))\n\n        if self.demix_filter is None:\n            Y = self.output\n        else:\n            X, W = self.input, self.demix_filter\n            Y = self.separate(X, demix_filter=W)\n\n        Y2 = np.abs(Y) ** 2\n\n        if self.partitioning:\n            Z = self.latent\n            T, V = self.basis, self.activation\n\n            ZV = Z[:, :, np.newaxis] * V[np.newaxis, :, :]\n            ZTV = self.reconstruct_nmf(T, V, latent=Z)\n\n            ZTV2 = ZTV**2\n            ZV_ZTV2 = ZV[:, np.newaxis, :, :] / ZTV2[:, :, np.newaxis, :]\n            num = np.sum(ZV_ZTV2 * Y2[:, :, np.newaxis, :], axis=(0, 3))\n\n            ZV_ZTV = ZV[:, np.newaxis, :, :] / ZTV[:, :, np.newaxis, :]\n            denom = np.sum(ZV_ZTV, axis=(0, 3))\n        else:\n            T, V = self.basis, self.activation\n\n            TV = self.reconstruct_nmf(T, V)\n\n            TV2 = TV**2\n            V_TV2 = V[:, np.newaxis, :, :] / TV2[:, :, np.newaxis, :]\n            num = np.sum(V_TV2 * Y2[:, :, np.newaxis, :], axis=3)\n\n            V_TV = V[:, np.newaxis, :, :] / TV[:, :, np.newaxis, :]\n            denom = np.sum(V_TV, axis=3)\n\n        T = (num / denom) * T\n        T = flooring_fn(T)\n\n        self.basis = T\n\n    def update_activation_me(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update NMF activations by ME algorithm.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        Update :math:`t_{ikn}` as follows:\n\n        .. math::\n            v_{kj}\n            \\leftarrow\\left[\\frac{\\displaystyle\\sum_{i,n}\\frac{z_{nk}t_{ik}}\n            {(\\sum_{k'}z_{nk'}t_{ik'}v_{k'j})^{2}}\n            |y_{ijn}|^{2}}{\\displaystyle\\sum_{i,n}\\dfrac{z_{nk}t_{ik}}{\\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}}\n            \\right]v_{kj},\n\n        if ``partitioning=True``. Otherwise\n\n        .. math::\n            v_{kjn}\n            \\leftarrow \\left[\\frac{\\displaystyle\\sum_{i}\n            \\dfrac{t_{ikn}}{(\\sum_{k'}t_{ik'n}v_{k'jn})^{2}}|y_{ijn}|^{2}}\n            {\\displaystyle\\sum_{i}\\frac{t_{ikn}}{\\sum_{k'}t_{ik'n}v_{k'jn}}}\n            \\right]v_{kjn}.\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if self.domain != 2:\n            raise ValueError(\"Domain parameter is expected 2, but given {}.\".format(self.domain))\n\n        if self.demix_filter is None:\n            Y = self.output\n        else:\n            X, W = self.input, self.demix_filter\n            Y = self.separate(X, demix_filter=W)\n\n        Y2 = np.abs(Y) ** 2\n\n        if self.partitioning:\n            Z = self.latent\n            T, V = self.basis, self.activation\n\n            ZT = Z[:, np.newaxis, :] * T[np.newaxis, :, :]\n            ZTV = self.reconstruct_nmf(T, V, latent=Z)\n\n            ZTV2 = ZTV**2\n            ZT_ZTV2 = ZT[:, :, :, np.newaxis] / ZTV2[:, :, np.newaxis, :]\n            num = np.sum(ZT_ZTV2 * Y2[:, :, np.newaxis, :], axis=(0, 1))\n\n            ZT_ZTV = ZT[:, :, :, np.newaxis] / ZTV[:, :, np.newaxis, :]\n            denom = np.sum(ZT_ZTV, axis=(0, 1))\n        else:\n            T, V = self.basis, self.activation\n\n            TV = self.reconstruct_nmf(T, V)\n\n            TV2 = TV**2\n            T_TV2 = T[:, :, :, np.newaxis] / TV2[:, :, np.newaxis, :]\n            num = np.sum(T_TV2 * Y2[:, :, np.newaxis, :], axis=1)\n\n            T_TV = T[:, :, :, np.newaxis] / TV[:, :, np.newaxis, :]\n            denom = np.sum(T_TV, axis=1)\n\n        V = (num / denom) * V\n        V = flooring_fn(V)\n\n        self.activation = V\n\n    def update_spatial_model(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update demixing filters once.\n\n        - If ``spatial_algorithm`` is ``IP`` or ``IP1``, ``update_spatial_model_ip1`` is called.\n        - If ``spatial_algorithm`` is ``ISS`` or ``ISS1``, ``update_spatial_model_iss1`` is called.\n        - If ``spatial_algorithm`` is ``IP2``, ``update_spatial_model_ip2`` is called.\n        - If ``spatial_algorithm`` is ``ISS2``, ``update_spatial_model_iss2`` is called.\n        - If ``spatial_algorithm`` is ``IPA``, ``update_spatial_model_ipa`` is called.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if self.spatial_algorithm in [\"IP\", \"IP1\"]:\n            self.update_spatial_model_ip1(flooring_fn=flooring_fn)\n        elif self.spatial_algorithm in [\"IP2\"]:\n            self.update_spatial_model_ip2(flooring_fn=flooring_fn)\n        elif self.spatial_algorithm in [\"ISS\", \"ISS1\"]:\n            self.update_spatial_model_iss1(flooring_fn=flooring_fn)\n        elif self.spatial_algorithm in [\"ISS2\"]:\n            self.update_spatial_model_iss2(flooring_fn=flooring_fn)\n        elif self.spatial_algorithm in [\"IPA\"]:\n            self.update_spatial_model_ipa(flooring_fn=flooring_fn)\n        else:\n            raise NotImplementedError(\"Not support {}.\".format(self.spatial_algorithm))\n\n    def update_spatial_model_ip1(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update demixing filters once using iterative projection.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        Demixing filters are updated sequentially for :math:`n=1,\\ldots,N` as follows:\n\n        .. math::\n            \\boldsymbol{w}_{in}\n            &\\leftarrow\\left(\\boldsymbol{W}_{in}^{\\mathsf{H}}\\boldsymbol{U}_{in}\\right)^{-1} \\\n            \\boldsymbol{e}_{n}, \\\\\n            \\boldsymbol{w}_{in}\n            &\\leftarrow\\frac{\\boldsymbol{w}_{in}}\n            {\\sqrt{\\boldsymbol{w}_{in}^{\\mathsf{H}}\\boldsymbol{U}_{in}\\boldsymbol{w}_{in}}},\n\n        where\n\n        .. math::\n            \\boldsymbol{U}_{in}\n            = \\frac{1}{J}\\sum_{j}\n            \\frac{1}{\\left(\\sum_{k}z_{nk}t_{ik}v_{kj}\\right)^{\\frac{2}{p}}}\n            \\boldsymbol{x}_{ij}\\boldsymbol{x}_{ij}^{\\mathsf{H}}\n\n        if ``partitioning=True``, otherwise\n\n        .. math::\n            \\boldsymbol{U}_{in}\n            = \\frac{1}{J}\\sum_{j}\n            \\frac{1}{\\left(\\sum_{k}t_{ikn}v_{kjn}\\right)^{\\frac{2}{p}}}\n            \\boldsymbol{x}_{ij}\\boldsymbol{x}_{ij}^{\\mathsf{H}}.\n        \"\"\"\n        p = self.domain\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        X, W = self.input, self.demix_filter\n\n        if self.partitioning:\n            Z = self.latent\n            T, V = self.basis, self.activation\n\n            ZTV = self.reconstruct_nmf(T, V, latent=Z)\n            ZTV2p = ZTV ** (2 / p)\n            varphi = 1 / ZTV2p\n        else:\n            T, V = self.basis, self.activation\n\n            TV = self.reconstruct_nmf(T, V)\n            TV2p = TV ** (2 / p)\n            varphi = 1 / TV2p\n\n        XX_Hermite = X[:, np.newaxis, :, :] * X[np.newaxis, :, :, :].conj()\n        XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3)\n\n        varphi = varphi.transpose(1, 0, 2)\n        varphi_XX = varphi[:, :, np.newaxis, np.newaxis, :] * XX_Hermite[:, np.newaxis, :, :, :]\n        U = np.mean(varphi_XX, axis=-1)\n\n        self.demix_filter = update_by_ip1(W, U, flooring_fn=flooring_fn)\n\n    def update_spatial_model_ip2(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update demixing filters once using pairwise iterative projection \\\n        following [#nakashima2021faster]_.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        For :math:`n_{1}` and :math:`n_{2}` (:math:`n_{1}\\neq n_{2}`),\n        compute weighted covariance matrix as follows:\n\n        .. math::\n            \\boldsymbol{U}_{in}\n            = \\frac{1}{J}\\sum_{j}\n            \\frac{1}{r_{ijn}}\\boldsymbol{x}_{ij}\\boldsymbol{x}_{ij}^{\\mathsf{H}},\n\n        :math:`r_{ijn}` is computed by\n\n        .. math::\n            r_{ijn}\n            = \\left(\\sum_{k}z_{nk}t_{ik}v_{kj}\\right)^{\\frac{2}{p}}\n\n        if ``partitioning=True``.\n        Otherwise,\n\n        .. math::\n            r_{ijn}\n            = \\left(\\sum_{k}t_{ikn}v_{kjn}\\right)^{\\frac{2}{p}}.\n\n        Using :math:`\\boldsymbol{U}_{in_{1}}` and\n        :math:`\\boldsymbol{U}_{in_{2}}`, we compute generalized eigenvectors.\n\n        .. math::\n            \\left({\\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}}^{\\mathsf{H}}\\boldsymbol{U}_{in_{1}}\n            \\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\\right)\\boldsymbol{h}_{i}\n            = \\lambda_{i}\n            \\left({\\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}}^{\\mathsf{H}}\\boldsymbol{U}_{in_{2}}\n            \\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\\right)\\boldsymbol{h}_{i},\n\n        where\n\n        .. math::\n            \\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\n            &= (\\boldsymbol{W}_{i}\\boldsymbol{U}_{in_{1}})^{-1}\n            (\n            \\begin{array}{cc}\n                \\boldsymbol{e}_{n_{1}} & \\boldsymbol{e}_{n_{2}}\n            \\end{array}\n            ), \\\\\n            \\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\n            &= (\\boldsymbol{W}_{i}\\boldsymbol{U}_{in_{2}})^{-1}\n            (\n            \\begin{array}{cc}\n                \\boldsymbol{e}_{n_{1}} & \\boldsymbol{e}_{n_{2}}\n            \\end{array}\n            ).\n\n        After that, we standardize two eigenvectors :math:`\\boldsymbol{h}_{in_{1}}`\n        and :math:`\\boldsymbol{h}_{in_{2}}`.\n\n        .. math::\n            \\boldsymbol{h}_{in_{1}}\n            &\\leftarrow\\frac{\\boldsymbol{h}_{in_{1}}}\n            {\\sqrt{\\boldsymbol{h}_{in_{1}}^{\\mathsf{H}}\n            \\left({\\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}}^{\\mathsf{H}}\\boldsymbol{U}_{in_{1}}\n            \\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\\right)\n            \\boldsymbol{h}_{in_{1}}}}, \\\\\n            \\boldsymbol{h}_{in_{2}}\n            &\\leftarrow\\frac{\\boldsymbol{h}_{in_{2}}}\n            {\\sqrt{\\boldsymbol{h}_{in_{2}}^{\\mathsf{H}}\n            \\left({\\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}}^{\\mathsf{H}}\\boldsymbol{U}_{in_{2}}\n            \\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\\right)\n            \\boldsymbol{h}_{in_{2}}}}.\n\n        Then, update :math:`\\boldsymbol{w}_{in_{1}}` and :math:`\\boldsymbol{w}_{in_{2}}`\n        simultaneously.\n\n        .. math::\n            \\boldsymbol{w}_{in_{1}}\n            &\\leftarrow \\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\\boldsymbol{h}_{in_{1}} \\\\\n            \\boldsymbol{w}_{in_{2}}\n            &\\leftarrow \\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\\boldsymbol{h}_{in_{2}}\n\n        At each iteration, we update pairs of :math:`n_{1}` and :math:`n_{1}`\n        for :math:`n_{1}\\neq n_{2}`.\n\n        .. [#nakashima2021faster] T. Nakashima, R. Scheibler, Y. Wakabayashi, and N. Ono, \\\n            \"Faster independent low-rank matrix analysis with pairwise updates of demixing vectors,\"\n            in *Proc. EUSIPCO*, 2021, pp. 301-305.\n        \"\"\"\n        p = self.domain\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        X, W = self.input, self.demix_filter\n\n        if self.partitioning:\n            Z = self.latent\n            T, V = self.basis, self.activation\n            ZTV = self.reconstruct_nmf(T, V, latent=Z)\n            R = ZTV ** (2 / p)\n        else:\n            T, V = self.basis, self.activation\n            TV = self.reconstruct_nmf(T, V)\n            R = TV ** (2 / p)\n\n        varphi = 1 / R\n\n        XX_Hermite = X[:, np.newaxis, :, :] * X[np.newaxis, :, :, :].conj()\n        XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3)\n\n        varphi = varphi.transpose(1, 0, 2)\n        varphi_XX = varphi[:, :, np.newaxis, np.newaxis, :] * XX_Hermite[:, np.newaxis, :, :, :]\n        U = np.mean(varphi_XX, axis=-1)\n\n        self.demix_filter = update_by_ip2(\n            W, U, flooring_fn=flooring_fn, pair_selector=self.pair_selector\n        )\n\n    def update_spatial_model_iss1(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update estimated spectrograms once using iterative source steering.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        Update :math:`y_{ijn}` as follows:\n\n        .. math::\n            \\boldsymbol{y}_{ij}\n            & \\leftarrow\\boldsymbol{y}_{ij} - \\boldsymbol{d}_{in}y_{ijn} \\\\\n            d_{inn'}\n            &= \\begin{cases}\n                \\dfrac{\\displaystyle\\sum_{j}\\dfrac{1}{r_{ijn}}\n                y_{ijn'}y_{ijn}^{*}}{\\displaystyle\\sum_{j}\\dfrac{1}\n                {r_{ijn}}|y_{ijn}|^{2}}\n                & (n'\\neq n) \\\\\n                1 - \\dfrac{1}{\\sqrt{\\displaystyle\\dfrac{1}{J}\\sum_{j}\\dfrac{1}\n                {r_{ijn}}\n                |y_{ijn}|^{2}}} & (n'=n)\n            \\end{cases},\n\n        where\n\n        .. math::\n            r_{ijn}\n            = \\left(\\sum_{k}z_{nk}t_{ik}v_{kj}\\right)^{\\frac{2}{p}},\n\n        if ``partitioning=True``. Otherwise\n\n        .. math::\n            r_{ijn}\n            = \\left(\\sum_{k}t_{ikn}v_{kjn}\\right)^{\\frac{2}{p}}.\n        \"\"\"\n        p = self.domain\n\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        Y = self.output\n\n        if self.partitioning:\n            Z = self.latent\n            T, V = self.basis, self.activation\n            ZTV = self.reconstruct_nmf(T, V, latent=Z)\n            R = ZTV ** (2 / p)\n        else:\n            T, V = self.basis, self.activation\n            TV = self.reconstruct_nmf(T, V)\n            R = TV ** (2 / p)\n\n        varphi = 1 / R\n\n        self.output = update_by_iss1(Y, varphi, flooring_fn=flooring_fn)\n\n    def update_spatial_model_iss2(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update estimated spectrograms once using pairwise iterative source steering.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        Compute :math:`\\boldsymbol{G}_{in}^{(n_{1},n_{2})}`\n        and :math:`\\boldsymbol{f}_{in}^{(n_{1},n_{2})}` for :math:`n_{1}\\neq n_{2}`:\n\n        .. math::\n            \\begin{array}{rclc}\n                \\boldsymbol{G}_{in}^{(n_{1},n_{2})}\n                &=& {\\displaystyle\\frac{1}{J}\\sum_{j}}\\dfrac{1}{r_{ijn}}\n                \\boldsymbol{y}_{ij}^{(n_{1},n_{2})}{\\boldsymbol{y}_{ij}^{(n_{1},n_{2})}}^{\\mathsf{H}}\n                &(n=1,\\ldots,N), \\\\\n                \\boldsymbol{f}_{in}^{(n_{1},n_{2})}\n                &=& {\\displaystyle\\frac{1}{J}\\sum_{j}}\n                \\dfrac{1}{r_{ijn}}y_{ijn}^{*}\\boldsymbol{y}_{ij}^{(n_{1},n_{2})}\n                &(n\\neq n_{1},n_{2}),\n            \\end{array}\n\n        where\n\n        .. math::\n            r_{ijn}\n            = \\left(\\sum_{k}z_{nk}t_{ik}v_{kj}\\right)^{\\frac{2}{p}}\n\n        if ``partitioning=True``.\n        Otherwise,\n\n        .. math::\n            r_{ijn}\n            = \\left(\\sum_{k}t_{ikn}v_{kjn}\\right)^{\\frac{2}{p}}.\n\n        Using :math:`\\boldsymbol{G}_{in}^{(n_{1},n_{2})}` and\n        :math:`\\boldsymbol{f}_{in}^{(n_{1},n_{2})}`, we compute\n\n        .. math::\n            \\begin{array}{rclc}\n                \\boldsymbol{p}_{in}\n                &=& \\dfrac{\\boldsymbol{h}_{in}}\n                {\\sqrt{\\boldsymbol{h}_{in}^{\\mathsf{H}}\\boldsymbol{G}_{in}^{(n_{1},n_{2})}\n                \\boldsymbol{h}_{in}}} & (n=n_{1},n_{2}), \\\\\n                \\boldsymbol{q}_{in}\n                &=& -{\\boldsymbol{G}_{in}^{(n_{1},n_{2})}}^{-1}\\boldsymbol{f}_{in}^{(n_{1},n_{2})}\n                & (n\\neq n_{1},n_{2}),\n            \\end{array}\n\n        where :math:`\\boldsymbol{h}_{in}` (:math:`n=n_{1},n_{2}`) is\n        a generalized eigenvector obtained from\n\n        .. math::\n            \\boldsymbol{G}_{in_{1}}^{(n_{1},n_{2})}\\boldsymbol{h}_{i}\n            = \\lambda_{i}\\boldsymbol{G}_{in_{2}}^{(n_{1},n_{2})}\\boldsymbol{h}_{i}.\n\n        Separated signal :math:`y_{ijn}` is updated as follows:\n\n        .. math::\n            y_{ijn}\n            &\\leftarrow\\begin{cases}\n            &\\boldsymbol{p}_{in}^{\\mathsf{H}}\\boldsymbol{y}_{ij}^{(n_{1},n_{2})}\n            & (n=n_{1},n_{2}) \\\\\n            &\\boldsymbol{q}_{in}^{\\mathsf{H}}\\boldsymbol{y}_{ij}^{(n_{1},n_{2})} + y_{ijn}\n            & (n\\neq n_{1},n_{2})\n            \\end{cases}.\n        \"\"\"\n        p = self.domain\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        Y = self.output\n\n        if self.partitioning:\n            Z = self.latent\n            T, V = self.basis, self.activation\n            ZTV = self.reconstruct_nmf(T, V, latent=Z)\n            R = ZTV ** (2 / p)\n        else:\n            T, V = self.basis, self.activation\n            TV = self.reconstruct_nmf(T, V)\n            R = TV ** (2 / p)\n\n        varphi = 1 / R\n\n        self.output = update_by_iss2(\n            Y, varphi, flooring_fn=flooring_fn, pair_selector=self.pair_selector\n        )\n\n    def update_spatial_model_ipa(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update estimated spectrograms once using \\\n        iterative projection with adjustment [#scheibler2021independent]_.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        Compute :math:`r_{ijn}` as follows:\n\n        .. math::\n            r_{ijn}\n            = \\left(\\sum_{k}z_{nk}t_{ik}v_{kj}\\right)^{\\frac{2}{p}},\n\n        if ``partitioning=True``. Otherwise\n\n        .. math::\n            r_{ijn}\n            = \\left(\\sum_{k}t_{ikn}v_{kjn}\\right)^{\\frac{2}{p}}.\n\n        Then, by defining, :math:`\\tilde{\\boldsymbol{U}}_{in'}`,\n        :math:`\\boldsymbol{A}_{in}\\in\\mathbb{R}^{(N-1)\\times(N-1)}`,\n        :math:`\\boldsymbol{b}_{in}\\in\\mathbb{C}^{N-1}`,\n        :math:`\\boldsymbol{C}_{in}\\in\\mathbb{C}^{(N-1)\\times(N-1)}`,\n        :math:`\\boldsymbol{d}_{in}\\in\\mathbb{C}^{N-1}`,\n        and :math:`z_{in}\\in\\mathbb{R}_{\\geq 0}` as follows:\n\n        .. math::\n\n            \\tilde{\\boldsymbol{U}}_{in'}\n            &= \\frac{1}{J}\\sum_{j}\\frac{1}{r_{ijn'}}\n            \\boldsymbol{y}_{ij}\\boldsymbol{y}_{ij}^{\\mathsf{H}}, \\\\\n            \\boldsymbol{A}_{in}\n            &= \\mathrm{diag}(\\ldots,\n            \\boldsymbol{e}_{n}^{\\mathsf{T}}\\tilde{\\boldsymbol{U}}_{in'}\\boldsymbol{e}_{n}\n            ,\\ldots)~~(n'\\neq n), \\\\\n            \\boldsymbol{b}_{in}\n            &= (\\ldots,\n            \\boldsymbol{e}_{n}^{\\mathsf{T}}\\tilde{\\boldsymbol{U}}_{in'}\\boldsymbol{e}_{n'}\n            ,\\ldots)^{\\mathsf{T}}~~(n'\\neq n), \\\\\n            \\boldsymbol{C}_{in}\n            &= \\bar{\\boldsymbol{E}}_{n}^{\\mathsf{T}}(\\tilde{\\boldsymbol{U}}_{in}^{-1})^{*}\n            \\bar{\\boldsymbol{E}}_{n}, \\\\\n            \\boldsymbol{d}_{in}\n            &= \\bar{\\boldsymbol{E}}_{n}^{\\mathsf{T}}(\\tilde{\\boldsymbol{U}}_{in}^{-1})^{*}\n            \\boldsymbol{e}_{n}, \\\\\n            z_{in}\n            &= \\boldsymbol{e}_{n}^{\\mathsf{T}}\\tilde{\\boldsymbol{U}}_{in}^{-1}\\boldsymbol{e}_{n}\n            - \\boldsymbol{d}_{in}^{\\mathsf{H}}\\boldsymbol{C}_{in}^{-1}\\boldsymbol{d}_{in}.\n\n        :math:`\\boldsymbol{y}_{ij}` is updated via log-quadratically penelized\n        quadratic minimization (LQPQM).\n\n        .. math::\n            \\check{\\boldsymbol{q}}_{in}\n            &\\leftarrow \\mathrm{LQPQM2}(\\boldsymbol{H}_{in},\\boldsymbol{v}_{in},z_{in}), \\\\\n            \\boldsymbol{q}_{in}\n            &\\leftarrow \\boldsymbol{G}_{in}^{-1}\\check{\\boldsymbol{q}}_{in}\n            - \\boldsymbol{A}_{in}^{-1}\\boldsymbol{b}_{in}, \\\\\n            \\tilde{\\boldsymbol{q}}_{in}\n            &\\leftarrow \\boldsymbol{e}_{n} - \\bar{\\boldsymbol{E}}_{n}\\boldsymbol{q}_{in}, \\\\\n            \\boldsymbol{p}_{in}\n            &\\leftarrow \\frac{\\tilde{\\boldsymbol{U}}_{in}^{-1}\\tilde{\\boldsymbol{q}}_{in}^{*}}\n            {\\sqrt{(\\tilde{\\boldsymbol{q}}_{in}^{*})^{\\mathsf{H}}\\tilde{\\boldsymbol{U}}_{in}^{-1}\n            \\tilde{\\boldsymbol{q}}_{in}^{*}}}, \\\\\n            \\boldsymbol{\\Upsilon}_{i}^{(n)}\n            &\\leftarrow \\boldsymbol{I}\n            + \\boldsymbol{e}_{n}(\\boldsymbol{p}_{in} - \\boldsymbol{e}_{n})^{\\mathsf{H}}\n            + \\bar{\\boldsymbol{E}}_{n}\\boldsymbol{q}_{in}^{*}\\boldsymbol{e}_{n}^{\\mathsf{T}}, \\\\\n            \\boldsymbol{y}_{ij}\n            &\\leftarrow \\boldsymbol{\\Upsilon}_{i}^{(n)}\\boldsymbol{y}_{ij},\n\n        .. [#scheibler2021independent]\n            R. Scheibler,\n            \"Independent vector analysis via log-quadratically penalized quadratic minimization,\"\n            *IEEE Trans. Signal Processing*, vol. 69, pp. 2509-2524, 2021.\n        \"\"\"\n        self.lqpqm_normalization: bool\n        self.newton_iter: int\n\n        p = self.domain\n        normalization = self.lqpqm_normalization\n        max_iter = self.newton_iter\n\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        Y = self.output\n\n        if self.partitioning:\n            Z = self.latent\n            T, V = self.basis, self.activation\n            ZTV = self.reconstruct_nmf(T, V, latent=Z)\n            R = ZTV ** (2 / p)\n        else:\n            T, V = self.basis, self.activation\n            TV = self.reconstruct_nmf(T, V)\n            R = TV ** (2 / p)\n\n        varphi = 1 / R\n\n        self.output = update_by_ipa(\n            Y,\n            varphi,\n            normalization=normalization,\n            flooring_fn=flooring_fn,\n            max_iter=max_iter,\n        )\n\n    def compute_loss(self) -> float:\n        r\"\"\"Compute loss :math:`\\mathcal{L}`.\n\n        :math:`\\mathcal{L}` is given as follows:\n\n        .. math::\n            \\mathcal{L}\n            = \\frac{1}{J}\\sum_{i,j,n}\\left(\\frac{|y_{ijn}|^{2}}{r_{ijn}}\n            + \\log r_{ijn}\\right)\n            - 2\\sum_{i}\\log|\\det\\boldsymbol{W}_{i}|,\n\n        where\n\n        .. math::\n            r_{ijn}\n            = \\left(\\sum_{k}z_{nk}t_{ik}v_{kj}\\right)^{\\frac{2}{p}},\n\n        if ``partitioning=True``. Otherwise\n\n        .. math::\n            r_{ijn}\n            = \\left(\\sum_{k}t_{ikn}v_{kjn}\\right)^{\\frac{2}{p}}.\n\n        Returns:\n            Computed loss.\n        \"\"\"\n        p = self.domain\n\n        if self.demix_filter is None:\n            X, Y = self.input, self.output\n            Y2 = np.abs(Y) ** 2\n            X, Y = X.transpose(1, 0, 2), Y.transpose(1, 0, 2)\n            X_Hermite = X.transpose(0, 2, 1).conj()\n            XX_Hermite = X @ X_Hermite\n            W = Y @ X_Hermite @ np.linalg.inv(XX_Hermite)\n        else:\n            X, W = self.input, self.demix_filter\n            Y = self.separate(X, demix_filter=W)\n            Y2 = np.abs(Y) ** 2\n\n        if self.partitioning:\n            Z = self.latent\n            T, V = self.basis, self.activation\n            ZTV = self.reconstruct_nmf(T, V, latent=Z)\n            R = ZTV ** (2 / p)\n            loss = Y2 / R + (2 / p) * np.log(ZTV)\n        else:\n            T, V = self.basis, self.activation\n            TV = self.reconstruct_nmf(T, V)\n            R = TV ** (2 / p)\n            loss = Y2 / R + (2 / p) * np.log(TV)\n\n        logdet = self.compute_logdet(W)  # (n_bins,)\n\n        loss = np.sum(loss.mean(axis=-1), axis=0) - 2 * logdet\n        loss = loss.sum(axis=0).item()\n\n        return loss\n\n    def apply_projection_back(self) -> None:\n        r\"\"\"Apply projection back technique to estimated spectrograms.\"\"\"\n        if self.demix_filter is None:\n            assert self.scale_restoration, \"Set self.scale_restoration=True.\"\n\n            X, Y = self.input, self.output\n            Y_scaled = projection_back(Y, reference=X, reference_id=self.reference_id)\n\n            self.output = Y_scaled\n        else:\n            super().apply_projection_back()\n\n    def apply_minimal_distortion_principle(self) -> None:\n        r\"\"\"Apply minimal distortion principle to estimated spectrograms.\"\"\"\n        if self.demix_filter is None:\n            X, Y = self.input, self.output\n            Y_scaled = minimal_distortion_principle(Y, reference=X, reference_id=self.reference_id)\n\n            self.output = Y_scaled\n        else:\n            super().apply_minimal_distortion_principle()\n\n\nclass TILRMA(ILRMABase):\n    r\"\"\"Independent low-rank matrix analysis (ILRMA) on Student's *t* distribution.\n\n    We assume :math:`y_{ijn}` follows a Student's *t* distribution.\n\n    .. math::\n        p(y_{ijn})\n        = \\frac{1}{\\pi r_{ijn}}\n        \\left(1+\\frac{2}{\\nu}\\frac{|y_{ijn}|^{2}}{r_{ijn}}\\right)^{-\\frac{2+\\nu}{2}},\n\n    where\n\n    .. math::\n        r_{ijn}\n        = \\left(\\sum_{k}z_{nk}t_{ik}v_{kj}\\right)^{\\frac{2}{p}},\n\n    if ``partitioning=True``. Otherwise,\n\n    .. math::\n        r_{ijn}\n        = \\left(\\sum_{k}t_{ikn}v_{kjn}\\right)^{\\frac{2}{p}}.\n\n    :math:`\\nu` is a degree of freedom parameter.\n\n    Args:\n        n_basis (int):\n            Number of NMF bases.\n        dof (float):\n            Degree of freedom parameter in student's-t distribution.\n        spatial_algorithm (str):\n            Algorithm for demixing filter updates.\n            Choose ``IP``, ``IP1``, ``IP2``, ``ISS``, ``ISS1``, or ``ISS2``.\n            Default: ``IP``.\n        source_algorithm (str):\n            Algorithm for source model updates.\n            Choose ``MM`` or ``ME``. Default: ``MM``.\n        domain (float):\n            Domain parameter. Default: ``2``.\n        partitioning (bool):\n            Whether to use partioning function. Default: ``False``.\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n        pair_selector (callable, optional):\n            Selector to choose updaing pair in ``IP2`` and ``ISS2``.\n            If ``None`` is given, ``sequential_pair_selector`` is used.\n            Default: ``None``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        normalization (bool or str, optional):\n            Normalization of demixing filters and NMF parameters.\n            Choose ``power`` or ``projection_back``.\n            Default: ``power``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back``\n            or ``minimal_distortion_principle``. Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the update algorithm if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back and minimal distortion principle. Default: ``0``.\n        rng (numpy.random.Generator, optioinal):\n            Random number generator. This is mainly used to randomly initialize NMF.\n            If ``None`` is given, ``np.random.default_rng()`` is used.\n            Default: ``None``.\n\n    Examples:\n        Update demixing filters by IP:\n\n        .. code-block:: python\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> ilrma = TILRMA(\n            ...     n_basis=2,\n            ...     dof=1000,\n            ...     spatial_algorithm=\"IP\",\n            ...     rng=np.random.default_rng(42),\n            ... )\n            >>> spectrogram_est = ilrma(spectrogram_mix, n_iter=100)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n        Update demixing filters by IP2:\n\n        .. code-block:: python\n\n            >>> from ssspy.utils.select_pair import sequential_pair_selector\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> ilrma = TILRMA(\n            ...     n_basis=2,\n            ...     dof=1000,\n            ...     spatial_algorithm=\"IP2\",\n            ...     pair_selector=sequential_pair_selector,\n            ...     rng=np.random.default_rng(42),\n            ... )\n            >>> spectrogram_est = ilrma(spectrogram_mix, n_iter=100)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n        Update demixing filters by ISS:\n\n        .. code-block:: python\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> ilrma = TILRMA(\n            ...     n_basis=2,\n            ...     dof=1000,\n            ...     spatial_algorithm=\"ISS\",\n            ...     rng=np.random.default_rng(42),\n            ... )\n            >>> spectrogram_est = ilrma(spectrogram_mix, n_iter=100)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n        Update demixing filters by ISS2:\n\n        .. code-block:: python\n\n            >>> import functools\n            >>> from ssspy.utils.select_pair import sequential_pair_selector\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> ilrma = TILRMA(\n            ...     n_basis=2,\n            ...     dof=1000,\n            ...     spatial_algorithm=\"ISS2\",\n            ...     pair_selector=functools.partial(sequential_pair_selector, step=2),\n            ...     rng=np.random.default_rng(42),\n            ... )\n            >>> spectrogram_est = ilrma(spectrogram_mix, n_iter=100)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n    \"\"\"\n\n    def __init__(\n        self,\n        n_basis: int,\n        dof: float,\n        spatial_algorithm: str = \"IP\",\n        source_algorithm: str = \"MM\",\n        domain: float = 2,\n        partitioning: bool = False,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        pair_selector: Optional[Callable[[int], Iterable[Tuple[int, int]]]] = None,\n        callbacks: Optional[\n            Union[Callable[[\"TILRMA\"], None], List[Callable[[\"TILRMA\"], None]]]\n        ] = None,\n        normalization: Optional[Union[bool, str]] = True,\n        scale_restoration: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n        rng: Optional[np.random.Generator] = None,\n    ) -> None:\n        super().__init__(\n            n_basis=n_basis,\n            partitioning=partitioning,\n            flooring_fn=flooring_fn,\n            callbacks=callbacks,\n            scale_restoration=scale_restoration,\n            record_loss=record_loss,\n            reference_id=reference_id,\n            rng=rng,\n        )\n\n        assert spatial_algorithm in spatial_algorithms, \"Not support {}.\".format(spatial_algorithms)\n        assert source_algorithm in source_algorithms, \"Not support {}.\".format(source_algorithm)\n        assert 0 < domain <= 2, \"domain parameter should be chosen from [0, 2].\"\n\n        if spatial_algorithm == \"IPA\":\n            raise ValueError(\"IPA is not supported for t-ILRMA.\")\n\n        if source_algorithm == \"ME\":\n            assert domain == 2, \"domain parameter should be 2 when you specify ME algorithm.\"\n\n        self.dof = dof\n        self.spatial_algorithm = spatial_algorithm\n        self.source_algorithm = source_algorithm\n        self.domain = domain\n        self.normalization = normalization\n\n        if pair_selector is None:\n            if spatial_algorithm in [\"IP2\", \"ISS2\"]:\n                self.pair_selector = sequential_pair_selector\n        else:\n            self.pair_selector = pair_selector\n\n    def __call__(\n        self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs\n    ) -> np.ndarray:\n        r\"\"\"Separate a frequency-domain multichannel signal.\n\n        Args:\n            input (numpy.ndarray):\n                The mixture signal in frequency-domain.\n                The shape is (n_channels, n_bins, n_frames).\n            n_iter (int):\n                The number of iterations of demixing filter updates.\n                Default: ``100``.\n            initial_call (bool):\n                If ``True``, perform callbacks (and computation of loss if necessary)\n                before iterations.\n\n        Returns:\n            numpy.ndarray of the separated signal in frequency-domain.\n            The shape is (n_channels, n_bins, n_frames).\n        \"\"\"\n        self.input = input.copy()\n\n        self._reset(flooring_fn=self.flooring_fn, **kwargs)\n\n        # Call __call__ of ILRMABase's parent, i.e. __call__ of IterativeMethodBase\n        super(ILRMABase, self).__call__(n_iter=n_iter, initial_call=initial_call)\n\n        if self.scale_restoration:\n            self.restore_scale()\n\n        if self.demix_filter is None:\n            pass\n        else:\n            self.output = self.separate(self.input, demix_filter=self.demix_filter)\n\n        return self.output\n\n    def __repr__(self) -> str:\n        s = \"TILRMA(\"\n        s += \"n_basis={n_basis}\"\n        s += \", dof={dof}\"\n        s += \", spatial_algorithm={spatial_algorithm}\"\n        s += \", source_algorithm={source_algorithm}\"\n        s += \", domain={domain}\"\n        s += \", partitioning={partitioning}\"\n        s += \", normalization={normalization}\"\n        s += \", scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def _reset(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n        **kwargs,\n    ) -> None:\n        r\"\"\"Reset attributes by given keyword arguments.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n            kwargs:\n                Keyword arguments to set as attributes of ILRMA.\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        super()._reset(flooring_fn=flooring_fn, **kwargs)\n\n        if self.spatial_algorithm in [\"ISS\", \"ISS1\", \"ISS2\"]:\n            self.demix_filter = None\n\n    def update_once(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update NMF parameters and demixing filters once.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        self.update_source_model(flooring_fn=flooring_fn)\n        self.update_spatial_model(flooring_fn=flooring_fn)\n\n        if self.normalization:\n            self.normalize(flooring_fn=flooring_fn)\n\n    def update_source_model(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update NMF bases, activations, and latent variables.\n\n        - If ``source_algorithm`` is ``MM``, ``update_source_model_mm`` is called.\n        - If ``source_algorithm`` is ``ME``, ``update_source_model_me`` is called.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if self.source_algorithm == \"MM\":\n            self.update_source_model_mm(flooring_fn=flooring_fn)\n        elif self.source_algorithm == \"ME\":\n            self.update_source_model_me(flooring_fn=flooring_fn)\n        else:\n            raise ValueError(\n                \"{}-algorithm-based source model updates are not supported.\".format(\n                    self.source_algorithm\n                )\n            )\n\n    def update_source_model_mm(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update NMF bases, activations, and latent variables by MM algorithm.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if self.partitioning:\n            self.update_latent_mm()\n\n        self.update_basis_mm(flooring_fn=flooring_fn)\n        self.update_activation_mm(flooring_fn=flooring_fn)\n\n    def update_source_model_me(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update NMF bases, activations, and latent variables by ME algorithm.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if self.partitioning:\n            self.update_latent_me()\n\n        self.update_basis_me(flooring_fn=flooring_fn)\n        self.update_activation_me(flooring_fn=flooring_fn)\n\n    def update_latent_mm(self) -> None:\n        r\"\"\"Update latent variables in NMF by MM algorithm.\n\n        Update :math:`z_{nk}` as follows:\n\n        .. math::\n            z_{nk}\n            &\\leftarrow\\left[\\frac{\\displaystyle\\sum_{i,j}\\frac{t_{ik}v_{kj}}\n            {\\tilde{r}_{ijn}\\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}\n            |y_{ijn}|^{2}}{\\displaystyle\\sum_{i,j}\\dfrac{t_{ik}v_{kj}}{\\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}}\n            \\right]^{\\frac{p}{p+2}}z_{nk} \\\\\n            z_{nk}\n            &\\leftarrow\\frac{z_{nk}}{\\sum_{n'}z_{n'k}}, \\\\\n            \\tilde{r}_{ijn}\n            &= \\frac{\\nu}{\\nu+2}\\left(\\sum_{k}z_{nk}t_{ik}v_{kj}\\right)^{\\frac{2}{p}}\n            + \\frac{2}{\\nu+2}|y_{ijn}|^{2}.\n        \"\"\"\n        p = self.domain\n        nu = self.dof\n\n        if self.demix_filter is None:\n            Y = self.output\n        else:\n            X, W = self.input, self.demix_filter\n            Y = self.separate(X, demix_filter=W)\n\n        Y2 = np.abs(Y) ** 2\n        p_p2 = p / (p + 2)\n        nu_nu2 = nu / (nu + 2)\n\n        Z = self.latent\n        T, V = self.basis, self.activation\n\n        TV = T[:, :, np.newaxis] * V[np.newaxis, :, :]\n        ZTV = self.reconstruct_nmf(T, V, latent=Z)\n\n        ZTV2p = ZTV ** (2 / p)\n        R_tilde = nu_nu2 * ZTV2p + (1 - nu_nu2) * Y2\n        RZTV = R_tilde * ZTV\n        TV_RZTV = TV[np.newaxis, :, :, :] / RZTV[:, :, np.newaxis, :]\n        num = np.sum(TV_RZTV * Y2[:, :, np.newaxis, :], axis=(1, 3))\n\n        TV_ZTV = TV[np.newaxis, :, :, :] / ZTV[:, :, np.newaxis, :]\n        denom = np.sum(TV_ZTV, axis=(1, 3))\n\n        Z = ((num / denom) ** p_p2) * Z\n        Z = Z / Z.sum(axis=0)\n\n        self.latent = Z\n\n    def update_basis_mm(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update NMF bases by MM algorithm.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        Update :math:`t_{ikn}` as follows:\n\n        .. math::\n            t_{ik}\n            &\\leftarrow\\left[\n            \\frac{\\displaystyle\\sum_{j,n}\\frac{z_{nk}v_{kj}}\n            {\\tilde{r}_{ijn}\\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}\n            |y_{ijn}|^{2}}{\\displaystyle\\sum_{j,n}\n            \\dfrac{z_{nk}v_{kj}}{\\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}}\n            \\right]^{\\frac{p}{p+2}}t_{ik}, \\\\\n            \\tilde{r}_{ijn}\n            &= \\frac{\\nu}{\\nu+2}\\left(\\sum_{k}z_{nk}t_{ik}v_{kj}\\right)^{\\frac{2}{p}}\n            + \\frac{2}{\\nu+2}|y_{ijn}|^{2},\n\n        if ``partitioning=True``. Otherwise\n\n        .. math::\n            t_{ikn}\n            &\\leftarrow \\left[\\frac{\\displaystyle\\sum_{j}\n            \\dfrac{v_{kjn}}{\\tilde{r}_{ijn}\\sum_{k'}t_{ik'n}v_{k'jn}}|y_{ijn}|^{2}}\n            {\\displaystyle\\sum_{j}\\frac{v_{kjn}}{\\sum_{k'}t_{ik'n}v_{k'jn}}}\\right]\n            ^{\\frac{p}{p+2}}t_{ikn}, \\\\\n            \\tilde{r}_{ijn}\n            &= \\frac{\\nu}{\\nu+2}\\left(\\sum_{k}t_{ikn}v_{kjn}\\right)^{\\frac{2}{p}}\n            + \\frac{2}{\\nu+2}|y_{ijn}|^{2}.\n        \"\"\"\n        p = self.domain\n        nu = self.dof\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if self.demix_filter is None:\n            Y = self.output\n        else:\n            X, W = self.input, self.demix_filter\n            Y = self.separate(X, demix_filter=W)\n\n        Y2 = np.abs(Y) ** 2\n        p_p2 = p / (p + 2)\n        nu_nu2 = nu / (nu + 2)\n\n        if self.partitioning:\n            Z = self.latent\n            T, V = self.basis, self.activation\n\n            ZV = Z[:, :, np.newaxis] * V[np.newaxis, :, :]\n            ZTV = self.reconstruct_nmf(T, V, latent=Z)\n\n            ZTV2p = ZTV ** (2 / p)\n            R_tilde = nu_nu2 * ZTV2p + (1 - nu_nu2) * Y2\n            RZTV = R_tilde * ZTV\n            ZV_RZTV = ZV[:, np.newaxis, :, :] / RZTV[:, :, np.newaxis, :]\n            num = np.sum(ZV_RZTV * Y2[:, :, np.newaxis, :], axis=(0, 3))\n\n            ZV_ZTV = ZV[:, np.newaxis, :, :] / ZTV[:, :, np.newaxis, :]\n            denom = np.sum(ZV_ZTV, axis=(0, 3))\n        else:\n            T, V = self.basis, self.activation\n\n            TV = self.reconstruct_nmf(T, V)\n\n            TV2p = TV ** (2 / p)\n            R_tilde = nu_nu2 * TV2p + (1 - nu_nu2) * Y2\n            RTV = R_tilde * TV\n            V_RTV = V[:, np.newaxis, :, :] / RTV[:, :, np.newaxis, :]\n            num = np.sum(V_RTV * Y2[:, :, np.newaxis, :], axis=3)\n\n            V_TV = V[:, np.newaxis, :, :] / TV[:, :, np.newaxis, :]\n            denom = np.sum(V_TV, axis=3)\n\n        T = ((num / denom) ** p_p2) * T\n        T = flooring_fn(T)\n\n        self.basis = T\n\n    def update_activation_mm(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update NMF activations by MM algorithm.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        Update :math:`v_{kjn}` as follows:\n\n        .. math::\n            v_{kj}\n            &\\leftarrow\\left[\\frac{\\displaystyle\\sum_{i,n}\\frac{z_{nk}t_{ik}}\n            {\\tilde{r}_{ijn}\\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}\n            |y_{ijn}|^{2}}{\\displaystyle\\sum_{i,n}\\dfrac{z_{nk}t_{ik}}{\\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}}\n            \\right]^{\\frac{p}{p+2}}v_{kj}, \\\\\n            \\tilde{r}_{ijn}\n            &= \\frac{\\nu}{\\nu+2}\\left(\\sum_{k}z_{nk}t_{ik}v_{kj}\\right)^{\\frac{2}{p}}\n            + \\frac{2}{\\nu+2}|y_{ijn}|^{2},\n\n        if ``partitioning=True``. Otherwise\n\n        .. math::\n            v_{kjn}\n            &\\leftarrow \\left[\\frac{\\displaystyle\\sum_{i}\n            \\dfrac{t_{ikn}}{\\tilde{r}_{ijn}\\sum_{k'}t_{ik'n}v_{k'jn}}|y_{ijn}|^{2}}\n            {\\displaystyle\\sum_{i}\\frac{t_{ikn}}{\\sum_{k'}t_{ik'n}v_{k'jn}}}\n            \\right]^{\\frac{p}{p+2}}v_{kjn}, \\\\\n            \\tilde{r}_{ijn}\n            &= \\frac{\\nu}{\\nu+2}\\left(\\sum_{k}t_{ikn}v_{kjn}\\right)^{\\frac{2}{p}}\n            + \\frac{2}{\\nu+2}|y_{ijn}|^{2}.\n        \"\"\"\n        p = self.domain\n        nu = self.dof\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if self.demix_filter is None:\n            Y = self.output\n        else:\n            X, W = self.input, self.demix_filter\n            Y = self.separate(X, demix_filter=W)\n\n        Y2 = np.abs(Y) ** 2\n        p_p2 = p / (p + 2)\n        nu_nu2 = nu / (nu + 2)\n\n        if self.partitioning:\n            Z = self.latent\n            T, V = self.basis, self.activation\n\n            ZT = Z[:, np.newaxis, :] * T[np.newaxis, :, :]\n            ZTV = self.reconstruct_nmf(T, V, latent=Z)\n\n            ZTV2p = ZTV ** (2 / p)\n            R_tilde = nu_nu2 * ZTV2p + (1 - nu_nu2) * Y2\n            RZTV = R_tilde * ZTV\n            ZT_RZTV = ZT[:, :, :, np.newaxis] / RZTV[:, :, np.newaxis, :]\n            num = np.sum(ZT_RZTV * Y2[:, :, np.newaxis, :], axis=(0, 1))\n\n            ZT_ZTV = ZT[:, :, :, np.newaxis] / ZTV[:, :, np.newaxis, :]\n            denom = np.sum(ZT_ZTV, axis=(0, 1))\n        else:\n            T, V = self.basis, self.activation\n\n            TV = self.reconstruct_nmf(T, V)\n\n            TV2p = TV ** (2 / p)\n            R_tilde = nu_nu2 * TV2p + (1 - nu_nu2) * Y2\n            RTV = R_tilde * TV\n            T_RTV = T[:, :, :, np.newaxis] / RTV[:, :, np.newaxis, :]\n            num = np.sum(T_RTV * Y2[:, :, np.newaxis, :], axis=1)\n\n            T_TV = T[:, :, :, np.newaxis] / TV[:, :, np.newaxis, :]\n            denom = np.sum(T_TV, axis=1)\n\n        V = ((num / denom) ** p_p2) * V\n        V = flooring_fn(V)\n\n        self.activation = V\n\n    def update_latent_me(self) -> None:\n        r\"\"\"Update latent variables in NMF by ME algorithm.\n\n        Update :math:`z_{nk}` as follows:\n\n        .. math::\n            z_{nk}\n            &\\leftarrow\\frac{\\displaystyle\\sum_{i,j}\\frac{t_{ik}v_{kj}}\n            {\\tilde{r}_{ijn}\\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}\n            |y_{ijn}|^{2}}{\\displaystyle\\sum_{i,j}\\dfrac{t_{ik}v_{kj}}{\\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}}\n            z_{nk} \\\\\n            z_{nk}\n            &\\leftarrow\\frac{z_{nk}}{\\sum_{n'}z_{n'k}}, \\\\\n            \\tilde{r}_{ijn}\n            &= \\frac{\\nu}{\\nu+2}\\sum_{k}z_{nk}t_{ik}v_{kj}+\\frac{2}{\\nu+2}|y_{ijn}|^{2}.\n        \"\"\"\n        nu = self.dof\n\n        if self.domain != 2:\n            raise ValueError(\"Domain parameter is expected 2, but given {}.\".format(self.domain))\n\n        if self.demix_filter is None:\n            Y = self.output\n        else:\n            X, W = self.input, self.demix_filter\n            Y = self.separate(X, demix_filter=W)\n\n        Y2 = np.abs(Y) ** 2\n        nu_nu2 = nu / (nu + 2)\n\n        Z = self.latent\n        T, V = self.basis, self.activation\n\n        TV = T[:, :, np.newaxis] * V[np.newaxis, :, :]\n        ZTV = self.reconstruct_nmf(T, V, latent=Z)\n\n        R_tilde = nu_nu2 * ZTV + (1 - nu_nu2) * Y2\n        RZTV = R_tilde * ZTV\n        TV_RZTV = TV[np.newaxis, :, :, :] / RZTV[:, :, np.newaxis, :]\n        num = np.sum(TV_RZTV * Y2[:, :, np.newaxis, :], axis=(1, 3))\n\n        TV_ZTV = TV[np.newaxis, :, :, :] / ZTV[:, :, np.newaxis, :]\n        denom = np.sum(TV_ZTV, axis=(1, 3))\n\n        Z = (num / denom) * Z\n        Z = Z / Z.sum(axis=0)\n\n        self.latent = Z\n\n    def update_basis_me(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update NMF bases by ME algorithm.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        Update :math:`t_{ikn}` as follows:\n\n        .. math::\n            t_{ik}\n            &\\leftarrow\n            \\frac{\\displaystyle\\sum_{j,n}\\frac{z_{nk}v_{kj}}\n            {\\tilde{r}_{ijn}\\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}\n            |y_{ijn}|^{2}}{\\displaystyle\\sum_{j,n}\n            \\dfrac{z_{nk}v_{kj}}{\\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}}\n            t_{ik}, \\\\\n            \\tilde{r}_{ijn}\n            &= \\frac{\\nu}{\\nu+2}\\sum_{k}z_{nk}t_{ik}v_{kj}+\\frac{2}{\\nu+2}|y_{ijn}|^{2},\n\n        if ``partitioning=True``. Otherwise\n\n        .. math::\n            t_{ikn}\n            &\\leftarrow\\frac{\\displaystyle\\sum_{j}\n            \\dfrac{v_{kjn}}{\\tilde{r}_{ijn}\\sum_{k'}t_{ik'n}v_{k'jn}}|y_{ijn}|^{2}}\n            {\\displaystyle\\sum_{j}\\frac{v_{kjn}}{\\sum_{k'}t_{ik'n}v_{k'jn}}}\n            t_{ikn}, \\\\\n            \\tilde{r}_{ijn}\n            &= \\frac{\\nu}{\\nu+2}\\sum_{k}t_{ikn}v_{kjn}+\\frac{2}{\\nu+2}|y_{ijn}|^{2}.\n        \"\"\"\n        nu = self.dof\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if self.domain != 2:\n            raise ValueError(\"Domain parameter is expected 2, but given {}.\".format(self.domain))\n\n        if self.demix_filter is None:\n            Y = self.output\n        else:\n            X, W = self.input, self.demix_filter\n            Y = self.separate(X, demix_filter=W)\n\n        Y2 = np.abs(Y) ** 2\n        nu_nu2 = nu / (nu + 2)\n\n        if self.partitioning:\n            Z = self.latent\n            T, V = self.basis, self.activation\n\n            ZV = Z[:, :, np.newaxis] * V[np.newaxis, :, :]\n            ZTV = self.reconstruct_nmf(T, V, latent=Z)\n\n            R_tilde = nu_nu2 * ZTV + (1 - nu_nu2) * Y2\n            RZTV = R_tilde * ZTV\n            ZV_RZTV = ZV[:, np.newaxis, :, :] / RZTV[:, :, np.newaxis, :]\n            num = np.sum(ZV_RZTV * Y2[:, :, np.newaxis, :], axis=(0, 3))\n\n            ZV_ZTV = ZV[:, np.newaxis, :, :] / ZTV[:, :, np.newaxis, :]\n            denom = np.sum(ZV_ZTV, axis=(0, 3))\n        else:\n            T, V = self.basis, self.activation\n\n            TV = self.reconstruct_nmf(T, V)\n\n            R_tilde = nu_nu2 * TV + (1 - nu_nu2) * Y2\n            RTV = R_tilde * TV\n            V_RTV = V[:, np.newaxis, :, :] / RTV[:, :, np.newaxis, :]\n            num = np.sum(V_RTV * Y2[:, :, np.newaxis, :], axis=3)\n\n            V_TV = V[:, np.newaxis, :, :] / TV[:, :, np.newaxis, :]\n            denom = np.sum(V_TV, axis=3)\n\n        T = (num / denom) * T\n        T = flooring_fn(T)\n\n        self.basis = T\n\n    def update_activation_me(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update NMF activations by ME algorithm.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        Update :math:`v_{kjn}` as follows:\n\n        .. math::\n            v_{kj}\n            &\\leftarrow\\frac{\\displaystyle\\sum_{i,n}\\frac{z_{nk}t_{ik}}\n            {\\tilde{r}_{ijn}\\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}\n            |y_{ijn}|^{2}}{\\displaystyle\\sum_{i,n}\\dfrac{z_{nk}t_{ik}}{\\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}}\n            v_{kj}, \\\\\n            \\tilde{r}_{ijn}\n            &= \\frac{\\nu}{\\nu+2}\\sum_{k}z_{nk}t_{ik}v_{kj}+\\frac{2}{\\nu+2}|y_{ijn}|^{2},\n\n        if ``partitioning=True``. Otherwise\n\n        .. math::\n            v_{kjn}\n            &\\leftarrow\\frac{\\displaystyle\\sum_{i}\n            \\dfrac{t_{ikn}}{\\tilde{r}_{ijn}\\sum_{k'}t_{ik'n}v_{k'jn}}|y_{ijn}|^{2}}\n            {\\displaystyle\\sum_{i}\\frac{t_{ikn}}{\\sum_{k'}t_{ik'n}v_{k'jn}}}\n            v_{kjn}, \\\\\n            \\tilde{r}_{ijn}\n            &= \\frac{\\nu}{\\nu+2}\\sum_{k}t_{ikn}v_{kjn}+\\frac{2}{\\nu+2}|y_{ijn}|^{2}.\n        \"\"\"\n        nu = self.dof\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if self.domain != 2:\n            raise ValueError(\"Domain parameter is expected 2, but given {}.\".format(self.domain))\n\n        if self.demix_filter is None:\n            Y = self.output\n        else:\n            X, W = self.input, self.demix_filter\n            Y = self.separate(X, demix_filter=W)\n\n        Y2 = np.abs(Y) ** 2\n        nu_nu2 = nu / (nu + 2)\n\n        if self.partitioning:\n            Z = self.latent\n            T, V = self.basis, self.activation\n\n            ZT = Z[:, np.newaxis, :] * T[np.newaxis, :, :]\n            ZTV = self.reconstruct_nmf(T, V, latent=Z)\n\n            R_tilde = nu_nu2 * ZTV + (1 - nu_nu2) * Y2\n            RZTV = R_tilde * ZTV\n            ZT_RZTV = ZT[:, :, :, np.newaxis] / RZTV[:, :, np.newaxis, :]\n            num = np.sum(ZT_RZTV * Y2[:, :, np.newaxis, :], axis=(0, 1))\n\n            ZT_ZTV = ZT[:, :, :, np.newaxis] / ZTV[:, :, np.newaxis, :]\n            denom = np.sum(ZT_ZTV, axis=(0, 1))\n        else:\n            T, V = self.basis, self.activation\n\n            TV = self.reconstruct_nmf(T, V)\n\n            R_tilde = nu_nu2 * TV + (1 - nu_nu2) * Y2\n            RTV = R_tilde * TV\n            T_RTV = T[:, :, :, np.newaxis] / RTV[:, :, np.newaxis, :]\n            num = np.sum(T_RTV * Y2[:, :, np.newaxis, :], axis=1)\n\n            T_TV = T[:, :, :, np.newaxis] / TV[:, :, np.newaxis, :]\n            denom = np.sum(T_TV, axis=1)\n\n        V = (num / denom) * V\n        V = flooring_fn(V)\n\n        self.activation = V\n\n    def update_spatial_model(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update demixing filters once.\n\n        - If ``spatial_algorithm`` is ``IP`` or ``IP1``, ``update_spatial_model_ip1`` is called.\n        - If ``spatial_algorithm`` is ``ISS`` or ``ISS1``, ``update_spatial_model_iss1`` is called.\n        - If ``spatial_algorithm`` is ``IP2``, ``update_spatial_model_ip2`` is called.\n        - If ``spatial_algorithm`` is ``ISS2``, ``update_spatial_model_iss2`` is called.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if self.spatial_algorithm in [\"IP\", \"IP1\"]:\n            self.update_spatial_model_ip1(flooring_fn=flooring_fn)\n        elif self.spatial_algorithm in [\"IP2\"]:\n            self.update_spatial_model_ip2(flooring_fn=flooring_fn)\n        elif self.spatial_algorithm in [\"ISS\", \"ISS1\"]:\n            self.update_spatial_model_iss1(flooring_fn=flooring_fn)\n        elif self.spatial_algorithm in [\"ISS2\"]:\n            self.update_spatial_model_iss2(flooring_fn=flooring_fn)\n        else:\n            raise NotImplementedError(\"Not support {}.\".format(self.spatial_algorithm))\n\n    def update_spatial_model_ip1(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update demixing filters once using iterative projection.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        Demixing filters are updated sequentially for :math:`n=1,\\ldots,N` as follows:\n\n        .. math::\n            \\boldsymbol{w}_{in}\n            &\\leftarrow\\left(\\boldsymbol{W}_{in}^{\\mathsf{H}}\\boldsymbol{U}_{in}\\right)^{-1} \\\n            \\boldsymbol{e}_{n}, \\\\\n            \\boldsymbol{w}_{in}\n            &\\leftarrow\\frac{\\boldsymbol{w}_{in}}\n            {\\sqrt{\\boldsymbol{w}_{in}^{\\mathsf{H}}\\boldsymbol{U}_{in}\\boldsymbol{w}_{in}}},\n\n        where\n\n        .. math::\n            \\boldsymbol{U}_{in}\n            = \\frac{1}{J}\\sum_{j}\n            \\frac{1}{\\tilde{r}_{ijn}}\\boldsymbol{x}_{ij}\\boldsymbol{x}_{ij}^{\\mathsf{H}}.\n\n        :math:`\\tilde{r}_{ijn}` is defined as\n\n        .. math::\n            \\tilde{r}_{ijn}\n            = \\frac{\\nu}{\\nu+2}\\left(\\sum_{k}z_{nk}t_{ik}v_{kj}\\right)^{\\frac{2}{p}}\n            + \\frac{2}{\\nu+2}|y_{ijn}|^{2},\n\n        if ``partitioning=True``. Otherwise\n\n        .. math::\n            \\tilde{r}_{ijn}\n            = \\frac{\\nu}{\\nu+2}\\left(\\sum_{k}t_{ikn}v_{kjn}\\right)^{\\frac{2}{p}}\n            + \\frac{2}{\\nu+2}|y_{ijn}|^{2}.\n        \"\"\"\n        p = self.domain\n        nu = self.dof\n\n        X, W = self.input, self.demix_filter\n        Y = self.separate(X, demix_filter=W)\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        Y2 = np.abs(Y) ** 2\n        nu_nu2 = nu / (nu + 2)\n\n        if self.partitioning:\n            Z = self.latent\n            T, V = self.basis, self.activation\n\n            ZTV = self.reconstruct_nmf(T, V, latent=Z)\n            ZTV2p = ZTV ** (2 / p)\n            R_tilde = nu_nu2 * ZTV2p + (1 - nu_nu2) * Y2\n        else:\n            T, V = self.basis, self.activation\n\n            TV = self.reconstruct_nmf(T, V)\n            TV2p = TV ** (2 / p)\n            R_tilde = nu_nu2 * TV2p + (1 - nu_nu2) * Y2\n\n        varphi = 1 / R_tilde\n\n        XX_Hermite = X[:, np.newaxis, :, :] * X[np.newaxis, :, :, :].conj()\n        XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3)\n\n        varphi = varphi.transpose(1, 0, 2)\n        varphi_XX = varphi[:, :, np.newaxis, np.newaxis, :] * XX_Hermite[:, np.newaxis, :, :, :]\n        U = np.mean(varphi_XX, axis=-1)\n\n        self.demix_filter = update_by_ip1(W, U, flooring_fn=flooring_fn)\n\n    def update_spatial_model_ip2(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update demixing filters once using pairwise iterative projection.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        For :math:`n_{1}` and :math:`n_{2}` (:math:`n_{1}\\neq n_{2}`),\n        compute weighted covariance matrix as follows:\n\n        .. math::\n            \\boldsymbol{U}_{in}\n            = \\frac{1}{J}\\sum_{j}\n            \\frac{1}{\\tilde{r}_{ijn}}\\boldsymbol{x}_{ij}\\boldsymbol{x}_{ij}^{\\mathsf{H}},\n\n        :math:`\\tilde{r}_{ijn}` is computed by\n\n        .. math::\n            \\tilde{r}_{ijn}\n            = \\frac{\\nu}{\\nu+2}\\left(\\sum_{k}z_{nk}t_{ik}v_{kj}\\right)^{\\frac{2}{p}}\n            + \\frac{2}{\\nu+2}|y_{ijn}|^{2},\n\n        if ``partitioning=True``. \\\n        Otherwise,\n\n        .. math::\n            \\tilde{r}_{ijn}\n            = \\frac{\\nu}{\\nu+2}\\left(\\sum_{k}t_{ikn}v_{kjn}\\right)^{\\frac{2}{p}}\n            + \\frac{2}{\\nu+2}|y_{ijn}|^{2}.\n\n        Using :math:`\\boldsymbol{U}_{in_{1}}` and\n        :math:`\\boldsymbol{U}_{in_{2}}`, we compute generalized eigenvectors.\n\n        .. math::\n            \\left({\\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}}^{\\mathsf{H}}\\boldsymbol{U}_{in_{1}}\n            \\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\\right)\\boldsymbol{h}_{i}\n            = \\lambda_{i}\n            \\left({\\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}}^{\\mathsf{H}}\\boldsymbol{U}_{in_{2}}\n            \\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\\right)\\boldsymbol{h}_{i},\n\n        where\n\n        .. math::\n            \\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\n            &= (\\boldsymbol{W}_{i}\\boldsymbol{U}_{in_{1}})^{-1}\n            (\n            \\begin{array}{cc}\n                \\boldsymbol{e}_{n_{1}} & \\boldsymbol{e}_{n_{2}}\n            \\end{array}\n            ), \\\\\n            \\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\n            &= (\\boldsymbol{W}_{i}\\boldsymbol{U}_{in_{2}})^{-1}\n            (\n            \\begin{array}{cc}\n                \\boldsymbol{e}_{n_{1}} & \\boldsymbol{e}_{n_{2}}\n            \\end{array}\n            ).\n\n        After that, we standardize two eigenvectors :math:`\\boldsymbol{h}_{in_{1}}`\n        and :math:`\\boldsymbol{h}_{in_{2}}`.\n\n        .. math::\n            \\boldsymbol{h}_{in_{1}}\n            &\\leftarrow\\frac{\\boldsymbol{h}_{in_{1}}}\n            {\\sqrt{\\boldsymbol{h}_{in_{1}}^{\\mathsf{H}}\n            \\left({\\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}}^{\\mathsf{H}}\\boldsymbol{U}_{in_{1}}\n            \\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\\right)\n            \\boldsymbol{h}_{in_{1}}}}, \\\\\n            \\boldsymbol{h}_{in_{2}}\n            &\\leftarrow\\frac{\\boldsymbol{h}_{in_{2}}}\n            {\\sqrt{\\boldsymbol{h}_{in_{2}}^{\\mathsf{H}}\n            \\left({\\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}}^{\\mathsf{H}}\\boldsymbol{U}_{in_{2}}\n            \\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\\right)\n            \\boldsymbol{h}_{in_{2}}}}.\n\n        Then, update :math:`\\boldsymbol{w}_{in_{1}}` and :math:`\\boldsymbol{w}_{in_{2}}`\n        simultaneously.\n\n        .. math::\n            \\boldsymbol{w}_{in_{1}}\n            &\\leftarrow \\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\\boldsymbol{h}_{in_{1}} \\\\\n            \\boldsymbol{w}_{in_{2}}\n            &\\leftarrow \\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\\boldsymbol{h}_{in_{2}}\n\n        At each iteration, we update pairs of :math:`n_{1}` and :math:`n_{1}`\n        for :math:`n_{1}\\neq n_{2}`.\n        \"\"\"\n        nu = self.dof\n        p = self.domain\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        X, W = self.input, self.demix_filter\n\n        nu_nu2 = nu / (nu + 2)\n        Y = self.separate(X, demix_filter=W)\n        Y2 = np.abs(Y) ** 2\n\n        if self.partitioning:\n            Z = self.latent\n            T, V = self.basis, self.activation\n\n            ZTV = self.reconstruct_nmf(T, V, latent=Z)\n            ZTV2p = ZTV ** (2 / p)\n            R_tilde = nu_nu2 * ZTV2p + (1 - nu_nu2) * Y2\n        else:\n            T, V = self.basis, self.activation\n\n            TV = self.reconstruct_nmf(T, V)\n            TV2p = TV ** (2 / p)\n            R_tilde = nu_nu2 * TV2p + (1 - nu_nu2) * Y2\n\n        varphi = 1 / R_tilde\n\n        XX_Hermite = X[:, np.newaxis, :, :] * X[np.newaxis, :, :, :].conj()\n        XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3)\n\n        varphi = varphi.transpose(1, 0, 2)\n        varphi_XX = varphi[:, :, np.newaxis, np.newaxis, :] * XX_Hermite[:, np.newaxis, :, :, :]\n        U = np.mean(varphi_XX, axis=-1)\n\n        self.demix_filter = update_by_ip2(\n            W, U, flooring_fn=flooring_fn, pair_selector=self.pair_selector\n        )\n\n    def update_spatial_model_iss1(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update estimated spectrograms once using iterative source steering.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        Update :math:`y_{ijn}` as follows:\n\n        .. math::\n            \\boldsymbol{y}_{ij}\n            & \\leftarrow\\boldsymbol{y}_{ij} - \\boldsymbol{d}_{in}y_{ijn} \\\\\n            d_{inn'}\n            &= \\begin{cases}\n                \\dfrac{\\displaystyle\\sum_{j}\\dfrac{1}{\\tilde{r}_{ijn}}\n                y_{ijn'}y_{ijn}^{*}}{\\displaystyle\\sum_{j}\\dfrac{1}\n                {\\tilde{r}_{ijn}}|y_{ijn}|^{2}}\n                & (n'\\neq n) \\\\\n                1 - \\dfrac{1}{\\sqrt{\\displaystyle\\dfrac{1}{J}\\sum_{j}\\dfrac{1}\n                {\\tilde{r}_{ijn}}|y_{ijn}|^{2}}}\n                & (n'=n)\n            \\end{cases}.\n\n        :math:`\\tilde{r}_{ijn}` is defined as\n\n        .. math::\n            \\tilde{r}_{ijn}\n            = \\frac{\\nu}{\\nu+2}\\left(\\sum_{k}z_{nk}t_{ik}v_{kj}\\right)^{\\frac{2}{p}}\n            + \\frac{2}{\\nu+2}|y_{ijn}|^{2},\n\n        if ``partitioning=True``. Otherwise\n\n        .. math::\n            \\tilde{r}_{ijn}\n            = \\frac{\\nu}{\\nu+2}\\left(\\sum_{k}t_{ikn}v_{kjn}\\right)^{\\frac{2}{p}}\n            + \\frac{2}{\\nu+2}|y_{ijn}|^{2}.\n        \"\"\"\n        p = self.domain\n        nu = self.dof\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        Y = self.output\n        Y2 = np.abs(Y) ** 2\n        nu_nu2 = nu / (nu + 2)\n\n        if self.partitioning:\n            Z = self.latent\n            T, V = self.basis, self.activation\n\n            ZTV = self.reconstruct_nmf(T, V, latent=Z)\n            ZTV2p = ZTV ** (2 / p)\n            R_tilde = nu_nu2 * ZTV2p + (1 - nu_nu2) * Y2\n        else:\n            T, V = self.basis, self.activation\n\n            TV = self.reconstruct_nmf(T, V)\n            TV2p = TV ** (2 / p)\n            R_tilde = nu_nu2 * TV2p + (1 - nu_nu2) * Y2\n\n        varphi = 1 / R_tilde\n\n        self.output = update_by_iss1(Y, varphi, flooring_fn=flooring_fn)\n\n    def update_spatial_model_iss2(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update estimated spectrograms once using pairwise iterative source steering.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        Compute :math:`\\boldsymbol{G}_{in}^{(n_{1},n_{2})}`\n        and :math:`\\boldsymbol{f}_{in}^{(n_{1},n_{2})}` for :math:`n_{1}\\neq n_{2}`:\n\n        .. math::\n            \\begin{array}{rclc}\n                \\boldsymbol{G}_{in}^{(n_{1},n_{2})}\n                &=& {\\displaystyle\\frac{1}{J}\\sum_{j}}\\dfrac{1}{\\tilde{r}_{ijn}}\n                \\boldsymbol{y}_{ij}^{(n_{1},n_{2})}{\\boldsymbol{y}_{ij}^{(n_{1},n_{2})}}^{\\mathsf{H}}\n                &(n=1,\\ldots,N), \\\\\n                \\boldsymbol{f}_{in}^{(n_{1},n_{2})}\n                &=& {\\displaystyle\\frac{1}{J}\\sum_{j}}\n                \\dfrac{1}{\\tilde{r}_{ijn}}y_{ijn}^{*}\\boldsymbol{y}_{ij}^{(n_{1},n_{2})}\n                &(n\\neq n_{1},n_{2}),\n            \\end{array}\n\n        where\n\n        .. math::\n            \\tilde{r}_{ijn}\n            = \\frac{\\nu}{\\nu+2}\\left(\\sum_{k}z_{nk}t_{ik}v_{kj}\\right)^{\\frac{2}{p}}\n            + \\frac{2}{\\nu+2}|y_{ijn}|^{2}\n\n        if ``partitioning=True``.\n        Otherwise,\n\n        .. math::\n            \\tilde{r}_{ijn}\n            = \\frac{\\nu}{\\nu+2}\\left(\\sum_{k}t_{ikn}v_{kjn}\\right)^{\\frac{2}{p}}\n            + \\frac{2}{\\nu+2}|y_{ijn}|^{2}.\n\n        Using :math:`\\boldsymbol{G}_{in}^{(n_{1},n_{2})}` and\n        :math:`\\boldsymbol{f}_{in}^{(n_{1},n_{2})}`, we compute\n\n        .. math::\n            \\begin{array}{rclc}\n                \\boldsymbol{p}_{in}\n                &=& \\dfrac{\\boldsymbol{h}_{in}}\n                {\\sqrt{\\boldsymbol{h}_{in}^{\\mathsf{H}}\\boldsymbol{G}_{in}^{(n_{1},n_{2})}\n                \\boldsymbol{h}_{in}}} & (n=n_{1},n_{2}), \\\\\n                \\boldsymbol{q}_{in}\n                &=& -{\\boldsymbol{G}_{in}^{(n_{1},n_{2})}}^{-1}\\boldsymbol{f}_{in}^{(n_{1},n_{2})}\n                & (n\\neq n_{1},n_{2}),\n            \\end{array}\n\n        where :math:`\\boldsymbol{h}_{in}` (:math:`n=n_{1},n_{2}`) is\n        a generalized eigenvector obtained from\n\n        .. math::\n            \\boldsymbol{G}_{in_{1}}^{(n_{1},n_{2})}\\boldsymbol{h}_{i}\n            = \\lambda_{i}\\boldsymbol{G}_{in_{2}}^{(n_{1},n_{2})}\\boldsymbol{h}_{i}.\n\n        Separated signal :math:`y_{ijn}` is updated as follows:\n\n        .. math::\n            y_{ijn}\n            &\\leftarrow\\begin{cases}\n            &\\boldsymbol{p}_{in}^{\\mathsf{H}}\\boldsymbol{y}_{ij}^{(n_{1},n_{2})}\n            & (n=n_{1},n_{2}) \\\\\n            &\\boldsymbol{q}_{in}^{\\mathsf{H}}\\boldsymbol{y}_{ij}^{(n_{1},n_{2})} + y_{ijn}\n            & (n\\neq n_{1},n_{2})\n            \\end{cases}.\n        \"\"\"\n        p = self.domain\n        nu = self.dof\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        Y = self.output\n        Y2 = np.abs(Y) ** 2\n        nu_nu2 = nu / (nu + 2)\n\n        if self.partitioning:\n            Z = self.latent\n            T, V = self.basis, self.activation\n\n            ZTV = self.reconstruct_nmf(T, V, latent=Z)\n            ZTV2p = ZTV ** (2 / p)\n            R_tilde = nu_nu2 * ZTV2p + (1 - nu_nu2) * Y2\n        else:\n            T, V = self.basis, self.activation\n\n            TV = self.reconstruct_nmf(T, V)\n            TV2p = TV ** (2 / p)\n            R_tilde = nu_nu2 * TV2p + (1 - nu_nu2) * Y2\n\n        varphi = 1 / R_tilde\n\n        self.output = update_by_iss2(\n            Y, varphi, flooring_fn=flooring_fn, pair_selector=self.pair_selector\n        )\n\n    def compute_loss(self) -> float:\n        r\"\"\"Compute loss :math:`\\mathcal{L}`.\n\n        :math:`\\mathcal{L}` is given as follows:\n\n        .. math::\n            \\mathcal{L}\n            = \\frac{1}{J}\\sum_{i,j}\n            \\left\\{1+\\frac{\\nu}{2}\\log\\left(1+\\frac{2}{\\nu}\n            \\frac{|y_{ijn}|^{2}}{r_{ijn}}\\right)\n            + \\log r_{ijn}\\right\\}\n            -2\\sum_{i}\\log\\left|\\det\\boldsymbol{W}_{i}\\right|,\n\n        where\n\n        .. math::\n            r_{ijn}\n            = \\left(\\sum_{k}z_{nk}t_{ik}v_{kj}\\right)^{\\frac{2}{p}},\n\n        if ``partitioning=True``, otherwise\n\n        .. math::\n            r_{ijn}\n            = \\left(\\sum_{k}t_{ikn}v_{kjn}\\right)^{\\frac{2}{p}}.\n\n        Returns:\n            Computed loss.\n        \"\"\"\n        nu = self.dof\n        p = self.domain\n\n        if self.demix_filter is None:\n            X, Y = self.input, self.output\n            Y2 = np.abs(Y) ** 2\n            X, Y = X.transpose(1, 0, 2), Y.transpose(1, 0, 2)\n            X_Hermite = X.transpose(0, 2, 1).conj()\n            XX_Hermite = X @ X_Hermite\n            W = Y @ X_Hermite @ np.linalg.inv(XX_Hermite)\n        else:\n            X, W = self.input, self.demix_filter\n            Y = self.separate(X, demix_filter=W)\n            Y2 = np.abs(Y) ** 2\n\n        if self.partitioning:\n            Z = self.latent\n            T, V = self.basis, self.activation\n            ZTV = self.reconstruct_nmf(T, V, latent=Z)\n            Y2ZTV2p = Y2 / (ZTV ** (2 / p))\n            loss = (1 + nu / 2) * np.log(1 + (2 / nu) * Y2ZTV2p) + (2 / p) * np.log(ZTV)\n        else:\n            T, V = self.basis, self.activation\n            TV = self.reconstruct_nmf(T, V)\n            Y2TV2p = Y2 / (TV ** (2 / p))\n            loss = (1 + nu / 2) * np.log(1 + (2 / nu) * Y2TV2p) + (2 / p) * np.log(TV)\n\n        logdet = self.compute_logdet(W)  # (n_bins,)\n\n        loss = np.sum(loss.mean(axis=-1), axis=0) - 2 * logdet\n        loss = loss.sum(axis=0).item()\n\n        return loss\n\n    def apply_projection_back(self) -> None:\n        r\"\"\"Apply projection back technique to estimated spectrograms.\"\"\"\n        if self.demix_filter is None:\n            assert self.scale_restoration, \"Set self.scale_restoration=True.\"\n\n            X, Y = self.input, self.output\n            Y_scaled = projection_back(Y, reference=X, reference_id=self.reference_id)\n\n            self.output = Y_scaled\n        else:\n            super().apply_projection_back()\n\n    def apply_minimal_distortion_principle(self) -> None:\n        r\"\"\"Apply minimal distortion principle to estimated spectrograms.\"\"\"\n        if self.demix_filter is None:\n            X, Y = self.input, self.output\n            Y_scaled = minimal_distortion_principle(Y, reference=X, reference_id=self.reference_id)\n\n            self.output = Y_scaled\n        else:\n            super().apply_minimal_distortion_principle()\n\n\nclass GGDILRMA(ILRMABase):\n    r\"\"\"Independent low-rank matrix analysis (ILRMA) on a generalized Gaussian distribution.\n\n    We assume :math:`y_{ijn}` follows a generalized Gaussian distribution.\n\n    .. math::\n        p(y_{ijn})\n        = \\frac{\\beta}{2\\pi r_{ijn}\\Gamma\\left(\\frac{2}{\\beta}\\right)}\n        \\exp\\left\\{-\\left(\\frac{|y_{ijn}|^{2}}{r_{ijn}}\\right)^{\\frac{\\beta}{2}}\\right\\},\n\n    where\n\n    .. math::\n        r_{ijn}\n        = \\left(\\sum_{k}z_{nk}t_{ik}v_{kj}\\right)^{\\frac{2}{p}},\n\n    if ``partitioning=True``. Otherwise,\n\n    .. math::\n        r_{ijn}\n        = \\left(\\sum_{k}t_{ikn}v_{kjn}\\right)^{\\frac{2}{p}}.\n\n    :math:`\\beta` is a shape parameter of a generalized Gaussian distribution.\n\n    Args:\n        n_basis (int):\n            Number of NMF bases.\n        beta (float):\n            Shape parameter in generalized Gaussian distribution.\n        spatial_algorithm (str):\n            Algorithm for demixing filter updates.\n            Choose ``IP``, ``IP1``, ``IP2``, ``ISS``, ``ISS1``, or ``ISS2``.\n            Default: ``IP``.\n        source_algorithm (str):\n            Algorithm for source model updates.\n            Only ``MM`` is supported: Default: ``MM``.\n        domain (float):\n            Domain parameter. Default: ``2``.\n        partitioning (bool):\n            Whether to use partioning function. Default: ``False``.\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n        pair_selector (callable, optional):\n            Selector to choose updaing pair in ``IP2`` and ``ISS2``.\n            If ``None`` is given, ``sequential_pair_selector`` is used.\n            Default: ``None``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        normalization (bool or str, optional):\n            Normalization of demixing filters and NMF parameters.\n            Choose ``power`` or ``projection_back``.\n            Default: ``power``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back``\n            or ``minimal_distortion_principle``. Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the update algorithm if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back and minimal distortion principle. Default: ``0``.\n        rng (numpy.random.Generator, optioinal):\n            Random number generator. This is mainly used to randomly initialize NMF.\n            If ``None`` is given, ``np.random.default_rng()`` is used.\n            Default: ``None``.\n\n    Examples:\n        Update demixing filters by IP:\n\n        .. code-block:: python\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> ilrma = GGDILRMA(\n            ...     n_basis=2,\n            ...     beta=1.99,\n            ...     spatial_algorithm=\"IP\",\n            ...     rng=np.random.default_rng(42),\n            ... )\n            >>> spectrogram_est = ilrma(spectrogram_mix, n_iter=100)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n        Update demixing filters by IP2:\n\n        .. code-block:: python\n\n            >>> from ssspy.utils.select_pair import sequential_pair_selector\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> ilrma = GGDILRMA(\n            ...     n_basis=2,\n            ...     beta=1.99,\n            ...     spatial_algorithm=\"IP2\",\n            ...     pair_selector=sequential_pair_selector,\n            ...     rng=np.random.default_rng(42),\n            ... )\n            >>> spectrogram_est = ilrma(spectrogram_mix, n_iter=100)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n        Update demixing filters by ISS:\n\n        .. code-block:: python\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> ilrma = GGDILRMA(\n            ...     n_basis=2,\n            ...     beta=1.99,\n            ...     spatial_algorithm=\"ISS\",\n            ...     rng=np.random.default_rng(42),\n            ... )\n            >>> spectrogram_est = ilrma(spectrogram_mix, n_iter=100)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n        Update demixing filters by ISS2:\n\n        .. code-block:: python\n\n            >>> import functools\n            >>> from ssspy.utils.select_pair import sequential_pair_selector\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> ilrma = GGDILRMA(\n            ...     n_basis=2,\n            ...     beta=1.99,\n            ...     spatial_algorithm=\"ISS2\",\n            ...     pair_selector=functools.partial(sequential_pair_selector, step=2),\n            ...     rng=np.random.default_rng(42),\n            ... )\n            >>> spectrogram_est = ilrma(spectrogram_mix, n_iter=100)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n    \"\"\"\n\n    def __init__(\n        self,\n        n_basis: int,\n        beta: float,\n        spatial_algorithm: str = \"IP\",\n        source_algorithm: str = \"MM\",\n        domain: float = 2,\n        partitioning: bool = False,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        pair_selector: Optional[Callable[[int], Iterable[Tuple[int, int]]]] = None,\n        callbacks: Optional[\n            Union[Callable[[\"GGDILRMA\"], None], List[Callable[[\"GGDILRMA\"], None]]]\n        ] = None,\n        normalization: Optional[Union[bool, str]] = True,\n        scale_restoration: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n        rng: Optional[np.random.Generator] = None,\n    ) -> None:\n        super().__init__(\n            n_basis=n_basis,\n            partitioning=partitioning,\n            flooring_fn=flooring_fn,\n            callbacks=callbacks,\n            scale_restoration=scale_restoration,\n            record_loss=record_loss,\n            reference_id=reference_id,\n            rng=rng,\n        )\n\n        assert 0 < beta < 2, \"Shape parameter {} shoule be chosen from (0, 2).\".format(beta)\n        assert spatial_algorithm in spatial_algorithms, \"Not support {}.\".format(spatial_algorithms)\n        assert source_algorithm == \"MM\", \"Not support {}.\".format(source_algorithm)\n        assert 0 < domain <= 2, \"domain parameter should be chosen from [0, 2].\"\n\n        if spatial_algorithm == \"IPA\":\n            raise ValueError(\"IPA is not supported for GGD-ILRMA.\")\n\n        self.beta = beta\n        self.spatial_algorithm = spatial_algorithm\n        self.source_algorithm = source_algorithm\n        self.domain = domain\n        self.normalization = normalization\n\n        if pair_selector is None:\n            if spatial_algorithm in [\"IP2\", \"ISS2\"]:\n                self.pair_selector = sequential_pair_selector\n        else:\n            self.pair_selector = pair_selector\n\n    def __call__(\n        self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs\n    ) -> np.ndarray:\n        r\"\"\"Separate a frequency-domain multichannel signal.\n\n        Args:\n            input (numpy.ndarray):\n                The mixture signal in frequency-domain.\n                The shape is (n_channels, n_bins, n_frames).\n            n_iter (int):\n                The number of iterations of demixing filter updates.\n                Default: ``100``.\n            initial_call (bool):\n                If ``True``, perform callbacks (and computation of loss if necessary)\n                before iterations.\n\n        Returns:\n            numpy.ndarray of the separated signal in frequency-domain.\n            The shape is (n_channels, n_bins, n_frames).\n        \"\"\"\n        self.input = input.copy()\n\n        self._reset(flooring_fn=self.flooring_fn, **kwargs)\n\n        # Call __call__ of ILRMABase's parent, i.e. __call__ of IterativeMethodBase\n        super(ILRMABase, self).__call__(n_iter=n_iter, initial_call=initial_call)\n\n        if self.scale_restoration:\n            self.restore_scale()\n\n        if self.demix_filter is None:\n            pass\n        else:\n            self.output = self.separate(self.input, demix_filter=self.demix_filter)\n\n        return self.output\n\n    def __repr__(self) -> str:\n        s = \"GGDILRMA(\"\n        s += \"n_basis={n_basis}\"\n        s += \", beta={beta}\"\n        s += \", spatial_algorithm={spatial_algorithm}\"\n        s += \", source_algorithm={source_algorithm}\"\n        s += \", domain={domain}\"\n        s += \", partitioning={partitioning}\"\n        s += \", normalization={normalization}\"\n        s += \", scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def _reset(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n        **kwargs,\n    ) -> None:\n        r\"\"\"Reset attributes by given keyword arguments.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n            kwargs:\n                Keyword arguments to set as attributes of ILRMA.\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        super()._reset(flooring_fn=flooring_fn, **kwargs)\n\n        if self.spatial_algorithm in [\"ISS\", \"ISS1\", \"ISS2\"]:\n            self.demix_filter = None\n\n    def update_once(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update NMF parameters and demixing filters once.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        self.update_source_model(flooring_fn=flooring_fn)\n        self.update_spatial_model(flooring_fn=flooring_fn)\n\n        if self.normalization:\n            self.normalize(flooring_fn=flooring_fn)\n\n    def update_source_model(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update NMF bases, activations, and latent variables.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if self.source_algorithm == \"MM\":\n            self.update_source_model_mm(flooring_fn=flooring_fn)\n        else:\n            raise ValueError(\n                \"{}-algorithm-based source model updates are not supported.\".format(\n                    self.source_algorithm\n                )\n            )\n\n    def update_source_model_mm(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update NMF bases, activations, and latent variables by MM algorithm.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if self.partitioning:\n            self.update_latent_mm()\n\n        self.update_basis_mm(flooring_fn=flooring_fn)\n        self.update_activation_mm(flooring_fn=flooring_fn)\n\n    def update_latent_mm(self) -> None:\n        r\"\"\"Update latent variables in NMF by MM algorithm.\n\n        Update :math:`z_{nk}` as follows:\n\n        .. math::\n            z_{nk}\n            &\\leftarrow\\left[\n            \\frac{\\beta}{2}\n            \\frac{\\displaystyle\\sum_{i,j}\\frac{t_{ik}v_{kj}}\n            {(\\sum_{k'}z_{nk'}t_{ik'}v_{k'j})^{\\frac{\\beta+p}{2}}}|y_{ijn}|^{\\beta}}\n            {\\displaystyle\\sum_{i,j}\\frac{t_{ik}v_{kj}}{\\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}}\n            \\right]^{\\frac{p}{\\beta+p}}z_{nk}, \\\\\n            z_{nk}\n            &\\leftarrow\\frac{z_{nk}}{\\displaystyle\\sum_{n'}z_{n'k}}.\n        \"\"\"\n        p = self.domain\n        beta = self.beta\n\n        if self.demix_filter is None:\n            Y = self.output\n        else:\n            X, W = self.input, self.demix_filter\n            Y = self.separate(X, demix_filter=W)\n\n        Yb = np.abs(Y) ** beta\n        p_bp = p / (beta + p)\n        bp_p = (beta + p) / p\n\n        Z = self.latent\n        T, V = self.basis, self.activation\n\n        TV = T[:, :, np.newaxis] * V[np.newaxis, :, :]\n        ZTV = self.reconstruct_nmf(T, V, latent=Z)\n\n        ZTVbpp = ZTV**bp_p\n        TV_RZTV = TV[np.newaxis, :, :, :] / ZTVbpp[:, :, np.newaxis, :]\n        num = (beta / 2) * np.sum(TV_RZTV * Yb[:, :, np.newaxis, :], axis=(1, 3))\n\n        TV_ZTV = TV[np.newaxis, :, :, :] / ZTV[:, :, np.newaxis, :]\n        denom = np.sum(TV_ZTV, axis=(1, 3))\n\n        Z = ((num / denom) ** p_bp) * Z\n        Z = Z / Z.sum(axis=0)\n\n        self.latent = Z\n\n    def update_basis_mm(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update NMF bases by MM algorithm.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        Update :math:`t_{ikn}` as follows:\n\n        .. math::\n            t_{ik}\n            \\leftarrow\\left[\n            \\frac{\\beta}{2}\n            \\frac{\\displaystyle\\sum_{j,n}\\frac{z_{nk}v_{kj}}\n            {(\\sum_{k'}z_{nk'}t_{ik'}v_{k'j})^{\\frac{\\beta+p}{p}}}|y_{ijn}|^{\\beta}}\n            {\\displaystyle\\sum_{j,n}\\frac{z_{nk}v_{kj}}{\\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}}\n            \\right]^{\\frac{p}{\\beta+p}}t_{ik},\n\n        if ``partitioning=True``. Otherwise\n\n        .. math::\n            t_{ikn}\n            \\leftarrow\\left[\n            \\frac{\\beta}{2}\n            \\frac{\\displaystyle\\sum_{j}\\frac{v_{kjn}}\n            {(\\sum_{k'}t_{ik'n}v_{k'jn})^{\\frac{\\beta+p}{p}}}|y_{ijn}|^{\\beta}}\n            {\\displaystyle\\sum_{j}\\frac{v_{kjn}}{\\sum_{k'}t_{ik'n}v_{k'jn}}}\n            \\right]^{\\frac{p}{\\beta+p}}t_{ikn}.\n        \"\"\"\n        p = self.domain\n        beta = self.beta\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if self.demix_filter is None:\n            Y = self.output\n        else:\n            X, W = self.input, self.demix_filter\n            Y = self.separate(X, demix_filter=W)\n\n        Yb = np.abs(Y) ** beta\n        p_bp = p / (beta + p)\n        bp_p = (beta + p) / p\n\n        if self.partitioning:\n            Z = self.latent\n            T, V = self.basis, self.activation\n\n            ZV = Z[:, :, np.newaxis] * V[np.newaxis, :, :]\n            ZTV = self.reconstruct_nmf(T, V, latent=Z)\n\n            ZTVbpp = ZTV**bp_p\n            ZV_ZTVbpp = ZV[:, np.newaxis, :, :] / ZTVbpp[:, :, np.newaxis, :]\n            num = (beta / 2) * np.sum(ZV_ZTVbpp * Yb[:, :, np.newaxis, :], axis=(0, 3))\n\n            ZV_ZTV = ZV[:, np.newaxis, :, :] / ZTV[:, :, np.newaxis, :]\n            denom = np.sum(ZV_ZTV, axis=(0, 3))\n        else:\n            T, V = self.basis, self.activation\n\n            TV = self.reconstruct_nmf(T, V)\n\n            TVbpp = TV**bp_p\n            V_TVbpp = V[:, np.newaxis, :, :] / TVbpp[:, :, np.newaxis, :]\n            num = (beta / 2) * np.sum(V_TVbpp * Yb[:, :, np.newaxis, :], axis=3)\n\n            V_TV = V[:, np.newaxis, :, :] / TV[:, :, np.newaxis, :]\n            denom = np.sum(V_TV, axis=3)\n\n        T = ((num / denom) ** p_bp) * T\n        T = flooring_fn(T)\n\n        self.basis = T\n\n    def update_activation_mm(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update NMF activations by MM algorithm.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        Update :math:`v_{kjn}` as follows:\n\n        .. math::\n            v_{kj}\n            \\leftarrow\\left[\n            \\frac{\\beta}{2}\n            \\frac{\\displaystyle\\sum_{i,n}\\frac{z_{nk}t_{ik}}\n            {(\\sum_{k'}z_{nk'}t_{ik'}v_{k'j})^{\\frac{\\beta+p}{p}}}|y_{ijn}|^{\\beta}}\n            {\\displaystyle\\sum_{i,n}\\frac{z_{nk}t_{ik}}{\\sum_{k'}z_{nk'}t_{ik'}v_{k'j}}}\n            \\right]^{\\frac{p}{\\beta+p}}v_{kj},\n\n        if ``partitioning=True``. Otherwise\n\n        .. math::\n            v_{kj}\n            \\leftarrow\\left[\n            \\frac{\\beta}{2}\n            \\frac{\\displaystyle\\sum_{i}\\frac{t_{ikn}}\n            {(\\sum_{k'}t_{ik'n}v_{k'jn})^{\\frac{\\beta+p}{p}}}|y_{ijn}|^{\\beta}}\n            {\\displaystyle\\sum_{i}\\frac{t_{ik}}{\\sum_{k'}t_{ik'n}v_{k'jn}}}\n            \\right]^{\\frac{p}{\\beta+p}}v_{kjn}.\n        \"\"\"\n        p = self.domain\n        beta = self.beta\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if self.demix_filter is None:\n            Y = self.output\n        else:\n            X, W = self.input, self.demix_filter\n            Y = self.separate(X, demix_filter=W)\n\n        Yb = np.abs(Y) ** beta\n        p_bp = p / (beta + p)\n        bp_p = (beta + p) / p\n\n        if self.partitioning:\n            Z = self.latent\n            T, V = self.basis, self.activation\n\n            ZT = Z[:, np.newaxis, :] * T[np.newaxis, :, :]\n            ZTV = self.reconstruct_nmf(T, V, latent=Z)\n\n            ZTVbpp = ZTV**bp_p\n            ZT_ZTVbpp = ZT[:, :, :, np.newaxis] / ZTVbpp[:, :, np.newaxis, :]\n            num = (beta / 2) * np.sum(ZT_ZTVbpp * Yb[:, :, np.newaxis, :], axis=(0, 1))\n\n            ZT_ZTV = ZT[:, :, :, np.newaxis] / ZTV[:, :, np.newaxis, :]\n            denom = np.sum(ZT_ZTV, axis=(0, 1))\n        else:\n            T, V = self.basis, self.activation\n\n            TV = self.reconstruct_nmf(T, V)\n\n            TVbpp = TV**bp_p\n            T_TVbpp = T[:, :, :, np.newaxis] / TVbpp[:, :, np.newaxis, :]\n            num = (beta / 2) * np.sum(T_TVbpp * Yb[:, :, np.newaxis, :], axis=1)\n\n            T_TV = T[:, :, :, np.newaxis] / TV[:, :, np.newaxis, :]\n            denom = np.sum(T_TV, axis=1)\n\n        V = ((num / denom) ** p_bp) * V\n        V = flooring_fn(V)\n\n        self.activation = V\n\n    def update_spatial_model(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update demixing filters once.\n\n        - If ``spatial_algorithm`` is ``IP`` or ``IP1``, ``update_spatial_model_ip1`` is called.\n        - If ``spatial_algorithm`` is ``ISS`` or ``ISS1``, ``update_spatial_model_iss1`` is called.\n        - If ``spatial_algorithm`` is ``IP2``, ``update_spatial_model_ip2`` is called.\n        - If ``spatial_algorithm`` is ``ISS2``, ``update_spatial_model_iss2`` is called.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if self.spatial_algorithm in [\"IP\", \"IP1\"]:\n            self.update_spatial_model_ip1(flooring_fn=flooring_fn)\n        elif self.spatial_algorithm in [\"IP2\"]:\n            self.update_spatial_model_ip2(flooring_fn=flooring_fn)\n        elif self.spatial_algorithm in [\"ISS\", \"ISS1\"]:\n            self.update_spatial_model_iss1(flooring_fn=flooring_fn)\n        elif self.spatial_algorithm in [\"ISS2\"]:\n            self.update_spatial_model_iss2(flooring_fn=flooring_fn)\n        else:\n            raise NotImplementedError(\"Not support {}.\".format(self.spatial_algorithm))\n\n    def update_spatial_model_ip1(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update demixing filters once using iterative projection.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        Demixing filters are updated sequentially for :math:`n=1,\\ldots,N` as follows:\n\n        .. math::\n            \\boldsymbol{w}_{in}\n            &\\leftarrow\\left(\\boldsymbol{W}_{in}^{\\mathsf{H}}\\boldsymbol{U}_{in}\\right)^{-1} \\\n            \\boldsymbol{e}_{n}, \\\\\n            \\boldsymbol{w}_{in}\n            &\\leftarrow\\frac{\\boldsymbol{w}_{in}}\n            {\\sqrt{\\boldsymbol{w}_{in}^{\\mathsf{H}}\\boldsymbol{U}_{in}\\boldsymbol{w}_{in}}},\n\n        where\n\n        .. math::\n            \\boldsymbol{U}_{in}\n            \\leftarrow\\frac{1}{J}\\sum_{i,j,n}\n            \\frac{\\boldsymbol{x}_{ij}\\boldsymbol{x}_{ij}^{\\mathsf{H}}}{\\tilde{r}_{ijn}}.\n\n        :math:`\\tilde{r}_{ijn}` is computed as\n\n        .. math::\n            \\tilde{r}_{ijn}\n            = \\frac{2|y_{ijn}|^{2-\\beta}}{\\beta}\n            \\left(\\sum_{k}z_{nk}t_{ik}v_{kj}\\right)^{\\frac{\\beta}{p}},\n\n        if ``partitioning=True``. Otherwise,\n\n        .. math::\n            \\tilde{r}_{ijn}\n            = \\frac{2|y_{ijn}|^{2-\\beta}}{\\beta}\n            \\left(\\sum_{k}t_{ikn}v_{kjn}\\right)^{\\frac{\\beta}{p}}.\n        \"\"\"\n        p = self.domain\n        beta = self.beta\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        X, W = self.input, self.demix_filter\n        Y = self.separate(X, demix_filter=W)\n\n        Y2b = np.abs(Y) ** (2 - beta)\n        Y2b = flooring_fn(Y2b)\n\n        if self.partitioning:\n            Z = self.latent\n            T, V = self.basis, self.activation\n\n            ZTV = self.reconstruct_nmf(T, V, latent=Z)\n            ZTVbp = ZTV ** (beta / p)\n            R_tilde = (2 / beta) * Y2b * ZTVbp\n        else:\n            T, V = self.basis, self.activation\n\n            TV = self.reconstruct_nmf(T, V)\n            TVbp = TV ** (beta / p)\n            R_tilde = (2 / beta) * Y2b * TVbp\n\n        varphi = 1 / R_tilde\n\n        XX_Hermite = X[:, np.newaxis, :, :] * X[np.newaxis, :, :, :].conj()\n        XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3)\n\n        varphi = varphi.transpose(1, 0, 2)\n        varphi_XX = varphi[:, :, np.newaxis, np.newaxis, :] * XX_Hermite[:, np.newaxis, :, :, :]\n        U = np.mean(varphi_XX, axis=-1)\n\n        self.demix_filter = update_by_ip1(W, U, flooring_fn=flooring_fn)\n\n    def update_spatial_model_ip2(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update demixing filters once using pairwise iterative projection.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        For :math:`n_{1}` and :math:`n_{2}` (:math:`n_{1}\\neq n_{2}`),\n        compute weighted covariance matrix as follows:\n\n        .. math::\n            \\boldsymbol{U}_{in}\n            = \\frac{1}{J}\\sum_{j}\n            \\frac{1}{\\tilde{r}_{ijn}}\\boldsymbol{x}_{ij}\\boldsymbol{x}_{ij}^{\\mathsf{H}},\n\n        :math:`\\tilde{r}_{ijn}` is computed by\n\n        .. math::\n            \\tilde{r}_{ijn}\n            = \\frac{2|y_{ijn}|^{2-\\beta}}{\\beta}\n            \\left(\\sum_{k}z_{nk}t_{ik}v_{kj}\\right)^{\\frac{\\beta}{p}},\n\n        if ``partitioning=True``. \\\n        Otherwise,\n\n        .. math::\n            \\tilde{r}_{ijn}\n            = \\frac{2|y_{ijn}|^{2-\\beta}}{\\beta}\n            \\left(\\sum_{k}t_{ikn}v_{kjn}\\right)^{\\frac{\\beta}{p}}.\n\n        Using :math:`\\boldsymbol{U}_{in_{1}}` and\n        :math:`\\boldsymbol{U}_{in_{2}}`, we compute generalized eigenvectors.\n\n        .. math::\n            \\left({\\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}}^{\\mathsf{H}}\\boldsymbol{U}_{in_{1}}\n            \\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\\right)\\boldsymbol{h}_{i}\n            = \\lambda_{i}\n            \\left({\\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}}^{\\mathsf{H}}\\boldsymbol{U}_{in_{2}}\n            \\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\\right)\\boldsymbol{h}_{i},\n\n        where\n\n        .. math::\n            \\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\n            &= (\\boldsymbol{W}_{i}\\boldsymbol{U}_{in_{1}})^{-1}\n            (\n            \\begin{array}{cc}\n                \\boldsymbol{e}_{n_{1}} & \\boldsymbol{e}_{n_{2}}\n            \\end{array}\n            ), \\\\\n            \\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\n            &= (\\boldsymbol{W}_{i}\\boldsymbol{U}_{in_{2}})^{-1}\n            (\n            \\begin{array}{cc}\n                \\boldsymbol{e}_{n_{1}} & \\boldsymbol{e}_{n_{2}}\n            \\end{array}\n            ).\n\n        After that, we standardize two eigenvectors :math:`\\boldsymbol{h}_{in_{1}}`\n        and :math:`\\boldsymbol{h}_{in_{2}}`.\n\n        .. math::\n            \\boldsymbol{h}_{in_{1}}\n            &\\leftarrow\\frac{\\boldsymbol{h}_{in_{1}}}\n            {\\sqrt{\\boldsymbol{h}_{in_{1}}^{\\mathsf{H}}\n            \\left({\\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}}^{\\mathsf{H}}\\boldsymbol{U}_{in_{1}}\n            \\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\\right)\n            \\boldsymbol{h}_{in_{1}}}}, \\\\\n            \\boldsymbol{h}_{in_{2}}\n            &\\leftarrow\\frac{\\boldsymbol{h}_{in_{2}}}\n            {\\sqrt{\\boldsymbol{h}_{in_{2}}^{\\mathsf{H}}\n            \\left({\\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}}^{\\mathsf{H}}\\boldsymbol{U}_{in_{2}}\n            \\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\\right)\n            \\boldsymbol{h}_{in_{2}}}}.\n\n        Then, update :math:`\\boldsymbol{w}_{in_{1}}` and :math:`\\boldsymbol{w}_{in_{2}}`\n        simultaneously.\n\n        .. math::\n            \\boldsymbol{w}_{in_{1}}\n            &\\leftarrow \\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\\boldsymbol{h}_{in_{1}} \\\\\n            \\boldsymbol{w}_{in_{2}}\n            &\\leftarrow \\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\\boldsymbol{h}_{in_{2}}\n\n        At each iteration, we update pairs of :math:`n_{1}` and :math:`n_{1}`\n        for :math:`n_{1}\\neq n_{2}`.\n        \"\"\"\n        p = self.domain\n        beta = self.beta\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        X, W = self.input, self.demix_filter\n        Y = self.separate(X, demix_filter=W)\n\n        Y2b = np.abs(Y) ** (2 - beta)\n        Y2b = flooring_fn(Y2b)\n\n        if self.partitioning:\n            Z = self.latent\n            T, V = self.basis, self.activation\n\n            ZTV = self.reconstruct_nmf(T, V, latent=Z)\n            ZTVbp = ZTV ** (beta / p)\n            R_tilde = (2 / beta) * Y2b * ZTVbp\n        else:\n            T, V = self.basis, self.activation\n\n            TV = self.reconstruct_nmf(T, V)\n            TVbp = TV ** (beta / p)\n            R_tilde = (2 / beta) * Y2b * TVbp\n\n        varphi = 1 / R_tilde\n\n        XX_Hermite = X[:, np.newaxis, :, :] * X[np.newaxis, :, :, :].conj()\n        XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3)\n\n        varphi = varphi.transpose(1, 0, 2)\n        varphi_XX = varphi[:, :, np.newaxis, np.newaxis, :] * XX_Hermite[:, np.newaxis, :, :, :]\n        U = np.mean(varphi_XX, axis=-1)\n\n        self.demix_filter = update_by_ip2(\n            W, U, flooring_fn=flooring_fn, pair_selector=self.pair_selector\n        )\n\n    def update_spatial_model_iss1(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update estimated spectrograms once using iterative source steering.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        Update :math:`y_{ijn}` as follows:\n\n        .. math::\n            \\boldsymbol{y}_{ij}\n            & \\leftarrow\\boldsymbol{y}_{ij} - \\boldsymbol{d}_{in}y_{ijn} \\\\\n            d_{inn'}\n            &= \\begin{cases}\n                \\dfrac{\\displaystyle\\sum_{j}\\dfrac{1}{\\tilde{r}_{ijn}}\n                y_{ijn'}y_{ijn}^{*}}{\\displaystyle\\sum_{j}\\dfrac{1}\n                {\\tilde{r}_{ijn}}|y_{ijn}|^{2}}\n                & (n'\\neq n) \\\\\n                1 - \\dfrac{1}{\\sqrt{\\displaystyle\\dfrac{1}{J}\\sum_{j}\\dfrac{1}\n                {\\tilde{r}_{ijn}}|y_{ijn}|^{2}}} & (n'=n)\n            \\end{cases},\n\n        where :math:`\\tilde{r}_{ijn}` is computed as\n\n        .. math::\n            \\tilde{r}_{ijn}\n            = \\frac{2|y_{ijn}|^{2-\\beta}}{\\beta}\n            \\left(\\sum_{k}z_{nk}t_{ik}v_{kj}\\right)^{\\frac{\\beta}{p}},\n\n        if ``partitioning=True``. Otherwise,\n\n        .. math::\n            \\tilde{r}_{ijn}\n            = \\frac{2|y_{ijn}|^{2-\\beta}}{\\beta}\n            \\left(\\sum_{k}t_{ikn}v_{kjn}\\right)^{\\frac{\\beta}{p}}.\n        \"\"\"\n        p = self.domain\n        beta = self.beta\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        Y = self.output\n        Y2b = np.abs(Y) ** (2 - beta)\n        Y2b = flooring_fn(Y2b)\n\n        if self.partitioning:\n            Z = self.latent\n            T, V = self.basis, self.activation\n\n            ZTV = self.reconstruct_nmf(T, V, latent=Z)\n            ZTVbp = ZTV ** (beta / p)\n            R_bar = Y2b * ZTVbp\n        else:\n            T, V = self.basis, self.activation\n\n            TV = self.reconstruct_nmf(T, V)\n            TVbp = TV ** (beta / p)\n            R_bar = Y2b * TVbp\n\n        varphi = beta / (2 * R_bar)\n\n        self.output = update_by_iss1(Y, varphi, flooring_fn=flooring_fn)\n\n    def update_spatial_model_iss2(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update estimated spectrograms once using pairwise iterative source steering.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        Compute :math:`\\boldsymbol{G}_{in}^{(n_{1},n_{2})}`\n        and :math:`\\boldsymbol{f}_{in}^{(n_{1},n_{2})}` for :math:`n_{1}\\neq n_{2}`:\n\n        .. math::\n            \\begin{array}{rclc}\n                \\boldsymbol{G}_{in}^{(n_{1},n_{2})}\n                &=& {\\displaystyle\\frac{1}{J}\\sum_{j}}\\dfrac{1}{\\tilde{r}_{ijn}}\n                \\boldsymbol{y}_{ij}^{(n_{1},n_{2})}{\\boldsymbol{y}_{ij}^{(n_{1},n_{2})}}^{\\mathsf{H}}\n                &(n=1,\\ldots,N), \\\\\n                \\boldsymbol{f}_{in}^{(n_{1},n_{2})}\n                &=& {\\displaystyle\\frac{1}{J}\\sum_{j}}\n                \\dfrac{1}{\\tilde{r}_{ijn}}y_{ijn}^{*}\\boldsymbol{y}_{ij}^{(n_{1},n_{2})}\n                &(n\\neq n_{1},n_{2}),\n            \\end{array}\n\n        where\n\n        .. math::\n            \\tilde{r}_{ijn}\n            = \\frac{2}{\\beta}|y_{ijn}|^{2-\\beta}\n            \\left(\\sum_{k}z_{nk}t_{ik}v_{kj}\\right)^{\\frac{\\beta}{p}},\n\n        if ``partitioning=True``.\n        Otherwise,\n\n        .. math::\n            \\tilde{r}_{ijn}\n            = \\frac{2}{\\beta}|y_{ijn}|^{2-\\beta}\n            \\left(\\sum_{k}t_{ikn}v_{kjn}\\right)^{\\frac{\\beta}{p}}.\n\n        Using :math:`\\boldsymbol{G}_{in}^{(n_{1},n_{2})}` and\n        :math:`\\boldsymbol{f}_{in}^{(n_{1},n_{2})}`, we compute\n\n        .. math::\n            \\begin{array}{rclc}\n                \\boldsymbol{p}_{in}\n                &=& \\dfrac{\\boldsymbol{h}_{in}}\n                {\\sqrt{\\boldsymbol{h}_{in}^{\\mathsf{H}}\\boldsymbol{G}_{in}^{(n_{1},n_{2})}\n                \\boldsymbol{h}_{in}}} & (n=n_{1},n_{2}), \\\\\n                \\boldsymbol{q}_{in}\n                &=& -{\\boldsymbol{G}_{in}^{(n_{1},n_{2})}}^{-1}\\boldsymbol{f}_{in}^{(n_{1},n_{2})}\n                & (n\\neq n_{1},n_{2}),\n            \\end{array}\n\n        where :math:`\\boldsymbol{h}_{in}` (:math:`n=n_{1},n_{2}`) is\n        a generalized eigenvector obtained from\n\n        .. math::\n            \\boldsymbol{G}_{in_{1}}^{(n_{1},n_{2})}\\boldsymbol{h}_{i}\n            = \\lambda_{i}\\boldsymbol{G}_{in_{2}}^{(n_{1},n_{2})}\\boldsymbol{h}_{i}.\n\n        Separated signal :math:`y_{ijn}` is updated as follows:\n\n        .. math::\n            y_{ijn}\n            &\\leftarrow\\begin{cases}\n            &\\boldsymbol{p}_{in}^{\\mathsf{H}}\\boldsymbol{y}_{ij}^{(n_{1},n_{2})}\n            & (n=n_{1},n_{2}) \\\\\n            &\\boldsymbol{q}_{in}^{\\mathsf{H}}\\boldsymbol{y}_{ij}^{(n_{1},n_{2})} + y_{ijn}\n            & (n\\neq n_{1},n_{2})\n            \\end{cases}.\n        \"\"\"\n        p = self.domain\n        beta = self.beta\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        Y = self.output\n        Y2b = np.abs(Y) ** (2 - beta)\n        Y2b = flooring_fn(Y2b)\n\n        if self.partitioning:\n            Z = self.latent\n            T, V = self.basis, self.activation\n\n            ZTV = self.reconstruct_nmf(T, V, latent=Z)\n            ZTVbp = ZTV ** (beta / p)\n            R_tilde = (2 / beta) * Y2b * ZTVbp\n        else:\n            T, V = self.basis, self.activation\n\n            TV = self.reconstruct_nmf(T, V)\n            TVbp = TV ** (beta / p)\n            R_tilde = (2 / beta) * Y2b * TVbp\n\n        varphi = 1 / R_tilde\n\n        self.output = update_by_iss2(\n            Y, varphi, flooring_fn=flooring_fn, pair_selector=self.pair_selector\n        )\n\n    def compute_loss(self) -> float:\n        r\"\"\"Compute loss :math:`\\mathcal{L}`.\n\n        :math:`\\mathcal{L}` is given as follows:\n\n        .. math::\n            \\mathcal{L}\n            = \\frac{1}{J}\\sum_{i,j,n}\n            \\left\\{\\left(\\frac{|y_{ijn}|^{2}}{r_{ijn}}\\right)^{\\frac{\\beta}{2}}\n            + \\log r_{ijn}\\right\\}\n            - 2\\sum_{i}\\log|\\det\\boldsymbol{W}_{i}|,\n\n        where\n\n        .. math::\n            r_{ijn}\n            = \\left(\\sum_{k}z_{nk}t_{ik}v_{kj}\\right)^{\\frac{2}{p}},\n\n        if ``partitioning=True``. Otherwise\n\n        .. math::\n            r_{ijn}\n            = \\left(\\sum_{k}t_{ikn}v_{kjn}\\right)^{\\frac{2}{p}}.\n\n        Returns:\n            Computed loss.\n        \"\"\"\n        beta = self.beta\n        p = self.domain\n\n        if self.demix_filter is None:\n            X, Y = self.input, self.output\n            Yb = np.abs(Y) ** beta\n            X, Y = X.transpose(1, 0, 2), Y.transpose(1, 0, 2)\n            X_Hermite = X.transpose(0, 2, 1).conj()\n            XX_Hermite = X @ X_Hermite\n            W = Y @ X_Hermite @ np.linalg.inv(XX_Hermite)\n        else:\n            X, W = self.input, self.demix_filter\n            Y = self.separate(X, demix_filter=W)\n            Yb = np.abs(Y) ** beta\n\n        if self.partitioning:\n            Z = self.latent\n            T, V = self.basis, self.activation\n            ZTV = self.reconstruct_nmf(T, V, latent=Z)\n            R = ZTV ** (beta / p)\n            loss = Yb / R + (2 / p) * np.log(ZTV)\n        else:\n            T, V = self.basis, self.activation\n            TV = self.reconstruct_nmf(T, V)\n            R = TV ** (beta / p)\n            loss = Yb / R + (2 / p) * np.log(TV)\n\n        logdet = self.compute_logdet(W)  # (n_bins,)\n\n        loss = np.sum(loss.mean(axis=-1), axis=0) - 2 * logdet\n        loss = loss.sum(axis=0).item()\n\n        return loss\n\n    def apply_projection_back(self) -> None:\n        r\"\"\"Apply projection back technique to estimated spectrograms.\"\"\"\n        if self.demix_filter is None:\n            assert self.scale_restoration, \"Set self.scale_restoration=True.\"\n\n            X, Y = self.input, self.output\n            Y_scaled = projection_back(Y, reference=X, reference_id=self.reference_id)\n\n            self.output = Y_scaled\n        else:\n            super().apply_projection_back()\n\n    def apply_minimal_distortion_principle(self) -> None:\n        r\"\"\"Apply minimal distortion principle to estimated spectrograms.\"\"\"\n        if self.demix_filter is None:\n            X, Y = self.input, self.output\n            Y_scaled = minimal_distortion_principle(Y, reference=X, reference_id=self.reference_id)\n\n            self.output = Y_scaled\n        else:\n            super().apply_minimal_distortion_principle()\n"
  },
  {
    "path": "ssspy/bss/ipsdta.py",
    "content": "import functools\nfrom typing import Callable, List, Optional, Tuple, Union\n\nimport numpy as np\n\nfrom ..algorithm import (\n    MINIMAL_DISTORTION_PRINCIPLE_KEYWORDS,\n    PROJECTION_BACK_KEYWORDS,\n    minimal_distortion_principle,\n    projection_back,\n)\nfrom ..linalg.mean import gmeanmh\nfrom ..linalg.quadratic import quadratic\nfrom ..linalg.sqrtm import invsqrtmh, sqrtmh\nfrom ..special.flooring import identity, max_flooring\nfrom ..special.psd import to_psd\nfrom ..utils.flooring import choose_flooring_fn\nfrom ._update_spatial_model import update_by_block_decomposition_vcd\nfrom .base import IterativeMethodBase\n\nspatial_algorithms = [\"FPI\", \"VCD\"]\nsource_algorithms = [\"EM\", \"MM\"]\nEPS = 1e-10\n\n\nclass IPSDTABase(IterativeMethodBase):\n    r\"\"\"Base class of independent positive semidefinite tensor analysis (IPSDTA).\n\n    Args:\n        n_basis (int):\n            Number of PSDTF bases.\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back`` explicitly.\n            Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the update algorithm if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back.\n            Default: ``0``.\n        rng (numpy.random.Generator, optioinal):\n            Random number generator. This is mainly used to randomly initialize PSDTF.\n            If ``None`` is given, ``np.random.default_rng()`` is used.\n            Default: ``None``.\n    \"\"\"\n\n    def __init__(\n        self,\n        n_basis: int,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        callbacks: Optional[\n            Union[Callable[[\"IPSDTABase\"], None], List[Callable[[\"IPSDTABase\"], None]]]\n        ] = None,\n        scale_restoration: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n        rng: Optional[np.random.Generator] = None,\n    ) -> None:\n        self.source_normalization: Optional[Union[bool, str]]\n\n        super().__init__(callbacks=callbacks, record_loss=record_loss)\n\n        self.n_basis = n_basis\n\n        if flooring_fn is None:\n            self.flooring_fn = identity\n        else:\n            self.flooring_fn = flooring_fn\n\n        self.input = None\n        self.scale_restoration = scale_restoration\n\n        if reference_id is None and scale_restoration:\n            raise ValueError(\"Specify 'reference_id' if scale_restoration=True.\")\n        else:\n            self.reference_id = reference_id\n\n        if rng is None:\n            rng = np.random.default_rng()\n\n        self.rng = rng\n\n    def __call__(self, input: np.ndarray, n_iter: int = 100, **kwargs) -> np.ndarray:\n        r\"\"\"Separate a frequency-domain multichannel signal.\n\n        Args:\n            input (numpy.ndarray):\n                The mixture signal in frequency-domain.\n                The shape is (n_channels, n_bins, n_frames).\n            n_iter (int):\n                The number of iterations of demixing filter updates.\n                Default: ``100``.\n\n        Returns:\n            numpy.ndarray of the separated signal in frequency-domain.\n            The shape is (n_channels, n_bins, n_frames).\n        \"\"\"\n        self.input = input.copy()\n\n        self._reset(**kwargs)\n\n        super().__call__(n_iter=n_iter)\n\n        if self.scale_restoration:\n            self.restore_scale()\n\n        self.output = self.separate(self.input, demix_filter=self.demix_filter)\n\n        return self.output\n\n    def __repr__(self) -> str:\n        s = \"IPSDTA(\"\n        s += \"n_basis={n_basis}\"\n        s += \", scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def _reset(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n        **kwargs,\n    ) -> None:\n        r\"\"\"Reset attributes by given keyword arguments.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n            kwargs:\n                Keyword arguments to set as attributes of IPSDTA.\n        \"\"\"\n        assert self.input is not None, \"Specify data!\"\n\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        for key in kwargs.keys():\n            setattr(self, key, kwargs[key])\n\n        X = self.input\n\n        n_channels, n_bins, n_frames = X.shape\n        n_sources = n_channels  # n_channels == n_sources\n\n        self.n_sources, self.n_channels = n_sources, n_channels\n        self.n_bins, self.n_frames = n_bins, n_frames\n\n        if not hasattr(self, \"demix_filter\"):\n            W = np.eye(n_sources, n_channels, dtype=np.complex128)\n            W = np.tile(W, reps=(n_bins, 1, 1))\n        else:\n            if self.demix_filter is None:\n                W = None\n            else:\n                # To avoid overwriting ``demix_filter`` given by keyword arguments.\n                W = self.demix_filter.copy()\n\n        self.demix_filter = W\n        self.output = self.separate(X, demix_filter=W)\n\n        self._init_psdtf(flooring_fn=flooring_fn, rng=self.rng)\n\n    def _init_psdtf(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n        rng: Optional[np.random.Generator] = None,\n    ) -> None:\n        r\"\"\"Initialize PSDTF.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n            rng (numpy.random.Generator, optional):\n                Random number generator. If ``None`` is given,\n                ``np.random.default_rng()`` is used.\n                Default: ``None``.\n        \"\"\"\n        n_basis = self.n_basis\n        n_sources = self.n_sources\n        n_bins, n_frames = self.n_bins, self.n_frames\n\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if rng is None:\n            rng = np.random.default_rng()\n\n        if not hasattr(self, \"basis\"):\n            # should be positive semi-definite\n            eye = np.eye(n_bins, dtype=np.complex128)\n            rand = rng.random((n_sources, n_basis, n_bins))\n            T = rand[..., np.newaxis] * eye\n        else:\n            # To avoid overwriting.\n            T = self.basis.copy()\n\n        if not hasattr(self, \"activation\"):\n            V = rng.random((n_sources, n_basis, n_frames))\n            V = flooring_fn(V)\n        else:\n            # To avoid overwriting.\n            V = self.activation.copy()\n\n        self.basis, self.activation = T, V\n\n        if self.source_normalization:\n            self.normalize_psdtf()\n\n    def separate(self, input: np.ndarray, demix_filter: np.ndarray) -> np.ndarray:\n        r\"\"\"Separate ``input`` using ``demixing_filter``.\n\n        .. math::\n            \\boldsymbol{y}_{ij}\n            = \\boldsymbol{W}_{i}\\boldsymbol{x}_{ij}\n\n        Args:\n            input (numpy.ndarray):\n                The mixture signal in frequency-domain.\n                The shape is (n_channels, n_bins, n_frames).\n            demix_filter (numpy.ndarray):\n                The demixing filters to separate ``input``.\n                The shape is (n_bins, n_sources, n_channels).\n\n        Returns:\n            numpy.ndarray of the separated signal in frequency-domain.\n            The shape is (n_sources, n_bins, n_frames).\n        \"\"\"\n        X, W = input, demix_filter\n        Y = W @ X.transpose(1, 0, 2)\n        output = Y.transpose(1, 0, 2)\n\n        return output\n\n    def reconstruct_psdtf(\n        self,\n        basis: np.ndarray,\n        activation: np.ndarray,\n        axis1: int = -2,\n        axis2: int = -1,\n    ) -> np.ndarray:\n        r\"\"\"Reconstruct PSDTF.\n\n        Args:\n            basis (numpy.ndarray):\n                Basis matrix.\n                The shape is (n_sources, n_basis, n_bins, n_bins) if ``axis1=-1`` and ``axis2=-2``.\n                Otherwise, (n_sources, n_bins, n_bins, n_basis).\n            activation (numpy.ndarray):\n                Activation matrix.\n                The shape is (n_sources, n_basis, n_frames).\n            axis1 (int):\n                First axis of covariance matrix. Default: ``-2``.\n            axis2 (int):\n                Second axis of covariance matrix. Default: ``-1``.\n\n        Returns:\n            numpy.ndarray of reconstructed PSDTF.\n            The shape is (n_sources, n_frames, n_bins, n_bins).\n        \"\"\"\n        T, V = basis, activation\n        n_dims = T.ndim\n\n        axis1 = n_dims + axis1 if axis1 < 0 else axis1\n        axis2 = n_dims + axis2 if axis2 < 0 else axis2\n\n        assert (axis1 == 1 and axis2 == 2) or (axis1 == 2 and axis2 == 3)\n\n        if axis1 == 1 and axis2 == 2:\n            T = T.transpose(0, 3, 1, 2)\n\n        R = np.sum(T[:, :, np.newaxis, :, :] * V[:, :, :, np.newaxis, np.newaxis], axis=1)\n        R = to_psd(R, axis1=2, axis2=3)\n\n        return R\n\n    def update_once(self) -> None:\n        r\"\"\"Update demixing filters once.\"\"\"\n        raise NotImplementedError(\"Implement 'update_once' method.\")\n\n    def normalize_psdtf(self) -> None:\n        r\"\"\"Normalize PSDTF parameters.\"\"\"\n        source_normalization = self.source_normalization\n        T, V = self.basis, self.activation\n\n        assert source_normalization, \"Set source_normalization.\"\n\n        trace = np.trace(T, axis1=-2, axis2=-1).real\n        T = T / trace[:, :, np.newaxis, np.newaxis]\n        V = V * trace[:, :, np.newaxis]\n\n        self.basis, self.activation = T, V\n\n    def compute_loss(self) -> float:\n        r\"\"\"Compute loss :math:`\\mathcal{L}`.\n\n        Returns:\n            Computed loss.\n        \"\"\"\n        raise NotImplementedError(\"Implement 'compute_loss' method.\")\n\n    def compute_logdet(self, demix_filter: np.ndarray) -> np.ndarray:\n        r\"\"\"Compute log-determinant of demixing filter\n\n        Args:\n            demix_filter (numpy.ndarray):\n                Demixing filters with shape of (n_bins, n_sources, n_channels).\n\n        Returns:\n            numpy.ndarray of computed log-determinant values.\n        \"\"\"\n        _, logdet = np.linalg.slogdet(demix_filter)  # (n_bins,)\n\n        return logdet\n\n    def restore_scale(self) -> None:\n        r\"\"\"Restore scale ambiguity.\n\n        If ``self.scale_restoration=\"projection_back``, we use projection back technique.\n        \"\"\"\n        scale_restoration = self.scale_restoration\n\n        assert scale_restoration, \"Set self.scale_restoration=True.\"\n\n        if type(scale_restoration) is bool:\n            scale_restoration = PROJECTION_BACK_KEYWORDS[0]\n\n        if scale_restoration in PROJECTION_BACK_KEYWORDS:\n            self.apply_projection_back()\n        elif scale_restoration in MINIMAL_DISTORTION_PRINCIPLE_KEYWORDS:\n            self.apply_minimal_distortion_principle()\n        else:\n            raise ValueError(\"{} is not supported for scale restoration.\".format(scale_restoration))\n\n    def apply_projection_back(self) -> None:\n        r\"\"\"Apply projection back technique to estimated spectrograms.\"\"\"\n        assert self.scale_restoration, \"Set self.scale_restoration=True.\"\n\n        X, W = self.input, self.demix_filter\n        W_scaled = projection_back(W, reference_id=self.reference_id)\n        Y_scaled = self.separate(X, demix_filter=W_scaled)\n\n        self.output, self.demix_filter = Y_scaled, W_scaled\n\n    def apply_minimal_distortion_principle(self) -> None:\n        r\"\"\"Apply minimal distortion principle to estimated spectrograms.\"\"\"\n        assert self.scale_restoration, \"Set self.scale_restoration=True.\"\n\n        X, W = self.input, self.demix_filter\n        Y = self.separate(X, demix_filter=W)\n        Y_scaled = minimal_distortion_principle(Y, reference=X, reference_id=self.reference_id)\n        X = X.transpose(1, 0, 2)\n        Y = Y_scaled.transpose(1, 0, 2)\n        X_Hermite = X.transpose(0, 2, 1).conj()\n        W_scaled = Y @ X_Hermite @ np.linalg.inv(X @ X_Hermite)\n\n        self.output, self.demix_filter = Y_scaled, W_scaled\n\n\nclass BlockDecompositionIPSDTABase(IPSDTABase):\n    r\"\"\"Base class of independent positive semidefinite tensor analysis (IPSDTA) \\\n    using block decomposition of bases.\n\n    Args:\n        n_basis (int):\n            Number of PSDTF bases.\n        n_blocks (int):\n            Number of sub-blocks.\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back`` explicitly.\n            Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the update algorithm if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back.\n            Default: ``0``.\n        rng (numpy.random.Generator, optioinal):\n            Random number generator. This is mainly used to randomly initialize PSDTF.\n            If ``None`` is given, ``np.random.default_rng()`` is used.\n            Default: ``None``.\n    \"\"\"\n\n    def __init__(\n        self,\n        n_basis: int,\n        n_blocks: int,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        callbacks: Optional[\n            Union[\n                Callable[[\"BlockDecompositionIPSDTABase\"], None],\n                List[Callable[[\"BlockDecompositionIPSDTABase\"], None]],\n            ]\n        ] = None,\n        scale_restoration: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n        rng: Optional[np.random.Generator] = None,\n    ) -> None:\n        super().__init__(\n            n_basis=n_basis,\n            flooring_fn=flooring_fn,\n            callbacks=callbacks,\n            scale_restoration=scale_restoration,\n            record_loss=record_loss,\n            reference_id=reference_id,\n            rng=rng,\n        )\n\n        self.n_blocks = n_blocks\n\n    def __repr__(self) -> str:\n        s = \"IPSDTA(\"\n        s += \"n_basis={n_basis}\"\n        s += \", n_blocks={n_blocks}\"\n        s += \", scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def _reset(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n        **kwargs,\n    ) -> None:\n        r\"\"\"Reset attributes by given keyword arguments.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n            kwargs:\n                Keyword arguments to set as attributes of IPSDTA.\n        \"\"\"\n        assert self.input is not None, \"Specify data!\"\n\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        for key in kwargs.keys():\n            setattr(self, key, kwargs[key])\n\n        X = self.input\n\n        n_channels, n_bins, n_frames = X.shape\n        n_sources = n_channels  # n_channels == n_sources\n\n        self.n_sources, self.n_channels = n_sources, n_channels\n        self.n_bins, self.n_frames = n_bins, n_frames\n\n        if not hasattr(self, \"demix_filter\"):\n            W = np.eye(n_sources, n_channels, dtype=np.complex128)\n            W = np.tile(W, reps=(n_bins, 1, 1))\n        else:\n            if self.demix_filter is None:\n                W = None\n            else:\n                # To avoid overwriting ``demix_filter`` given by keyword arguments.\n                W = self.demix_filter.copy()\n\n        self.demix_filter = W\n        self.output = self.separate(X, demix_filter=W)\n\n        self._init_block_decomposition_psdtf(flooring_fn=flooring_fn, rng=self.rng)\n\n    def _init_block_decomposition_psdtf(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n        rng: Optional[np.random.Generator] = None,\n    ) -> None:\n        r\"\"\"Initialize PSDTF using block decomposition of bases.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n            rng (numpy.random.Generator, optional):\n                Random number generator. If ``None`` is given,\n                ``np.random.default_rng()`` is used.\n                Default: ``None``.\n        \"\"\"\n        n_basis = self.n_basis\n        n_sources = self.n_sources\n        n_bins, n_frames = self.n_bins, self.n_frames\n        n_blocks = self.n_blocks\n        n_remains = self.n_remains\n\n        n_neighbors = n_bins // n_blocks\n\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if rng is None:\n            rng = np.random.default_rng()\n\n        if not hasattr(self, \"basis\"):\n            # should be positive semi-definite\n            eye = np.eye(n_neighbors, dtype=np.complex128)\n            rand = rng.random((n_sources, n_basis, n_blocks - n_remains, n_neighbors))\n            T = rand[..., np.newaxis] * eye\n\n            if n_remains > 0:\n                eye = np.eye(n_neighbors + 1, dtype=np.complex128)\n                rand = rng.random((n_sources, n_basis, n_remains, n_neighbors + 1))\n                T_high = rand[..., np.newaxis] * eye\n\n                T = T, T_high\n        else:\n            # To avoid overwriting.\n            if n_remains > 0:\n                T_low, T_high = self.basis\n                T = T_low.copy(), T_high.copy()\n            else:\n                T = self.basis.copy()\n\n        if not hasattr(self, \"activation\"):\n            V = rng.random((n_sources, n_basis, n_frames))\n            V = flooring_fn(V)\n        else:\n            # To avoid overwriting.\n            V = self.activation.copy()\n\n        self.basis, self.activation = T, V\n\n        if self.source_normalization:\n            self.normalize_block_decomposition_psdtf()\n\n    @property\n    def n_remains(self) -> int:\n        if not hasattr(self, \"n_bins\"):\n            raise AttributeError(\"Since n_bins is not defined, n_remains cannot be computed.\")\n\n        return self.n_bins % self.n_blocks\n\n    def reconstruct_block_decomposition_psdtf(\n        self, basis: np.ndarray, activation: np.ndarray, axis1: int = -2, axis2: int = -1\n    ) -> np.ndarray:\n        r\"\"\"Reconstruct PSDTF using block decomposition of bases.\n\n        Args:\n            basis (numpy.ndarray):\n                Basis matrix.\n                The shape is (n_sources, n_basis, n_blocks, n_neighbors, n_neighbors)\n                if ``axis1=-1`` and ``axis2=-2``.\n                Otherwise, (n_sources, n_blocks, n_neighbors, n_neighbors, n_basis).\n            activation (numpy.ndarray):\n                Activation matrix.\n                The shape is (n_sources, n_basis, n_frames).\n            axis1 (int):\n                First axis of covariance matrix. Default: ``-2``.\n            axis2 (int):\n                Second axis of covariance matrix. Default: ``-1``.\n\n        Returns:\n            numpy.ndarray of reconstructed PSDTF.\n            The shape is (n_sources, n_frames, n_blocks, n_neighbors, n_neighbors).\n        \"\"\"\n\n        def _reconstruct(\n            basis: np.ndarray, activation: np.ndarray, axis1: int = -2, axis2: int = -1\n        ) -> np.ndarray:\n            r\"\"\"Reconstruct PSDTF using block decomposition of bases.\n\n            Args:\n                basis (numpy.ndarray):\n                    Basis matrix.\n                    The shape is (n_sources, n_basis, n_blocks, n_neighbors, n_neighbors)\n                    if ``axis1=-1`` and ``axis2=-2``.\n                    Otherwise, (n_sources, n_blocks, n_neighbors, n_neighbors, n_basis).\n                activation (numpy.ndarray):\n                    Activation matrix.\n                    The shape is (n_sources, n_basis, n_frames).\n                axis1 (int):\n                    First axis of covariance matrix. Default: ``-2``.\n                axis2 (int):\n                    Second axis of covariance matrix. Default: ``-1``.\n\n            Returns:\n                numpy.ndarray of reconstructed PSDTF.\n                The shape is (n_sources, n_frames, n_blocks, n_neighbors, n_neighbors).\n            \"\"\"\n            na = np.newaxis\n            T, V = basis, activation\n            n_dims = T.ndim\n\n            axis1 = n_dims + axis1 if axis1 < 0 else axis1\n            axis2 = n_dims + axis2 if axis2 < 0 else axis2\n\n            assert (axis1 == 2 and axis2 == 3) or (axis1 == 3 and axis2 == 4)\n\n            if axis1 == 2 and axis2 == 3:\n                T = T.transpose(0, 4, 1, 2, 3)\n\n            R = np.sum(\n                T[:, :, na, :, :, :] * V[:, :, :, na, na, na],\n                axis=1,\n            )\n            R = to_psd(R, axis1=3, axis2=4)\n\n            return R\n\n        if type(basis) is tuple:\n            assert self.n_remains > 0, \"n_remains is expected to be positive.\"\n\n            T_low, T_high = basis\n            V = activation\n            R_low = _reconstruct(T_low, V, axis1=axis1, axis2=axis2)\n            R_high = _reconstruct(T_high, V, axis1=axis1, axis2=axis2)\n            R = R_low, R_high\n        else:\n            T = basis\n            V = activation\n            R = _reconstruct(T, V, axis1=axis1, axis2=axis2)\n\n        return R\n\n    def normalize_block_decomposition_psdtf(self, axis1: int = -2, axis2: int = -1) -> None:\n        r\"\"\"Normalize PSDTF parameters using block decomposition of bases.\n\n        Args:\n            axis1 (int):\n                First axis of covariance matrix. Default: ``-2``.\n            axis2 (int):\n                Second axis of covariance matrix. Default: ``-1``.\n        \"\"\"\n        source_normalization = self.source_normalization\n        n_remains = self.n_remains\n        na = np.newaxis\n        T, V = self.basis, self.activation\n\n        assert source_normalization, \"Set source_normalization.\"\n\n        if n_remains > 0:\n            T_low, T_high = T\n            trace_low = np.trace(T_low, axis1=axis1, axis2=axis2).real\n            trace_high = np.trace(T_high, axis1=axis1, axis2=axis2).real\n            trace = np.sum(trace_low, axis=-1) + np.sum(trace_high, axis=-1)\n            T_low = T_low / trace[:, :, na, na, na]\n            T_high = T_high / trace[:, :, na, na, na]\n            T = T_low, T_high\n        else:\n            trace = np.trace(T, axis1=axis1, axis2=axis2).real\n            trace = np.sum(trace, axis=-1)\n            T = T / trace[:, :, na, na, na]\n\n        V = V * trace[:, :, na]\n\n        self.basis, self.activation = T, V\n\n\nclass GaussIPSDTA(BlockDecompositionIPSDTABase):\n    r\"\"\"Independent positive semidefinite tensor analysis (IPSDTA) \\\n    on Gaussian distribution.\n\n    Args:\n        n_basis (int):\n            Number of PSDTF bases.\n        n_blocks (int):\n            Number of sub-blocks.\n        source_algorithm (str):\n            Algorithm for PSDTF updates.\n            Choose ``EM``, or ``MM``. Default: ``MM``.\n        spatial_algorithm (str):\n            Algorithm for demixing filter updates.\n            Choose ``FPI``, or ``VCD``. Default: ``VCD``.\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        source_normalization (bool):\n            If ``source_normalization=True``, normalize PSDTF parameters.\n            Default: ``True``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back`` explicitly.\n            Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the update algorithm if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back.\n            Default: ``0``.\n        rng (numpy.random.Generator, optioinal):\n            Random number generator. This is mainly used to randomly initialize PSDTF.\n            If ``None`` is given, ``np.random.default_rng()`` is used.\n            Default: ``None``.\n    \"\"\"\n\n    def __init__(\n        self,\n        n_basis: int,\n        n_blocks: int,\n        source_algorithm: str = \"MM\",\n        spatial_algorithm: str = \"VCD\",\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        callbacks: Optional[\n            Union[\n                Callable[[\"GaussIPSDTA\"], None],\n                List[Callable[[\"GaussIPSDTA\"], None]],\n            ]\n        ] = None,\n        source_normalization: Optional[Union[bool, str]] = True,\n        scale_restoration: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n        rng: Optional[np.random.Generator] = None,\n    ) -> None:\n        super().__init__(\n            n_basis,\n            n_blocks,\n            flooring_fn,\n            callbacks,\n            scale_restoration,\n            record_loss,\n            reference_id,\n            rng,\n        )\n\n        assert source_algorithm in source_algorithms, \"Not support {}.\".format(source_algorithms)\n        assert spatial_algorithm in spatial_algorithms, \"Not support {}.\".format(spatial_algorithms)\n\n        self.source_algorithm = source_algorithm\n        self.spatial_algorithm = spatial_algorithm\n        self.source_normalization = source_normalization\n\n    def __repr__(self) -> str:\n        s = \"GaussIPSDTA(\"\n        s += \"n_basis={n_basis}\"\n        s += \", n_blocks={n_blocks}\"\n        s += \", source_algorithm={source_algorithm}\"\n        s += \", spatial_algorithm={spatial_algorithm}\"\n        s += \", source_normalization={source_normalization}\"\n        s += \", scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def _reset(self, **kwargs) -> None:\n        r\"\"\"Reset attributes by given keyword arguments.\n\n        Args:\n            kwargs:\n                Keyword arguments to set as attributes of IPSDTA.\n        \"\"\"\n        super()._reset(**kwargs)\n\n        if self.spatial_algorithm == \"FPI\":\n            if not hasattr(self, \"fixed_point\"):\n                n_sources = self.n_sources\n                n_bins = self.n_bins\n\n                self.fixed_point = np.ones((n_sources, n_bins), dtype=np.complex128)\n            else:\n                self.fixed_point = self.fixed_point.copy()\n\n            raise NotImplementedError(\"IPSDTA with fixed-point iteration is not supported.\")\n\n    def update_once(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update PSDTF parameters and demixing filters once.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        self.update_source_model(flooring_fn=flooring_fn)\n        self.update_spatial_model(flooring_fn=flooring_fn)\n\n    def update_source_model(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update PSDTF basis matrices and activations.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if self.source_algorithm == \"MM\":\n            self.update_source_model_mm(flooring_fn=flooring_fn)\n        else:\n            raise NotImplementedError(\"Not support {}.\".format(self.source_algorithm))\n\n        if self.source_normalization:\n            self.normalize_block_decomposition_psdtf()\n\n    def update_source_model_mm(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update PSDTF basis matrices and activations by MM algorithm.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        self.update_basis_mm(flooring_fn=flooring_fn)\n        self.update_activation_mm()\n\n    def update_basis_mm(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update PSDTF basis matrices by MM algorithm.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        n_sources = self.n_sources\n        n_frames = self.n_frames\n\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        def _update_basis_mm(\n            basis: np.ndarray, activation: np.ndarray, separated: np.ndarray = None\n        ) -> np.ndarray:\n            r\"\"\"\n            Args:\n                basis: (n_sources, n_basis, n_blocks, n_neighbors, n_neighbors)\n                activation: (n_sources, n_basis, n_frames)\n                separated: (n_sources, n_blocks, n_neighbors, n_frames)\n\n            Returns:\n                numpy.ndarray of updated basis matrix.\n            \"\"\"\n            T, V = basis, activation\n            Y = separated\n            na = np.newaxis\n\n            R = self.reconstruct_block_decomposition_psdtf(T, V)\n            R_inverse = np.linalg.inv(R)\n            Y = Y.transpose(0, 3, 1, 2)\n            YY_Hermite = Y[:, :, :, :, na] @ Y[:, :, :, na, :].conj()\n            RYYR = R_inverse @ YY_Hermite @ R_inverse\n\n            P = np.mean(\n                V[:, :, :, na, na, na] * R_inverse[:, na, :, :, :, :],\n                axis=2,\n            )\n            Q = np.mean(\n                V[:, :, :, na, na, na] * RYYR[:, na, :, :, :, :],\n                axis=2,\n            )\n            TQT = T @ Q @ T\n\n            P = to_psd(P, flooring_fn=flooring_fn)\n            TQT = to_psd(TQT, flooring_fn=flooring_fn)\n\n            # geometric mean of P^(-1) and TQT\n            T = gmeanmh(P, TQT, type=2)\n            T = to_psd(T, flooring_fn=flooring_fn)\n\n            return T\n\n        n_bins = self.n_bins\n        n_blocks = self.n_blocks\n        n_remains = self.n_remains\n        n_neighbors = n_bins // n_blocks\n\n        X, W = self.input, self.demix_filter\n        T, V = self.basis, self.activation\n        Y = self.separate(X, demix_filter=W)\n\n        if n_remains > 0:\n            T_low, T_high = T\n            Y_low, Y_high = np.split(Y, [(n_blocks - n_remains) * n_neighbors], axis=1)\n            Y_low = Y_low.reshape(n_sources, n_blocks - n_remains, n_neighbors, n_frames)\n            Y_high = Y_high.reshape(n_sources, n_remains, n_neighbors + 1, n_frames)\n\n            T_low = _update_basis_mm(T_low, V, separated=Y_low)\n            T_high = _update_basis_mm(T_high, V, separated=Y_high)\n            T = T_low, T_high\n        else:\n            Y = Y.reshape(n_sources, n_blocks, n_neighbors, n_frames)\n            T = _update_basis_mm(T, V, separated=Y)\n\n        self.basis = T\n\n    def update_activation_mm(self) -> None:\n        r\"\"\"Update PSDTF activations by MM algorithm.\"\"\"\n\n        def _compute_traces(\n            basis: np.ndarray, activation: np.ndarray, separated: np.ndarray = None\n        ) -> Tuple[np.ndarray, np.ndarray]:\n            r\"\"\"\n            Args:\n                basis: (n_sources, n_basis, n_blocks, n_neighbors, n_neighbors)\n                activation: (n_sources, n_basis, n_frames)\n                separated: (n_sources, n_blocks, n_neighbors, n_frames)\n\n            Returns:\n                Tuple of numerator and denominator.\n                Type of each item is ``numpy.ndarray``.\n            \"\"\"\n            T, V = basis, activation\n            Y = separated\n            na = np.newaxis\n\n            R = self.reconstruct_block_decomposition_psdtf(T, V)\n            R_inverse = np.linalg.inv(R)\n            Y = Y.transpose(0, 3, 1, 2)\n            YY_Hermite = Y[:, :, :, :, na] @ Y[:, :, :, na, :].conj()\n            RYYR = R_inverse @ YY_Hermite @ R_inverse\n\n            num = np.trace(RYYR[:, na, :] @ T[:, :, na], axis1=-2, axis2=-1)\n            denom = np.trace(R_inverse[:, na, :] @ T[:, :, na], axis1=-2, axis2=-1)\n            num = np.real(num).sum(axis=-1)\n            denom = np.real(denom).sum(axis=-1)\n\n            return num, denom\n\n        n_sources = self.n_sources\n        n_bins, n_frames = self.n_bins, self.n_frames\n        n_blocks = self.n_blocks\n        n_remains = self.n_remains\n        n_neighbors = n_bins // n_blocks\n\n        X, W = self.input, self.demix_filter\n        T, V = self.basis, self.activation\n        Y = self.separate(X, demix_filter=W)\n\n        if n_remains > 0:\n            T_low, T_high = T\n            Y_low, Y_high = np.split(Y, [(n_blocks - n_remains) * n_neighbors], axis=1)\n            Y_low = Y_low.reshape(n_sources, n_blocks - n_remains, n_neighbors, n_frames)\n            Y_high = Y_high.reshape(n_sources, n_remains, n_neighbors + 1, n_frames)\n\n            num_low, denom_low = _compute_traces(T_low, V, separated=Y_low)\n            num_high, denom_high = _compute_traces(T_high, V, separated=Y_high)\n\n            num = num_low + num_high\n            denom = denom_low + denom_high\n        else:\n            Y = Y.reshape(n_sources, n_blocks, n_neighbors, n_frames)\n            num, denom = _compute_traces(T, V, separated=Y)\n\n        self.activation = V * np.sqrt(num / denom)\n\n    def update_spatial_model(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update demixing filters once.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if self.spatial_algorithm == \"VCD\":\n            self.update_spatial_model_vcd(flooring_fn=flooring_fn)\n        else:\n            raise NotImplementedError(\"Not support {}.\".format(self.spatial_algorithm))\n\n    def update_spatial_model_vcd(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update demixing filters once by VCD.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        def _update(input: np.ndarray, demix_filter: np.ndarray, covariance: np.ndarray):\n            r\"\"\"\n            Args:\n                input (np.ndarray):\n                    Mixture spectrogram.\n                    The shape is (n_channels, n_blocks, n_neighbors, n_frames).\n                demix_filter (np.ndarray):\n                    Demixing filter split by frequnecy bands.\n                    The shape is (n_blocks, n_neighbors, n_sources, n_channels).\n                covariance (np.ndarray):\n                    Rconstructed PSDTF.\n                    The shape is (n_sources, n_frames, n_blocks, n_neighbors, n_neighbors).\n\n            Returns:\n                np.ndarray of demixing filters after update.\n            \"\"\"\n            X, W = input, demix_filter\n            R = covariance\n\n            XX = X[:, na, :, :, na] * X[na, :, :, na, :].conj()\n            XX = XX.transpose(2, 3, 4, 0, 1, 5)\n\n            R_inverse = np.linalg.inv(R)\n            R_inverse = R_inverse.transpose(2, 4, 3, 0, 1)\n\n            RXX = np.mean(R_inverse[:, :, :, :, na, na] * XX[:, :, :, na, :, :], axis=-1)\n\n            W = update_by_block_decomposition_vcd(\n                W, weighted_covariance=RXX, singular_fn=lambda x: np.abs(x) < flooring_fn(0)\n            )\n\n            return W\n\n        n_sources, n_channels = self.n_sources, self.n_channels\n        n_bins, n_frames = self.n_bins, self.n_frames\n        n_blocks = self.n_blocks\n        n_remains = self.n_remains\n        na = np.newaxis\n\n        n_neighbors = n_bins // n_blocks\n\n        X, W = self.input, self.demix_filter\n        T, V = self.basis, self.activation\n\n        R = self.reconstruct_block_decomposition_psdtf(T, V)\n\n        if n_remains > 0:\n            X_low, X_high = np.split(X, [(n_blocks - n_remains) * n_neighbors], axis=1)\n            W_low, W_high = np.split(W, [(n_blocks - n_remains) * n_neighbors], axis=0)\n            R_low, R_high = R\n\n            # Lower frequency\n            X_low = X_low.reshape(n_channels, n_blocks - n_remains, n_neighbors, n_frames)\n            W_low = W_low.reshape(n_blocks - n_remains, n_neighbors, n_sources, n_channels)\n            W_low = _update(X_low, demix_filter=W_low, covariance=R_low)\n\n            # Higher frequency\n            X_high = X_high.reshape(n_channels, n_remains, n_neighbors + 1, n_frames)\n            W_high = W_high.reshape(n_remains, n_neighbors + 1, n_sources, n_channels)\n            W_high = _update(X_high, demix_filter=W_high, covariance=R_high)\n\n            W_low = W_low.reshape((n_blocks - n_remains) * n_neighbors, n_sources, n_channels)\n            W_high = W_high.reshape(n_remains * (n_neighbors + 1), n_sources, n_channels)\n            W = np.concatenate([W_low, W_high], axis=0)\n        else:\n            X = X.reshape(n_channels, n_blocks, n_neighbors, n_frames)\n            W = W.reshape(n_blocks, n_neighbors, n_sources, n_channels)\n            W = _update(X, demix_filter=W, covariance=R)\n            W = W.reshape(n_blocks * n_neighbors, n_sources, n_channels)\n\n        self.demix_filter = W\n\n    def compute_loss(self) -> float:\n        r\"\"\"Compute loss :math:`\\mathcal{L}`.\n\n        Returns:\n            Computed loss.\n        \"\"\"\n\n        def _compute_block_decomposition_loss(\n            separated: np.ndarray, demix_filter: np.ndarray, covariance: np.ndarray\n        ) -> float:\n            r\"\"\"\n            Args:\n                separated (np.ndarray):\n                    Separated signal with shape of (n_sources, n_frames, n_blocks, n_neighbors).\n                demix_filter (np.ndarray):\n                    Demixing filters with shape of (n_blocks, n_neighbors, n_sources, n_channels).\n                covariance:\n                    Covariance matrix with shape of\n                    (n_sources, n_frames, n_blocks, n_neighbors, n_neighbors).\n            \"\"\"\n            Y, W = separated, demix_filter\n            R = covariance\n\n            n_sources, n_frames, n_blocks, n_neighbors = Y.shape\n\n            Y = Y.reshape(n_sources, n_frames, n_blocks, n_neighbors, 1)\n            R_inverse = np.linalg.inv(R)\n            Y_Hermite = np.swapaxes(Y, 3, 4).conj()\n            YRY = np.sum(Y_Hermite @ R_inverse @ Y, axis=(0, 2, 3, 4))\n            YRY = np.real(YRY)\n            YRY = np.maximum(YRY, 0)\n            _, logdetR = np.linalg.slogdet(R)\n            logdetR = logdetR.sum(axis=(0, 2))\n            logdetW = self.compute_logdet(W)\n\n            loss = np.mean(YRY + logdetR, axis=0) - 2 * logdetW.sum(axis=(0, 1))\n            loss = loss.item()\n\n            return loss\n\n        n_sources, n_channels = self.n_sources, self.n_channels\n        n_bins, n_frames = self.n_bins, self.n_frames\n        n_blocks = self.n_blocks\n        n_remains = self.n_remains\n\n        n_neighbors = n_bins // n_blocks\n\n        X, W = self.input, self.demix_filter\n        Y = self.separate(X, demix_filter=W)\n        Y = Y.transpose(0, 2, 1)\n        T, V = self.basis, self.activation\n\n        R = self.reconstruct_block_decomposition_psdtf(T, V)\n\n        if n_remains > 0:\n            Y_low, Y_high = np.split(Y, [(n_blocks - n_remains) * n_neighbors], axis=2)\n            W_low, W_high = np.split(W, [(n_blocks - n_remains) * n_neighbors], axis=0)\n            R_low, R_high = R\n\n            Y_low = Y_low.reshape(n_sources, n_frames, (n_blocks - n_remains), n_neighbors)\n            Y_high = Y_high.reshape(n_sources, n_frames, n_remains, n_neighbors + 1)\n            W_low = W_low.reshape((n_blocks - n_remains), n_neighbors, n_sources, n_channels)\n            W_high = W_high.reshape(n_remains, n_neighbors + 1, n_sources, n_channels)\n\n            loss_low = _compute_block_decomposition_loss(\n                Y_low, demix_filter=W_low, covariance=R_low\n            )\n            loss_high = _compute_block_decomposition_loss(\n                Y_high, demix_filter=W_high, covariance=R_high\n            )\n\n            loss = loss_low + loss_high\n        else:\n            Y = Y.reshape(n_sources, n_frames, n_blocks, n_neighbors)\n            W = W.reshape(n_blocks, n_neighbors, n_sources, n_channels)\n\n            loss = _compute_block_decomposition_loss(Y, demix_filter=W, covariance=R)\n\n        return loss\n\n\nclass TIPSDTA(BlockDecompositionIPSDTABase):\n    r\"\"\"Independent positive semidefinite tensor analysis (IPSDTA) \\\n    on Student's t distribution.\n\n    Args:\n        n_basis (int):\n            Number of PSDTF bases.\n        n_blocks (int):\n            Number of sub-blocks.\n        dof (float):\n            Degree of freedom parameter.\n        source_algorithm (str):\n            Algorithm for PSDTF updates.\n            Only ``MM`` is supported. Default: ``MM``.\n        spatial_algorithm (str):\n            Algorithm for demixing filter updates.\n            Only ``VCD`` is supported. Default: ``VCD``.\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        source_normalization (bool):\n            If ``source_normalization=True``, normalize PSDTF parameters.\n            Default: ``True``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back`` explicitly.\n            Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the update algorithm if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back.\n            Default: ``0``.\n        rng (numpy.random.Generator, optioinal):\n            Random number generator. This is mainly used to randomly initialize PSDTF.\n            If ``None`` is given, ``np.random.default_rng()`` is used.\n            Default: ``None``.\n    \"\"\"\n\n    def __init__(\n        self,\n        n_basis: int,\n        n_blocks: int,\n        dof: float,\n        source_algorithm: str = \"MM\",\n        spatial_algorithm: str = \"VCD\",\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        callbacks: Optional[\n            Union[\n                Callable[[\"GaussIPSDTA\"], None],\n                List[Callable[[\"GaussIPSDTA\"], None]],\n            ]\n        ] = None,\n        source_normalization: Optional[Union[bool, str]] = True,\n        scale_restoration: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n        rng: Optional[np.random.Generator] = None,\n    ) -> None:\n        super().__init__(\n            n_basis,\n            n_blocks,\n            flooring_fn,\n            callbacks,\n            scale_restoration,\n            record_loss,\n            reference_id,\n            rng,\n        )\n\n        assert source_algorithm in source_algorithms, \"Not support {}.\".format(source_algorithm)\n        assert spatial_algorithm in spatial_algorithms, \"Not support {}.\".format(spatial_algorithm)\n\n        self.dof = dof\n        self.source_algorithm = source_algorithm\n        self.source_normalization = source_normalization\n        self.spatial_algorithm = spatial_algorithm\n\n    def __repr__(self) -> str:\n        s = \"TIPSDTA(\"\n        s += \"n_basis={n_basis}\"\n        s += \", n_blocks={n_blocks}\"\n        s += \", dof={dof}\"\n        s += \", source_algorithm={source_algorithm}\"\n        s += \", spatial_algorithm={spatial_algorithm}\"\n        s += \", source_normalization={source_normalization}\"\n        s += \", scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def update_once(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update PSDTF parameters and demixing filters once.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        self.update_source_model(flooring_fn=flooring_fn)\n        self.update_spatial_model(flooring_fn=flooring_fn)\n\n    def update_source_model(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update PSDTF basis matrices and activations.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if self.source_algorithm == \"MM\":\n            self.update_source_model_mm(flooring_fn=flooring_fn)\n        else:\n            raise NotImplementedError(\"Not support {}.\".format(self.source_algorithm))\n\n        if self.source_normalization:\n            self.normalize_block_decomposition_psdtf()\n\n    def update_source_model_mm(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update PSDTF basis matrices and activations by MM algorithm.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        self.update_basis_mm(flooring_fn=flooring_fn)\n        self.update_activation_mm()\n\n    def update_basis_mm(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update PSDTF basis matrices by MM algorithm.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        n_sources = self.n_sources\n        n_frames = self.n_frames\n\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        def _quadratic(Y: np.ndarray, R: np.ndarray) -> np.ndarray:\n            r\"\"\"\n            Args:\n                Y (np.ndarray):\n                    Separated spectrams with shape of\n                    (n_sources, n_blocks, n_neighbors, n_frames).\n                R (np.ndarray):\n                    Covariance matrix with shape of\n                    (n_sources, n_frames, n_blocks, n_neighbors, n_neighbors).\n\n            Returns:\n                Quadratic forms with shape of (n_sources, n_frames).\n            \"\"\"\n            Y = Y.transpose(0, 3, 1, 2)\n            R_inverse = np.linalg.inv(R)\n\n            YRY = quadratic(Y, R_inverse)\n            YRY = np.real(YRY)\n            YRY = np.maximum(YRY, 0)\n            YRY = YRY.sum(axis=-1)\n\n            return YRY\n\n        def _update_basis_mm(\n            basis: np.ndarray,\n            activation: np.ndarray,\n            separated: np.ndarray = None,\n            weight: np.ndarray = None,\n        ) -> np.ndarray:\n            r\"\"\"\n            Args:\n                basis: (n_sources, n_basis, n_blocks, n_neighbors, n_neighbors)\n                activation: (n_sources, n_basis, n_frames)\n                separated: (n_sources, n_blocks, n_neighbors, n_frames)\n                weight: (n_sources, n_frames)\n\n            Returns:\n                numpy.ndarray of updated basis matrix.\n            \"\"\"\n            T, V = basis, activation\n            Y = separated\n            pi = weight\n            na = np.newaxis\n\n            R = self.reconstruct_block_decomposition_psdtf(T, V)\n            R_inverse = np.linalg.inv(R)\n            Y = Y.transpose(0, 3, 1, 2)\n            YY_Hermite = Y[:, :, :, :, na] @ Y[:, :, :, na, :].conj()\n            RYYR = R_inverse @ YY_Hermite @ R_inverse\n            piRYYR = pi[:, :, na, na, na] * RYYR\n\n            P = np.mean(\n                V[:, :, :, na, na, na] * R_inverse[:, na, :, :, :, :],\n                axis=2,\n            )\n            Q = np.mean(\n                V[:, :, :, na, na, na] * piRYYR[:, na, :, :, :, :],\n                axis=2,\n            )\n            Q = to_psd(Q, flooring_fn=flooring_fn)\n            Q_sqrt = sqrtmh(Q)\n\n            QTPTQ = Q_sqrt @ T @ P @ T @ Q_sqrt\n            QTPTQ = to_psd(QTPTQ, flooring_fn=flooring_fn)\n            T = T @ Q_sqrt @ invsqrtmh(QTPTQ, flooring_fn=flooring_fn) @ Q_sqrt @ T\n            T = to_psd(T, flooring_fn=flooring_fn)\n\n            return T\n\n        n_bins = self.n_bins\n        n_blocks = self.n_blocks\n        n_remains = self.n_remains\n        n_neighbors = n_bins // n_blocks\n\n        nu = self.dof\n\n        X, W = self.input, self.demix_filter\n        T, V = self.basis, self.activation\n\n        Y = self.separate(X, demix_filter=W)\n        R = self.reconstruct_block_decomposition_psdtf(T, V)\n\n        if n_remains > 0:\n            T_low, T_high = T\n            Y_low, Y_high = np.split(Y, [(n_blocks - n_remains) * n_neighbors], axis=1)\n            Y_low = Y_low.reshape(n_sources, n_blocks - n_remains, n_neighbors, n_frames)\n            Y_high = Y_high.reshape(n_sources, n_remains, n_neighbors + 1, n_frames)\n            R_low, R_high = R\n\n            YRY_low = _quadratic(Y_low, R_low)\n            YRY_high = _quadratic(Y_high, R_high)\n\n            YRY = YRY_low + YRY_high\n            pi = (nu + 2 * n_bins) / (nu + 2 * YRY)\n\n            T_low = _update_basis_mm(T_low, V, separated=Y_low, weight=pi)\n            T_high = _update_basis_mm(T_high, V, separated=Y_high, weight=pi)\n            T = T_low, T_high\n        else:\n            Y = Y.reshape(n_sources, n_blocks, n_neighbors, n_frames)\n            YRY = _quadratic(Y, R)\n            pi = (nu + 2 * n_bins) / (nu + 2 * YRY)\n\n            T = _update_basis_mm(T, V, separated=Y, weight=pi)\n\n        self.basis = T\n\n    def update_activation_mm(self) -> None:\n        r\"\"\"Update PSDTF activations by MM algorithm.\"\"\"\n\n        def _quadratic(Y: np.ndarray, R: np.ndarray) -> np.ndarray:\n            r\"\"\"\n            Args:\n                Y (np.ndarray):\n                    Separated spectrams with shape of\n                    (n_sources, n_blocks, n_neighbors, n_frames).\n                R (np.ndarray):\n                    Covariance matrix with shape of\n                    (n_sources, n_frames, n_blocks, n_neighbors, n_neighbors).\n\n            Returns:\n                Quadratic forms with shape of (n_sources, n_frames).\n            \"\"\"\n            Y = Y.transpose(0, 3, 1, 2)\n            R_inverse = np.linalg.inv(R)\n\n            YRY = quadratic(Y, R_inverse)\n            YRY = np.real(YRY)\n            YRY = np.maximum(YRY, 0)\n            YRY = YRY.sum(axis=-1)\n\n            return YRY\n\n        def _compute_traces(\n            basis: np.ndarray,\n            activation: np.ndarray,\n            separated: np.ndarray = None,\n            weight: np.ndarray = None,\n        ) -> Tuple[np.ndarray, np.ndarray]:\n            r\"\"\"\n            Args:\n                basis: (n_sources, n_basis, n_blocks, n_neighbors, n_neighbors)\n                activation: (n_sources, n_basis, n_frames)\n                separated: (n_sources, n_blocks, n_neighbors, n_frames)\n\n            Returns:\n                Tuple of numerator and denominator.\n                Type of each item is ``numpy.ndarray``.\n            \"\"\"\n            T, V = basis, activation\n            Y = separated.transpose(0, 3, 1, 2)\n            pi = weight\n            na = np.newaxis\n\n            R = self.reconstruct_block_decomposition_psdtf(T, V)\n            R_inverse = np.linalg.inv(R)\n            YY_Hermite = Y[:, :, :, :, na] @ Y[:, :, :, na, :].conj()\n            RYYR = R_inverse @ YY_Hermite @ R_inverse\n\n            piRYYR = pi[:, :, na, na, na] * RYYR\n\n            num = np.trace(piRYYR[:, na, :] @ T[:, :, na], axis1=-2, axis2=-1)\n            denom = np.trace(R_inverse[:, na, :] @ T[:, :, na], axis1=-2, axis2=-1)\n            num = np.real(num).sum(axis=-1)\n            denom = np.real(denom).sum(axis=-1)\n\n            return num, denom\n\n        n_sources = self.n_sources\n        n_bins, n_frames = self.n_bins, self.n_frames\n        n_blocks = self.n_blocks\n        n_remains = self.n_remains\n        n_neighbors = n_bins // n_blocks\n\n        nu = self.dof\n\n        X, W = self.input, self.demix_filter\n        T, V = self.basis, self.activation\n\n        Y = self.separate(X, demix_filter=W)\n        R = self.reconstruct_block_decomposition_psdtf(T, V)\n\n        if n_remains > 0:\n            T_low, T_high = T\n            Y_low, Y_high = np.split(Y, [(n_blocks - n_remains) * n_neighbors], axis=1)\n            Y_low = Y_low.reshape(n_sources, n_blocks - n_remains, n_neighbors, n_frames)\n            Y_high = Y_high.reshape(n_sources, n_remains, n_neighbors + 1, n_frames)\n            R_low, R_high = R\n\n            YRY_low = _quadratic(Y_low, R_low)\n            YRY_high = _quadratic(Y_high, R_high)\n\n            YRY = YRY_low + YRY_high\n            pi = (nu + 2 * n_bins) / (nu + 2 * YRY)\n\n            num_low, denom_low = _compute_traces(T_low, V, separated=Y_low, weight=pi)\n            num_high, denom_high = _compute_traces(T_high, V, separated=Y_high, weight=pi)\n\n            num = num_low + num_high\n            denom = denom_low + denom_high\n        else:\n            Y = Y.reshape(n_sources, n_blocks, n_neighbors, n_frames)\n            YRY = _quadratic(Y, R)\n            pi = (nu + 2 * n_bins) / (nu + 2 * YRY)\n\n            num, denom = _compute_traces(T, V, separated=Y, weight=pi)\n\n        self.activation = V * np.sqrt(num / denom)\n\n    def update_spatial_model(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update demixing filters once.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if self.spatial_algorithm == \"VCD\":\n            self.update_spatial_model_vcd(flooring_fn=flooring_fn)\n        else:\n            raise NotImplementedError(\"Not support {}.\".format(self.spatial_algorithm))\n\n    def update_spatial_model_vcd(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update demixing filters once by VCD.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        def _quadratic(Y: np.ndarray, R: np.ndarray) -> np.ndarray:\n            r\"\"\"\n            Args:\n                Y (np.ndarray):\n                    Separated spectrams with shape of\n                    (n_sources, n_blocks, n_neighbors, n_frames).\n                R (np.ndarray):\n                    Covariance matrix with shape of\n                    (n_sources, n_frames, n_blocks, n_neighbors, n_neighbors).\n\n            Returns:\n                Quadratic forms with shape of (n_sources, n_frames).\n            \"\"\"\n            Y = Y.transpose(2, 3, 0, 1)\n            R_inverse = np.linalg.inv(R)\n\n            YRY = quadratic(Y, R_inverse)\n            YRY = np.real(YRY)\n            YRY = np.maximum(YRY, 0)\n            YRY = YRY.sum(axis=-1)\n\n            return YRY\n\n        def _update(\n            input: np.ndarray,\n            demix_filter: np.ndarray,\n            covariance: np.ndarray,\n            weight: np.ndarray = None,\n        ):\n            X, W = input, demix_filter\n            R = covariance\n            pi = weight\n\n            na = np.newaxis\n\n            XX = X[:, na, :, :, na] * X[na, :, :, na, :].conj()\n            XX = XX.transpose(2, 3, 4, 0, 1, 5)\n\n            R_inverse = np.linalg.inv(R)\n            R_inverse = R_inverse.transpose(2, 4, 3, 0, 1)\n            pi_R_inverse = pi * R_inverse\n\n            RXX = np.mean(pi_R_inverse[:, :, :, :, na, na] * XX[:, :, :, na, :, :], axis=-1)\n\n            W = update_by_block_decomposition_vcd(\n                W, weighted_covariance=RXX, singular_fn=lambda x: np.abs(x) < flooring_fn(0)\n            )\n\n            return W\n\n        n_sources, n_channels = self.n_sources, self.n_channels\n        n_bins, n_frames = self.n_bins, self.n_frames\n        n_blocks = self.n_blocks\n        n_remains = self.n_remains\n        n_neighbors = n_bins // n_blocks\n\n        nu = self.dof\n\n        X, W = self.input, self.demix_filter\n        T, V = self.basis, self.activation\n\n        R = self.reconstruct_block_decomposition_psdtf(T, V)\n\n        if n_remains > 0:\n            X_low, X_high = np.split(X, [(n_blocks - n_remains) * n_neighbors], axis=1)\n            W_low, W_high = np.split(W, [(n_blocks - n_remains) * n_neighbors], axis=0)\n            R_low, R_high = R\n\n            # Lower frequency\n            X_low = X_low.reshape(n_channels, n_blocks - n_remains, n_neighbors, n_frames)\n            W_low = W_low.reshape(n_blocks - n_remains, n_neighbors, n_sources, n_channels)\n            Y_low = W_low @ X_low.transpose(1, 2, 0, 3)\n\n            # Higher frequency\n            X_high = X_high.reshape(n_channels, n_remains, n_neighbors + 1, n_frames)\n            W_high = W_high.reshape(n_remains, n_neighbors + 1, n_sources, n_channels)\n            Y_high = W_high @ X_high.transpose(1, 2, 0, 3)\n\n            YRY_low = _quadratic(Y_low, R_low)\n            YRY_high = _quadratic(Y_high, R_high)\n\n            YRY = YRY_low + YRY_high\n            pi = (nu + 2 * n_bins) / (nu + 2 * YRY)\n\n            W_low = _update(X_low, demix_filter=W_low, covariance=R_low, weight=pi)\n            W_high = _update(X_high, demix_filter=W_high, covariance=R_high, weight=pi)\n\n            W_low = W_low.reshape((n_blocks - n_remains) * n_neighbors, n_sources, n_channels)\n            W_high = W_high.reshape(n_remains * (n_neighbors + 1), n_sources, n_channels)\n            W = np.concatenate([W_low, W_high], axis=0)\n        else:\n            X = X.reshape(n_channels, n_blocks, n_neighbors, n_frames)\n            W = W.reshape(n_blocks, n_neighbors, n_sources, n_channels)\n            Y = W @ X.transpose(1, 2, 0, 3)\n\n            YRY = _quadratic(Y, R)\n            pi = (nu + 2 * n_bins) / (nu + 2 * YRY)\n\n            W = _update(X, demix_filter=W, covariance=R, weight=pi)\n            W = W.reshape(n_blocks * n_neighbors, n_sources, n_channels)\n\n        self.demix_filter = W\n\n    def compute_loss(self) -> float:\n        r\"\"\"Compute loss :math:`\\mathcal{L}`.\n\n        Returns:\n            Computed loss.\n        \"\"\"\n\n        def _quadratic(Y: np.ndarray, R: np.ndarray) -> np.ndarray:\n            r\"\"\"\n            Args:\n                Y (np.ndarray):\n                    Separated spectrams with shape of\n                    (n_sources, n_blocks, n_neighbors, n_frames).\n                R (np.ndarray):\n                    Covariance matrix with shape of\n                    (n_sources, n_frames, n_blocks, n_neighbors, n_neighbors).\n\n            Returns:\n                Quadratic forms with shape of (n_sources, n_frames).\n            \"\"\"\n            Y = Y.transpose(0, 3, 1, 2)\n            R_inverse = np.linalg.inv(R)\n\n            YRY = quadratic(Y, R_inverse)\n            YRY = np.real(YRY)\n            YRY = np.maximum(YRY, 0)\n            YRY = YRY.sum(axis=-1)\n\n            return YRY\n\n        n_sources, n_channels = self.n_sources, self.n_channels\n        n_bins, n_frames = self.n_bins, self.n_frames\n\n        nu = self.dof\n\n        n_blocks = self.n_blocks\n        n_remains = self.n_remains\n\n        n_neighbors = n_bins // n_blocks\n\n        X, W = self.input, self.demix_filter\n        Y = self.separate(X, demix_filter=W)\n        T, V = self.basis, self.activation\n\n        R = self.reconstruct_block_decomposition_psdtf(T, V)\n\n        if n_remains > 0:\n            Y_low, Y_high = np.split(Y, [(n_blocks - n_remains) * n_neighbors], axis=1)\n            W_low, W_high = np.split(W, [(n_blocks - n_remains) * n_neighbors], axis=0)\n            R_low, R_high = R\n\n            Y_low = Y_low.reshape(n_sources, (n_blocks - n_remains), n_neighbors, n_frames)\n            Y_high = Y_high.reshape(n_sources, n_remains, n_neighbors + 1, n_frames)\n            W_low = W_low.reshape((n_blocks - n_remains), n_neighbors, n_sources, n_channels)\n            W_high = W_high.reshape(n_remains, n_neighbors + 1, n_sources, n_channels)\n\n            YRY_low = _quadratic(Y_low, R_low)\n            YRY_high = _quadratic(Y_high, R_high)\n\n            YRY = YRY_low + YRY_high\n\n            loss = np.sum(((nu + 2 * n_bins) / 2) * np.log(1 + (2 / nu) * YRY), axis=0)\n\n            _, logdetR_low = np.linalg.slogdet(R_low)\n            logdetR_low = logdetR_low.sum(axis=(0, 2))\n            _, logdetR_high = np.linalg.slogdet(R_high)\n            logdetR_high = logdetR_high.sum(axis=(0, 2))\n            logdetR = logdetR_low + logdetR_high\n\n            logdetW_low = self.compute_logdet(W_low)\n            logdetW_high = self.compute_logdet(W_high)\n\n            logdetW = logdetW_low.sum(axis=(0, 1)) + logdetW_high.sum(axis=(0, 1))\n        else:\n            Y = Y.reshape(n_sources, n_blocks, n_neighbors, n_frames)\n            W = W.reshape(n_blocks, n_neighbors, n_sources, n_channels)\n\n            YRY = _quadratic(Y, R)\n\n            loss = np.sum(((nu + 2 * n_bins) / 2) * np.log(1 + (2 / nu) * YRY), axis=0)\n\n            _, logdetR = np.linalg.slogdet(R)\n            logdetR = logdetR.sum(axis=(0, 2))\n\n            logdetW = self.compute_logdet(W)\n            logdetW = logdetW.sum(axis=(0, 1))\n\n        loss = np.mean(loss + logdetR, axis=0) - 2 * logdetW\n        loss = loss.item()\n\n        return loss\n"
  },
  {
    "path": "ssspy/bss/iva.py",
    "content": "import functools\nfrom typing import Callable, Iterable, List, Optional, Tuple, Union\n\nimport numpy as np\n\nfrom ..algorithm import (\n    MINIMAL_DISTORTION_PRINCIPLE_KEYWORDS,\n    PROJECTION_BACK_KEYWORDS,\n    minimal_distortion_principle,\n    projection_back,\n)\nfrom ..linalg import eigh, prox\nfrom ..special.flooring import identity, max_flooring\nfrom ..transform import whiten\nfrom ..utils.flooring import choose_flooring_fn\nfrom ..utils.select_pair import sequential_pair_selector\nfrom ._update_spatial_model import (\n    update_by_ip1,\n    update_by_ip2_one_pair,\n    update_by_ipa,\n    update_by_iss1,\n    update_by_iss2,\n)\nfrom .admmbss import ADMMBSS\nfrom .base import IterativeMethodBase\nfrom .pdsbss import PDSBSS\n\n__all__ = [\n    \"GradIVA\",\n    \"NaturalGradIVA\",\n    \"FastIVA\",\n    \"FasterIVA\",\n    \"AuxIVA\",\n    \"PDSIVA\",\n    \"ADMMIVA\",\n    \"GradLaplaceIVA\",\n    \"GradGaussIVA\",\n    \"NaturalGradLaplaceIVA\",\n    \"NaturalGradGaussIVA\",\n    \"AuxLaplaceIVA\",\n    \"AuxGaussIVA\",\n]\n\nspatial_algorithms = [\"IP\", \"IP1\", \"IP2\", \"ISS\", \"ISS1\", \"ISS2\", \"IPA\"]\nEPS = 1e-10\n\n\nclass IVABase(IterativeMethodBase):\n    r\"\"\"Base class of independent vector analysis (IVA).\n\n    Args:\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back``\n            or ``minimal_distortion_principle``. Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the update algorithm if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back and minimal distortion principle. Default: ``0``.\n    \"\"\"\n\n    def __init__(\n        self,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        callbacks: Optional[\n            Union[Callable[[\"IVABase\"], None], List[Callable[[\"IVABase\"], None]]]\n        ] = None,\n        scale_restoration: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n    ) -> None:\n        super().__init__(callbacks=callbacks, record_loss=record_loss)\n\n        if flooring_fn is None:\n            self.flooring_fn = identity\n        else:\n            self.flooring_fn = flooring_fn\n\n        self.input = None\n        self.scale_restoration = scale_restoration\n\n        if reference_id is None and scale_restoration:\n            raise ValueError(\"Specify 'reference_id' if scale_restoration=True.\")\n        else:\n            self.reference_id = reference_id\n\n    def __call__(\n        self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs\n    ) -> np.ndarray:\n        r\"\"\"Separate a frequency-domain multichannel signal.\n\n        Args:\n            input (numpy.ndarray):\n                Mixture signal in frequency-domain.\n                The shape is (n_channels, n_bins, n_frames).\n            n_iter (int):\n                Number of iterations of demixing filter updates.\n                Default: ``100``.\n            initial_call (bool):\n                If ``True``, perform callbacks (and computation of loss if necessary)\n                before iterations.\n\n        Returns:\n            numpy.ndarray of the separated signal in frequency-domain.\n            The shape is (n_channels, n_bins, n_frames).\n        \"\"\"\n        self.input = input.copy()\n\n        self._reset(**kwargs)\n\n        raise NotImplementedError(\"Implement '__call__' method.\")\n\n    def __repr__(self) -> str:\n        s = \"IVA(\"\n        s += \"scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def _reset(self, **kwargs) -> None:\n        r\"\"\"Reset attributes by given keyword arguments.\n\n        Args:\n            kwargs:\n                Keyword arguments to set as attributes of IVA.\n        \"\"\"\n        assert self.input is not None, \"Specify data!\"\n\n        for key in kwargs.keys():\n            setattr(self, key, kwargs[key])\n\n        X = self.input\n\n        n_channels, n_bins, n_frames = X.shape\n        n_sources = n_channels  # n_channels == n_sources\n\n        self.n_sources, self.n_channels = n_sources, n_channels\n        self.n_bins, self.n_frames = n_bins, n_frames\n\n        if not hasattr(self, \"demix_filter\"):\n            W = np.eye(n_sources, n_channels, dtype=np.complex128)\n            W = np.tile(W, reps=(n_bins, 1, 1))\n        else:\n            if self.demix_filter is None:\n                W = None\n            else:\n                # To avoid overwriting ``demix_filter`` given by keyword arguments.\n                W = self.demix_filter.copy()\n\n        self.demix_filter = W\n        self.output = self.separate(X, demix_filter=W)\n\n    def separate(self, input: np.ndarray, demix_filter: np.ndarray) -> np.ndarray:\n        r\"\"\"Separate ``input`` using ``demixing_filter``.\n\n        .. math::\n            \\boldsymbol{y}_{ij}\n            = \\boldsymbol{W}_{i}\\boldsymbol{x}_{ij}\n\n        Args:\n            input (numpy.ndarray):\n                The mixture signal in frequency-domain.\n                The shape is (n_channels, n_bins, n_frames).\n            demix_filter (numpy.ndarray):\n                The demixing filters to separate ``input``.\n                The shape is (n_bins, n_sources, n_channels).\n\n        Returns:\n            numpy.ndarray of the separated signal in frequency-domain.\n            The shape is (n_sources, n_bins, n_frames).\n        \"\"\"\n        X, W = input, demix_filter\n        Y = W @ X.transpose(1, 0, 2)\n        output = Y.transpose(1, 0, 2)\n\n        return output\n\n    def update_once(self) -> None:\n        r\"\"\"Update demixing filters once.\"\"\"\n        raise NotImplementedError(\"Implement 'update_once' method.\")\n\n    def compute_loss(self) -> float:\n        r\"\"\"Compute loss :math:`\\mathcal{L}`.\n\n        :math:`\\mathcal{L}` is given as follows:\n\n        .. math::\n            \\mathcal{L} \\\n            &= \\frac{1}{J}\\sum_{j,n}G(\\vec{\\boldsymbol{y}}_{jn}) \\\n            - 2\\sum_{i}\\log|\\det\\boldsymbol{W}_{i}|, \\\\\n            G(\\vec{\\boldsymbol{y}}_{jn}) \\\n            &= - \\log p(\\vec{\\boldsymbol{y}}_{jn})\n\n        Returns:\n            Computed loss.\n        \"\"\"\n        X, W = self.input, self.demix_filter\n        Y = self.separate(X, demix_filter=W)  # (n_sources, n_bins, n_frames)\n        logdet = self.compute_logdet(W)  # (n_bins,)\n        G = self.contrast_fn(Y)  # (n_sources, n_frames)\n        loss = np.sum(np.mean(G, axis=1), axis=0) - 2 * np.sum(logdet, axis=0)\n        loss = loss.item()\n\n        return loss\n\n    def compute_logdet(self, demix_filter: np.ndarray) -> np.ndarray:\n        r\"\"\"Compute log-determinant of demixing filter.\n\n        Args:\n            demix_filter (numpy.ndarray):\n                Demixing filters with shape of (n_bins, n_sources, n_channels).\n\n        Returns:\n            numpy.ndarray of computed log-determinant values.\n        \"\"\"\n        _, logdet = np.linalg.slogdet(demix_filter)  # (n_bins,)\n\n        return logdet\n\n    def restore_scale(self) -> None:\n        r\"\"\"Restore scale ambiguity.\n\n        If ``self.scale_restoration=projection_back``, we use projection back technique.\n        If ``self.scale_restoration=minimal_distortion_principle``,\n        we use minimal distortion principle.\n        \"\"\"\n        scale_restoration = self.scale_restoration\n\n        assert scale_restoration, \"Set self.scale_restoration=True.\"\n\n        if type(scale_restoration) is bool:\n            scale_restoration = PROJECTION_BACK_KEYWORDS[0]\n\n        if scale_restoration in PROJECTION_BACK_KEYWORDS:\n            self.apply_projection_back()\n        elif scale_restoration in MINIMAL_DISTORTION_PRINCIPLE_KEYWORDS:\n            self.apply_minimal_distortion_principle()\n        else:\n            raise ValueError(\"{} is not supported for scale restoration.\".format(scale_restoration))\n\n    def apply_projection_back(self) -> None:\n        r\"\"\"Apply projection back technique to estimated spectrograms.\"\"\"\n        assert self.scale_restoration, \"Set self.scale_restoration=True.\"\n\n        X, W = self.input, self.demix_filter\n        W_scaled = projection_back(W, reference_id=self.reference_id)\n        Y_scaled = self.separate(X, demix_filter=W_scaled)\n\n        self.output, self.demix_filter = Y_scaled, W_scaled\n\n    def apply_minimal_distortion_principle(self) -> None:\n        r\"\"\"Apply minimal distortion principle to estimated spectrograms.\"\"\"\n        assert self.scale_restoration, \"Set self.scale_restoration=True.\"\n\n        X, W = self.input, self.demix_filter\n        Y = self.separate(X, demix_filter=W)\n        Y_scaled = minimal_distortion_principle(Y, reference=X, reference_id=self.reference_id)\n        X = X.transpose(1, 0, 2)\n        Y = Y_scaled.transpose(1, 0, 2)\n        X_Hermite = X.transpose(0, 2, 1).conj()\n        W_scaled = Y @ X_Hermite @ np.linalg.inv(X @ X_Hermite)\n\n        self.output, self.demix_filter = Y_scaled, W_scaled\n\n\nclass GradIVABase(IVABase):\n    r\"\"\"Base class of independent vector analysis (IVA) using gradient descent.\n\n    Args:\n        step_size (float):\n            A step size of the gradient descent. Default: ``1e-1``.\n        contrast_fn (callable):\n            A contrast function which corresponds to :math:`-\\log p(\\vec{\\boldsymbol{y}}_{jn})`.\n            This function is expected to receive (n_channels, n_bins, n_frames)\n            and return (n_channels, n_frames).\n        score_fn (callable):\n            A score function which corresponds to the partial derivative of the contrast function.\n            This function is expected to receive (n_channels, n_bins, n_frames)\n            and return (n_channels, n_bins, n_frames).\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        is_holonomic (bool):\n            If ``is_holonomic=True``, Holonomic-type update is used.\n            Otherwise, Nonholonomic-type update is used. Default: ``False``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back``\n            or ``minimal_distortion_principle``. Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the gradient descent if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back and minimal distortion principle. Default: ``0``.\n    \"\"\"\n\n    def __init__(\n        self,\n        step_size: float = 1e-1,\n        contrast_fn: Callable[[np.ndarray], np.ndarray] = None,\n        score_fn: Callable[[np.ndarray], np.ndarray] = None,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        callbacks: Optional[\n            Union[Callable[[\"GradIVABase\"], None], List[Callable[[\"GradIVABase\"], None]]]\n        ] = None,\n        is_holonomic: bool = False,\n        scale_restoration: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n    ) -> None:\n        super().__init__(\n            flooring_fn=flooring_fn,\n            callbacks=callbacks,\n            scale_restoration=scale_restoration,\n            record_loss=record_loss,\n            reference_id=reference_id,\n        )\n        self.step_size = step_size\n\n        if contrast_fn is None:\n            raise ValueError(\"Specify contrast function.\")\n        else:\n            self.contrast_fn = contrast_fn\n\n        if score_fn is None:\n            raise ValueError(\"Specify score function.\")\n        else:\n            self.score_fn = score_fn\n\n        self.is_holonomic = is_holonomic\n\n    def __call__(\n        self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs\n    ) -> np.ndarray:\n        r\"\"\"Separate a frequency-domain multichannel signal.\n\n        Args:\n            input (numpy.ndarray):\n                The mixture signal in frequency-domain. \\\n                The shape is (n_channels, n_bins, n_frames).\n            n_iter (int):\n                The number of iterations of demixing filter updates. \\\n                Default: ``100``.\n            initial_call (bool):\n                If ``True``, perform callbacks (and computation of loss if necessary)\n                before iterations.\n\n        Returns:\n            numpy.ndarray:\n                The separated signal in frequency-domain. \\\n                The shape is (n_channels, n_bins, n_frames).\n        \"\"\"\n        self.input = input.copy()\n\n        self._reset(**kwargs)\n\n        # Call __call__ of IVABase's parent, i.e. __call__ of IterativeMethodBase\n        super(IVABase, self).__call__(n_iter=n_iter, initial_call=initial_call)\n\n        if self.scale_restoration:\n            self.restore_scale()\n\n        self.output = self.separate(self.input, demix_filter=self.demix_filter)\n\n        return self.output\n\n    def __repr__(self) -> str:\n        s = \"GradIVA(\"\n        s += \"step_size={step_size}\"\n        s += \", is_holonomic={is_holonomic}\"\n        s += \", scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n\nclass FastIVABase(IVABase):\n    r\"\"\"Base class of fast independent vector analysis (FastIVA).\n\n    Args:\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back``\n            or ``minimal_distortion_principle``. Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the update algorithm if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back and minimal distortion principle. Default: ``0``.\n    \"\"\"\n\n    def __init__(\n        self,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        callbacks: Optional[\n            Union[Callable[[\"IVABase\"], None], List[Callable[[\"IVABase\"], None]]]\n        ] = None,\n        scale_restoration: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n    ) -> None:\n        super().__init__(\n            flooring_fn=flooring_fn,\n            callbacks=callbacks,\n            scale_restoration=scale_restoration,\n            record_loss=record_loss,\n            reference_id=reference_id,\n        )\n\n    def __repr__(self) -> str:\n        s = \"FastIVA(\"\n        s += \"scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def _reset(self, **kwargs) -> None:\n        super()._reset(**kwargs)\n\n        X, W = self.input, self.demix_filter\n\n        Z = whiten(X)\n\n        Y = self.separate(Z, demix_filter=W, use_whitening=False)\n\n        self.whitened_input = Z\n        self.output = Y\n\n    def separate(\n        self, input: np.ndarray, demix_filter: np.ndarray, use_whitening: bool = True\n    ) -> np.ndarray:\n        r\"\"\"Separate ``input`` using ``demixing_filter``.\n\n        .. math::\n            \\boldsymbol{y}_{ij}\n            = \\boldsymbol{W}_{i}\\boldsymbol{x}_{ij}\n\n        Args:\n            input (numpy.ndarray):\n                The mixture signal in frequency-domain.\n                The shape is (n_channels, n_bins, n_frames).\n            demix_filter (numpy.ndarray):\n                The demixing filters to separate ``input``.\n                The shape is (n_bins, n_sources, n_channels).\n            use_whitening (bool):\n                If ``use_whitening=True``, use_whitening (sphering) is applied to ``input``.\n                Default: True.\n\n        Returns:\n            numpy.ndarray of the separated signal in frequency-domain.\n            The shape is (n_sources, n_bins, n_frames).\n        \"\"\"\n        if use_whitening:\n            whitened_input = whiten(input)\n        else:\n            whitened_input = input\n\n        output = super().separate(whitened_input, demix_filter=demix_filter)\n\n        return output\n\n    def compute_loss(self) -> float:\n        r\"\"\"Compute loss :math:`\\mathcal{L}`.\n\n        :math:`\\mathcal{L}` is given as follows:\n\n        .. math::\n            \\mathcal{L} \\\n            &= \\frac{1}{J}\\sum_{j,n}G(\\vec{\\boldsymbol{y}}_{jn}), \\\\\n            G(\\vec{\\boldsymbol{y}}_{jn}) \\\n            &= - \\log p(\\vec{\\boldsymbol{y}}_{jn})\n\n        Returns:\n            Computed loss.\n        \"\"\"\n        Z, W = self.whitened_input, self.demix_filter\n        Y = self.separate(Z, demix_filter=W, use_whitening=False)  # (n_sources, n_bins, n_frames)\n\n        G = self.contrast_fn(Y)  # (n_sources, n_frames)\n        loss = np.sum(np.mean(G, axis=1), axis=0).item()\n\n        return loss\n\n    def apply_projection_back(self) -> None:\n        r\"\"\"Apply projection back technique to estimated spectrograms.\"\"\"\n        assert self.scale_restoration, \"Set self.scale_restoration=True.\"\n\n        reference_id = self.reference_id\n\n        X, Z = self.input, self.whitened_input\n        W = self.demix_filter\n\n        Y = self.separate(Z, demix_filter=W, use_whitening=False)\n        Y_scaled = projection_back(Y, reference=X, reference_id=reference_id)\n\n        Z = Z.transpose(1, 0, 2)\n        Z_Hermite = Z.transpose(0, 2, 1).conj()\n        ZZ_Hermite = Z @ Z_Hermite\n        W_scaled = Y_scaled.transpose(1, 0, 2) @ Z_Hermite @ np.linalg.inv(ZZ_Hermite)\n\n        self.output, self.demix_filter = Y_scaled, W_scaled\n\n\nclass AuxIVABase(IVABase):\n    r\"\"\"Base class of auxiliary-function-based independent vector analysis (IVA).\n\n    Args:\n        contrast_fn (callable):\n            A contrast function corresponds to :math:`-\\log p(\\vec{\\boldsymbol{y}}_{jn})`.\n            This function is expected to receive (n_channels, n_bins, n_frames)\n            and return (n_channels, n_frames).\n        d_contrast_fn (callable):\n            A derivative of the contrast function.\n            This function is expected to receive (n_channels, n_frames)\n            and return (n_channels, n_frames).\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back``\n            or ``minimal_distortion_principle``. Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the update algorithm if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back and minimal distortion principle. Default: ``0``.\n    \"\"\"\n\n    def __init__(\n        self,\n        contrast_fn: Callable[[np.ndarray], np.ndarray] = None,\n        d_contrast_fn: Callable[[np.ndarray], np.ndarray] = None,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        callbacks: Optional[\n            Union[Callable[[\"AuxIVABase\"], None], List[Callable[[\"AuxIVABase\"], None]]]\n        ] = None,\n        scale_restoration: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n    ) -> None:\n        super().__init__(\n            flooring_fn=flooring_fn,\n            callbacks=callbacks,\n            scale_restoration=scale_restoration,\n            record_loss=record_loss,\n            reference_id=reference_id,\n        )\n        self.contrast_fn = contrast_fn\n        self.d_contrast_fn = d_contrast_fn\n\n    def __call__(\n        self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs\n    ) -> np.ndarray:\n        r\"\"\"Separate a frequency-domain multichannel signal.\n\n        Args:\n            input (numpy.ndarray):\n                The mixture signal in frequency-domain.\n                The shape is (n_channels, n_bins, n_frames).\n            n_iter (int):\n                The number of iterations of demixing filter updates.\n                Default: ``100``.\n            initial_call (bool):\n                If ``True``, perform callbacks (and computation of loss if necessary)\n                before iterations.\n\n        Returns:\n            numpy.ndarray of the separated signal in frequency-domain.\n            The shape is (n_channels, n_bins, n_frames).\n        \"\"\"\n        return super().__call__(input, n_iter=n_iter, initial_call=initial_call, **kwargs)\n\n    def __repr__(self) -> str:\n        s = \"AuxIVA(\"\n        s += \"scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n\nclass GradIVA(GradIVABase):\n    r\"\"\"Independent vector analysis (IVA) [#kim2006independent]_ using gradient descent.\n\n    Args:\n        step_size (float):\n            A step size of the gradient descent. Default: ``1e-1``.\n        contrast_fn (callable):\n            A contrast function corresponds to :math:`-\\log p(\\vec{\\boldsymbol{y}}_{jn})`.\n            This function is expected to receive (n_channels, n_bins, n_frames)\n            and return (n_channels, n_frames).\n        score_fn (callable):\n            A score function corresponds to the partial derivative of the contrast function.\n            This function is expected to receive (n_channels, n_bins, n_frames)\n            and return (n_channels, n_bins, n_frames).\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        is_holonomic (bool):\n            If ``is_holonomic=True``, Holonomic-type update is used.\n            Otherwise, Nonholonomic-type update is used. Default: ``False``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back``\n            or ``minimal_distortion_principle``. Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the gradient descent if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back and minimal distortion principle. Default: ``0``.\n\n    Examples:\n        Update demixing filters using Holonomic-type update:\n\n        .. code-block:: python\n\n            >>> def contrast_fn(y):\n            ...     return 2 * np.linalg.norm(y, axis=1)\n\n            >>> def score_fn(y):\n            ...     norm = np.linalg.norm(y, axis=1, keepdims=True)\n            ...     return y / np.maximum(norm, 1e-10)\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> iva = GradIVA(\n            ...     contrast_fn=contrast_fn,\n            ...     score_fn=score_fn,\n            ...     is_holonomic=True,\n            ... )\n            >>> spectrogram_est = iva(spectrogram_mix, n_iter=5000)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n        Update demixing filters using Nonholonomic-type update:\n\n        .. code-block:: python\n\n            >>> def contrast_fn(y):\n            ...     return 2 * np.linalg.norm(y, axis=1)\n\n            >>> def score_fn(y):\n            ...     norm = np.linalg.norm(y, axis=1, keepdims=True)\n            ...     return y / np.maximum(norm, 1e-10)\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> iva = GradIVA(\n            ...     contrast_fn=contrast_fn,\n            ...     score_fn=score_fn,\n            ...     is_holonomic=False,\n            ... )\n            >>> spectrogram_est = iva(spectrogram_mix, n_iter=5000)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n    .. [#kim2006independent]\n        T. Kim, H. T. Attias, S.-Y. Lee, and T.-W. Lee,\n        \"Blind source separation exploiting higher-order frequency dependencies,\"\n        in *IEEE Trans. ASLP*, vol. 15, no. 1, pp. 70-79, 2007.\n    \"\"\"\n\n    def __init__(\n        self,\n        step_size: float = 1e-1,\n        contrast_fn: Callable[[np.ndarray], np.ndarray] = None,\n        score_fn: Callable[[np.ndarray], np.ndarray] = None,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        callbacks: Optional[\n            Union[Callable[[\"GradIVA\"], None], List[Callable[[\"GradIVA\"], None]]]\n        ] = None,\n        is_holonomic: bool = True,\n        scale_restoration: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n    ) -> None:\n        super().__init__(\n            step_size=step_size,\n            contrast_fn=contrast_fn,\n            score_fn=score_fn,\n            flooring_fn=flooring_fn,\n            callbacks=callbacks,\n            is_holonomic=is_holonomic,\n            scale_restoration=scale_restoration,\n            record_loss=record_loss,\n            reference_id=reference_id,\n        )\n\n    def update_once(self) -> None:\n        r\"\"\"Update demixing filters once using the gradient descent.\n\n        If ``is_holonomic=True``, demixing filters are updated as follows:\n\n        .. math::\n            \\boldsymbol{W}_{i}\n            \\leftarrow\\boldsymbol{W}_{i} - \\eta\\left(\\frac{1}{J}\\sum_{j} \\\n            \\boldsymbol{\\phi}_{i}(\\vec{\\boldsymbol{Y}}_{j})\\boldsymbol{y}_{ij}^{\\mathsf{H}} \\\n            -\\boldsymbol{I}\\right)\\boldsymbol{W}_{i}^{-\\mathsf{H}},\n\n        where\n\n        .. math::\n            \\boldsymbol{\\phi}_{i}(\\vec{\\boldsymbol{Y}}_{j})\n            &= \\left(\\phi_{i}(\\vec{\\boldsymbol{y}}_{j1}),\\ldots,\\\n            \\phi_{i}(\\vec{\\boldsymbol{y}}_{jn}),\\ldots,\\\n            \\phi_{i}(\\vec{\\boldsymbol{y}}_{jN}))\\\n            \\right)^{\\mathsf{T}}\\in\\mathbb{C}^{N}, \\\\\n            \\phi_{i}(\\vec{\\boldsymbol{y}}_{jn})\n            &= \\frac{\\partial G(\\vec{\\boldsymbol{y}}_{jn})}{\\partial y_{ijn}^{*}}, \\\\\n            G(\\vec{\\boldsymbol{y}}_{jn})\n            &= -\\log p(\\vec{\\boldsymbol{y}}_{jn}).\n\n        Otherwise (``is_holonomic=False``),\n\n        .. math::\n            \\boldsymbol{W}_{i}\n            \\leftarrow\\boldsymbol{W}_{i}\n            - \\eta\\cdot\\mathrm{offdiag}\\left(\\frac{1}{J}\\sum_{j}\n            \\boldsymbol{\\phi}_{i}(\\vec{\\boldsymbol{Y}}_{j})\\boldsymbol{y}_{ij}^{\\mathsf{H}}\\right)\n            \\boldsymbol{W}_{i}^{-\\mathsf{H}}.\n        \"\"\"\n        X, W = self.input, self.demix_filter\n        Y = self.separate(X, demix_filter=W)\n\n        Phi = self.score_fn(Y)\n        Y_conj = Y.conj()\n        PhiY = np.mean(Phi[:, np.newaxis, :, :] * Y_conj[np.newaxis, :, :, :], axis=-1)\n        PhiY = PhiY.transpose(2, 0, 1)  # (n_bins, n_sources, n_sources)\n        W_inv = np.linalg.inv(W)\n        W_inv_Hermite = W_inv.transpose(0, 2, 1).conj()\n        eye = np.eye(self.n_sources)\n\n        if self.is_holonomic:\n            delta = (PhiY - eye) @ W_inv_Hermite\n        else:\n            delta = ((1 - eye) * PhiY) @ W_inv_Hermite\n\n        W = W - self.step_size * delta\n\n        Y = self.separate(X, demix_filter=W)\n\n        self.demix_filter = W\n        self.output = Y\n\n\nclass NaturalGradIVA(GradIVABase):\n    r\"\"\"Independent vector analysis (IVA) using natural gradient descent.\n\n    Args:\n        step_size (float):\n            A step size of the gradient descent. Default: ``1e-1``.\n        contrast_fn (callable):\n            A contrast function corresponds to :math:`-\\log p(\\vec{\\boldsymbol{y}}_{jn})`.\n            This function is expected to receive (n_channels, n_bins, n_frames)\n            and return (n_channels, n_frames).\n        score_fn (callable):\n            A score function corresponds to the partial derivative of the contrast function.\n            This function is expected to receive (n_channels, n_bins, n_frames)\n            and return (n_channels, n_bins, n_frames).\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        is_holonomic (bool):\n            If ``is_holonomic=True``, Holonomic-type update is used.\n            Otherwise, Nonholonomic-type update is used. Default: ``False``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back``\n            or ``minimal_distortion_principle``. Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the update algorithm if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back and minimal distortion principle. Default: ``0``.\n\n    Examples:\n        Update demixing filters using Holonomic-type update:\n\n        .. code-block:: python\n\n            >>> def contrast_fn(y):\n            ...     return 2 * np.linalg.norm(y, axis=1)\n\n            >>> def score_fn(y):\n            ...     norm = np.linalg.norm(y, axis=1, keepdims=True)\n            ...     return y / np.maximum(norm, 1e-10)\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> iva = NaturalGradIVA(\n            ...     contrast_fn=contrast_fn,\n            ...     score_fn=score_fn,\n            ...     is_holonomic=True,\n            ... )\n            >>> spectrogram_est = iva(spectrogram_mix, n_iter=500)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n        Update demixing filters using Nonholonomic-type update:\n\n        .. code-block:: python\n\n            >>> def contrast_fn(y):\n            ...     return 2 * np.linalg.norm(y, axis=1)\n\n            >>> def score_fn(y):\n            ...     norm = np.linalg.norm(y, axis=1, keepdims=True)\n            ...     return y / np.maximum(norm, 1e-10)\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> iva = NaturalGradIVA(\n            ...     contrast_fn=contrast_fn,\n            ...     score_fn=score_fn,\n            ...     is_holonomic=False,\n            ... )\n            >>> spectrogram_est = iva(spectrogram_mix, n_iter=500)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n    \"\"\"\n\n    def __init__(\n        self,\n        step_size: float = 1e-1,\n        contrast_fn: Callable[[np.ndarray], np.ndarray] = None,\n        score_fn: Callable[[np.ndarray], np.ndarray] = None,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        callbacks: Optional[\n            Union[Callable[[\"NaturalGradIVA\"], None], List[Callable[[\"NaturalGradIVA\"], None]]]\n        ] = None,\n        is_holonomic: bool = True,\n        scale_restoration: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n    ) -> None:\n        super().__init__(\n            step_size=step_size,\n            contrast_fn=contrast_fn,\n            score_fn=score_fn,\n            flooring_fn=flooring_fn,\n            callbacks=callbacks,\n            is_holonomic=is_holonomic,\n            scale_restoration=scale_restoration,\n            record_loss=record_loss,\n            reference_id=reference_id,\n        )\n\n    def update_once(self) -> None:\n        r\"\"\"Update demixing filters once using the natural gradient descent.\n\n        If ``is_holonomic=True``, demixing filters are updated as follows:\n\n        .. math::\n            \\boldsymbol{W}_{i}\n            \\leftarrow\\boldsymbol{W}_{i} - \\eta\\left(\\frac{1}{J}\\sum_{j} \\\n            \\boldsymbol{\\phi}_{i}(\\vec{\\boldsymbol{Y}}_{j})\\boldsymbol{y}_{ij}^{\\mathsf{H}} \\\n            -\\boldsymbol{I}\\right)\\boldsymbol{W}_{i},\n\n        where\n\n        .. math::\n            \\boldsymbol{\\phi}_{i}(\\vec{\\boldsymbol{Y}}_{j})\n            &= \\left(\\phi_{i}(\\vec{\\boldsymbol{y}}_{j1}),\\ldots,\\\n            \\phi_{i}(\\vec{\\boldsymbol{y}}_{jn}),\\ldots,\\\n            \\phi_{i}(\\vec{\\boldsymbol{y}}_{jN}))\\\n            \\right)^{\\mathsf{T}}\\in\\mathbb{C}^{N}, \\\\\n            \\phi_{i}(\\vec{\\boldsymbol{y}}_{jn})\n            &= \\frac{\\partial G(\\vec{\\boldsymbol{y}}_{jn})}{\\partial y_{ijn}^{*}}, \\\\\n            G(\\vec{\\boldsymbol{y}}_{jn})\n            &= -\\log p(\\vec{\\boldsymbol{y}}_{jn}).\n\n        Otherwise (``is_holonomic=False``),\n\n        .. math::\n            \\boldsymbol{W}_{i}\n            \\leftarrow\\boldsymbol{W}_{i}\n            - \\eta\\cdot\\mathrm{offdiag}\\left(\\frac{1}{J}\\sum_{j}\n            \\boldsymbol{\\phi}_{i}(\\vec{\\boldsymbol{Y}}_{j})\\boldsymbol{y}_{ij}^{\\mathsf{H}}\\right)\n            \\boldsymbol{W}_{i}.\n        \"\"\"\n        X, W = self.input, self.demix_filter\n        Y = self.separate(X, demix_filter=W)\n\n        Phi = self.score_fn(Y)\n        Y_conj = Y.conj()\n        PhiY = np.mean(Phi[:, np.newaxis, :, :] * Y_conj[np.newaxis, :, :, :], axis=-1)\n        PhiY = PhiY.transpose(2, 0, 1)  # (n_bins, n_sources, n_sources)\n        eye = np.eye(self.n_sources)\n\n        if self.is_holonomic:\n            delta = (PhiY - eye) @ W\n        else:\n            delta = ((1 - eye) * PhiY) @ W\n\n        W = W - self.step_size * delta\n\n        Y = self.separate(X, demix_filter=W)\n\n        self.demix_filter = W\n        self.output = Y\n\n\nclass FastIVA(FastIVABase):\n    r\"\"\"Fast independent vector analysis (Fast IVA) [#lee2007fast]_.\n\n    Args:\n        contrast_fn (callable):\n            A contrast function which corresponds to :math:`-\\log p(\\vec{\\boldsymbol{y}}_{jn})`.\n            This function is expected to receive (n_channels, n_bins, n_frames)\n            and return (n_channels, n_frames).\n        d_contrast_fn (callable):\n            A derivative of the contrast function.\n            This function is expected to receive (n_channels, n_frames)\n            and return (n_channels, n_frames).\n        dd_contrast_fn (callable):\n            Second order derivative of the contrast function.\n            This function is expected to receive (n_channels, n_frames)\n            and return (n_channels, n_frames).\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back``\n            or ``minimal_distortion_principle``. Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the update algorithm if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back and minimal distortion principle. Default: ``0``.\n\n    Examples:\n        .. code-block:: python\n\n            >>> from ssspy.transform import whiten\n            >>> from ssspy.algorithm import projection_back\n\n            >>> def contrast_fn(y):\n            ...     return 2 * np.linalg.norm(y, axis=1)\n\n            >>> def d_contrast_fn(y):\n            ...     return 2 * np.ones_like(y)\n\n            >>> def dd_contrast_fn(y):\n            ...     return 2 * np.zeros_like(y)\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> iva = FastIVA(\n            ...     contrast_fn=contrast_fn,\n            ...     d_contrast_fn=d_contrast_fn,\n            ...     dd_contrast_fn=dd_contrast_fn,\n            ...     scale_restoration=False,\n            ... )\n            >>> spectrogram_mix_whitened = whiten(spectrogram_mix)\n            >>> spectrogram_est = iva(spectrogram_mix_whitened, n_iter=100)\n            >>> spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n    .. [#lee2007fast] I. Lee et al.,\n        \"Fast fixed-point independent vector analysis algorithms \\\n        for convolutive blind source separation,\" *Signal Processing*,\n        vol. 87, no. 8, pp. 1859-1871, 2007.\n    \"\"\"\n\n    def __init__(\n        self,\n        contrast_fn: Callable[[np.ndarray], np.ndarray] = None,\n        d_contrast_fn: Callable[[np.ndarray], np.ndarray] = None,\n        dd_contrast_fn: Callable[[np.ndarray], np.ndarray] = None,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        callbacks: Optional[\n            Union[Callable[[\"FastIVA\"], None], List[Callable[[\"FastIVA\"], None]]]\n        ] = None,\n        scale_restoration: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n    ) -> None:\n        super().__init__(\n            flooring_fn=flooring_fn,\n            callbacks=callbacks,\n            scale_restoration=scale_restoration,\n            record_loss=record_loss,\n            reference_id=reference_id,\n        )\n\n        if contrast_fn is None:\n            raise ValueError(\"Specify contrast function.\")\n        else:\n            self.contrast_fn = contrast_fn\n\n        if d_contrast_fn is None:\n            raise ValueError(\"Specify derivative of contrast function.\")\n        else:\n            self.d_contrast_fn = d_contrast_fn\n\n        if dd_contrast_fn is None:\n            raise ValueError(\"Specify second order derivative of contrast function.\")\n        else:\n            self.dd_contrast_fn = dd_contrast_fn\n\n    def __call__(\n        self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs\n    ) -> np.ndarray:\n        r\"\"\"Separate a frequency-domain multichannel signal.\n\n        Args:\n            input (numpy.ndarray):\n                The mixture signal in frequency-domain.\n                The shape is (n_channels, n_bins, n_frames).\n            n_iter (int):\n                The number of iterations of demixing filter updates.\n                Default: ``100``.\n            initial_call (bool):\n                If ``True``, perform callbacks (and computation of loss if necessary)\n                before iterations.\n\n        Returns:\n            numpy.ndarray of the separated signal in frequency-domain.\n            The shape is (n_channels, n_bins, n_frames).\n        \"\"\"\n        self.input = input.copy()\n\n        self._reset(**kwargs)\n\n        # Call __call__ of IVABase's parent, i.e. __call__ of IterativeMethodBase\n        super(IVABase, self).__call__(n_iter=n_iter, initial_call=initial_call)\n\n        if self.scale_restoration:\n            self.restore_scale()\n\n        self.output = self.separate(\n            self.whitened_input, demix_filter=self.demix_filter, use_whitening=False\n        )\n\n        return self.output\n\n    def __repr__(self) -> str:\n        s = \"FastIVA(\"\n        s += \"scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def update_once(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update demixing filters once.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        Demixing filters are updated as follows:\n\n        .. math::\n            \\boldsymbol{w}_{in}\n            \\leftarrow&\\frac{1}{J}\\sum_{j}\n            \\frac{G'_{\\mathbb{R}}(\\|\\vec{\\boldsymbol{y}}_{jn}\\|_{2})}\n            {2\\|\\vec{\\boldsymbol{y}}_{jn}\\|_{2}}\n            \\left(\\boldsymbol{w}_{in}-y_{ijn}^{*}\\boldsymbol{x}_{ij}\\right) \\notag \\\\\n            &-\\frac{1}{J}\\sum_{j}\\frac{|y_{ijn}|^{2}}{2\\|\\vec{\\boldsymbol{y}}_{jn}\\|_{2}}\\left(\n            \\frac{G'_{\\mathbb{R}}(\\|\\vec{\\boldsymbol{y}}_{jn}\\|_{2})}\n            {\\|\\vec{\\boldsymbol{y}}_{jn}\\|_{2}}\n            - G''_{\\mathbb{R}}(\\|\\vec{\\boldsymbol{y}}_{jn}\\|_{2})\n            \\right)\\boldsymbol{w}_{in} \\\\\n            \\boldsymbol{W}_{i}\n            \\leftarrow&\\left(\\boldsymbol{W}_{i}\\boldsymbol{W}_{i}^{\\mathsf{H}}\\right)^{-\\frac{1}{2}}\n            \\boldsymbol{W}_{i}.\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        Z, W = self.whitened_input, self.demix_filter\n        Y = self.separate(Z, demix_filter=W, use_whitening=False)\n\n        norm = np.linalg.norm(Y, axis=1)\n        varphi = self.d_contrast_fn(norm) / flooring_fn(2 * norm)  # (n_sources, n_frames)\n\n        Y_conj = Y.conj()\n        YZ = Y_conj[:, np.newaxis, :, :] * Z\n        W_Hermite = W.transpose(1, 2, 0).conj()\n        W_YZ = W_Hermite[:, :, :, np.newaxis] - YZ\n        W_YZ = np.mean(varphi[:, np.newaxis, np.newaxis, :] * W_YZ, axis=-1)\n\n        Y_GG = (2 * varphi - self.dd_contrast_fn(norm)) / flooring_fn(2 * norm)\n        YY_GG = Y_GG[:, np.newaxis, :] * (np.abs(Y) ** 2)\n        YY_GGW = np.mean(W_Hermite[:, :, :, np.newaxis] * YY_GG[:, np.newaxis, :, :], axis=-1)\n\n        # Update\n        W_Hermite = W_YZ - YY_GGW\n        W = W_Hermite.transpose(2, 0, 1).conj()\n\n        u, _, v_Hermite = np.linalg.svd(W)\n        W = u @ v_Hermite\n\n        self.demix_filter = W\n\n\nclass FasterIVA(FastIVABase):\n    r\"\"\"Faster independent vector analysis (Faster IVA) [#brendel2021fasteriva]_.\n\n    Args:\n        contrast_fn (callable):\n            A contrast function which corresponds to :math:`-\\log p(\\vec{\\boldsymbol{y}}_{jn})`.\n            This function is expected to receive (n_channels, n_bins, n_frames)\n            and return (n_channels, n_frames).\n        d_contrast_fn (callable):\n            A derivative of the contrast function.\n            This function is expected to receive (n_channels, n_frames)\n            and return (n_channels, n_frames).\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back``\n            or ``minimal_distortion_principle``. Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the update algorithm if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back and minimal distortion principle. Default: ``0``.\n\n    Examples:\n        .. code-block:: python\n\n            >>> from ssspy.transform import whiten\n            >>> from ssspy.algorithm import projection_back\n\n            >>> def contrast_fn(y):\n            ...     return 2 * np.linalg.norm(y, axis=1)\n\n            >>> def d_contrast_fn(y):\n            ...     return 2 * np.ones_like(y)\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> iva = FasterIVA(\n            ...     contrast_fn=contrast_fn,\n            ...     d_contrast_fn=d_contrast_fn,\n            ...     scale_restoration=False,\n            ... )\n            >>> spectrogram_mix_whitened = whiten(spectrogram_mix)\n            >>> spectrogram_est = iva(spectrogram_mix_whitened, n_iter=100)\n            >>> spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n    .. [#brendel2021fasteriva] A. Brendel and W. Kellermann,\n        \"Faster IVA: Update rules for independent vector analysis based on negentropy \\\n        and the majorize-minimize principle,\"\n        in *Proc. WASPAA*, pp. 131-135, 2021.\n    \"\"\"\n\n    def __init__(\n        self,\n        contrast_fn: Callable[[np.ndarray], np.ndarray] = None,\n        d_contrast_fn: Callable[[np.ndarray], np.ndarray] = None,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        callbacks: Optional[\n            Union[Callable[[\"FasterIVA\"], None], List[Callable[[\"FasterIVA\"], None]]]\n        ] = None,\n        scale_restoration: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n    ) -> None:\n        super().__init__(\n            flooring_fn=flooring_fn,\n            callbacks=callbacks,\n            scale_restoration=scale_restoration,\n            record_loss=record_loss,\n            reference_id=reference_id,\n        )\n        if contrast_fn is None:\n            raise ValueError(\"Specify contrast function.\")\n        else:\n            self.contrast_fn = contrast_fn\n\n        if d_contrast_fn is None:\n            raise ValueError(\"Specify derivative of contrast function.\")\n        else:\n            self.d_contrast_fn = d_contrast_fn\n\n    def __call__(\n        self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs\n    ) -> np.ndarray:\n        r\"\"\"Separate a frequency-domain multichannel signal.\n\n        Args:\n            input (numpy.ndarray):\n                The mixture signal in frequency-domain.\n                The shape is (n_channels, n_bins, n_frames).\n            n_iter (int):\n                The number of iterations of demixing filter updates.\n                Default: ``100``.\n            initial_call (bool):\n                If ``True``, perform callbacks (and computation of loss if necessary)\n                before iterations.\n\n        Returns:\n            numpy.ndarray of the separated signal in frequency-domain.\n            The shape is (n_channels, n_bins, n_frames).\n        \"\"\"\n        self.input = input.copy()\n\n        self._reset(**kwargs)\n\n        # Call __call__ of IVABase's parent, i.e. __call__ of IterativeMethodBase\n        super(IVABase, self).__call__(n_iter=n_iter, initial_call=initial_call)\n\n        if self.scale_restoration:\n            self.restore_scale()\n\n        self.output = self.separate(\n            self.whitened_input, demix_filter=self.demix_filter, use_whitening=False\n        )\n\n        return self.output\n\n    def __repr__(self) -> str:\n        s = \"FasterIVA(\"\n        s += \"scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def update_once(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update demixing filters once.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        In FasterIVA, we compute the eigenvector of :math:`\\boldsymbol{U}_{in}`\n        which corresponds to the largest eigenvalue by solving\n\n        .. math::\n            \\boldsymbol{U}_{in}\\boldsymbol{w}_{in}\n            = \\lambda_{in}\\boldsymbol{w}_{in}.\n\n        Then,\n\n        .. math::\n            \\boldsymbol{W}_{i}\n            \\leftarrow\\left(\\boldsymbol{W}_{i}\\boldsymbol{W}_{i}^{\\mathsf{H}}\\right)^{-\\frac{1}{2}}\n            \\boldsymbol{W}_{i}.\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        Z, W = self.whitened_input, self.demix_filter\n        Y = self.separate(Z, demix_filter=W, use_whitening=False)\n\n        ZZ_Hermite = Z[:, np.newaxis, :, :] * Z[np.newaxis, :, :, :].conj()\n        ZZ_Hermite = ZZ_Hermite.transpose(2, 0, 1, 3)  # (n_bins, n_channels, n_channels, n_frames)\n        norm = np.linalg.norm(Y, axis=1)\n        varphi = self.d_contrast_fn(norm) / flooring_fn(2 * norm)  # (n_sources, n_frames)\n        varphi_ZZ = varphi[:, np.newaxis, np.newaxis, :] * ZZ_Hermite[:, np.newaxis, :, :, :]\n        U = np.mean(varphi_ZZ, axis=-1)  # (n_bins, n_sources, n_channels, n_channels)\n\n        _, w = eigh(U)  # (n_bins, n_sources, n_channels, n_channels)\n        W = w[..., -1].conj()  # eigenvector that corresponds to largest eigenvalue\n        u, _, v_Hermite = np.linalg.svd(W)\n        W = u @ v_Hermite\n\n        self.demix_filter = W\n\n\nclass AuxIVA(AuxIVABase):\n    r\"\"\"Auxiliary-function-based independent vector analysis (IVA) [#ono2011stable]_.\n\n    Args:\n        spatial_algorithm (str):\n            Algorithm for demixing filter updates.\n            Choose ``IP``, ``IP1``, ``IP2``, ``ISS``, ``ISS1``, ``ISS2``, or ``IPA``.\n            Default: ``IP``.\n        contrast_fn (callable):\n            A contrast function which corresponds to :math:`-\\log p(\\vec{\\boldsymbol{y}}_{jn})`.\n            This function is expected to receive (n_channels, n_bins, n_frames)\n            and return (n_channels, n_frames).\n        d_contrast_fn (callable):\n            A derivative of the contrast function.\n            This function is expected to receive (n_channels, n_frames)\n            and return (n_channels, n_frames).\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n        pair_selector (callable, optional):\n            Selector to choose updaing pair in ``IP2`` and ``ISS2``.\n            If ``None`` is given, ``sequential_pair_selector`` is used.\n            Default: ``None``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back``\n            or ``minimal_distortion_principle``. Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the demixing filter update if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back and minimal distortion principle. Default: ``0``.\n        lqpqm_normalization (bool):\n            This keyword argument can be specified when ``spatial_algorithm='IPA'``.\n            If ``True``, normalization by trace is applied to positive semi-definite matrix\n            in LQPQM. Default: ``True``.\n        newton_iter (int):\n            This keyword argument can be specified when ``spatial_algorithm='IPA'``.\n            Number of iterations in Newton method. Default: ``1``.\n\n    Examples:\n        Update demixing filters by IP:\n\n        .. code-block:: python\n\n            >>> def contrast_fn(y):\n            ...     return 2 * np.linalg.norm(y, axis=1)\n\n            >>> def d_contrast_fn(y):\n            ...     return 2 * np.ones_like(y)\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> iva = AuxIVA(\n            ...     spatial_algorithm=\"IP\",\n            ...     contrast_fn=contrast_fn,\n            ...     d_contrast_fn=d_contrast_fn,\n            ... )\n            >>> spectrogram_est = iva(spectrogram_mix, n_iter=100)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n        Update demixing filters by IP2:\n\n        .. code-block:: python\n\n            >>> from ssspy.utils.select_pair import sequential_pair_selector\n\n            >>> def contrast_fn(y):\n            ...     return 2 * np.linalg.norm(y, axis=1)\n\n            >>> def d_contrast_fn(y):\n            ...     return 2 * np.ones_like(y)\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> iva = AuxIVA(\n            ...     spatial_algorithm=\"IP2\",\n            ...     contrast_fn=contrast_fn,\n            ...     d_contrast_fn=d_contrast_fn,\n            ...     pair_selector=sequential_pair_selector,\n            ... )\n            >>> spectrogram_est = iva(spectrogram_mix, n_iter=100)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n        Update demixing filters by ISS:\n\n        .. code-block:: python\n\n            >>> def contrast_fn(y):\n            ...     return 2 * np.linalg.norm(y, axis=1)\n\n            >>> def d_contrast_fn(y):\n            ...     return 2 * np.ones_like(y)\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> iva = AuxIVA(\n            ...     spatial_algorithm=\"ISS\",\n            ...     contrast_fn=contrast_fn,\n            ...     d_contrast_fn=d_contrast_fn,\n            ... )\n            >>> spectrogram_est = iva(spectrogram_mix, n_iter=100)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n        Update demixing filters by ISS2:\n\n        .. code-block:: python\n\n            >>> import functools\n            >>> from ssspy.utils.select_pair import sequential_pair_selector\n\n            >>> def contrast_fn(y):\n            ...     return 2 * np.linalg.norm(y, axis=1)\n\n            >>> def d_contrast_fn(y):\n            ...     return 2 * np.ones_like(y)\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> iva = AuxIVA(\n            ...     spatial_algorithm=\"ISS2\",\n            ...     contrast_fn=contrast_fn,\n            ...     d_contrast_fn=d_contrast_fn,\n            ...     pair_selector=functools.partial(sequential_pair_selector, step=2),\n            ... )\n            >>> spectrogram_est = iva(spectrogram_mix, n_iter=100)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n        Update demixing filters by IPA:\n\n        .. code-block:: python\n\n            >>> def contrast_fn(y):\n            ...     return 2 * np.linalg.norm(y, axis=1)\n\n            >>> def d_contrast_fn(y):\n            ...     return 2 * np.ones_like(y)\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> iva = AuxIVA(\n            ...     spatial_algorithm=\"IPA\",\n            ...     contrast_fn=contrast_fn,\n            ...     d_contrast_fn=d_contrast_fn,\n            ... )\n            >>> spectrogram_est = iva(spectrogram_mix, n_iter=100)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n    .. [#ono2011stable]\n        N. Ono,\n        \"Stable and fast update rules for independent vector analysis based on \\\n        auxiliary function technique,\"\n        in *Proc. WASPAA*, 2011, p.189-192.\n    \"\"\"\n\n    _ipa_default_kwargs = {\"lqpqm_normalization\": True, \"newton_iter\": 1}\n    _default_kwargs = _ipa_default_kwargs\n\n    def __init__(\n        self,\n        spatial_algorithm: str = \"IP\",\n        contrast_fn: Callable[[np.ndarray], np.ndarray] = None,\n        d_contrast_fn: Callable[[np.ndarray], np.ndarray] = None,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        pair_selector: Optional[Callable[[int], Iterable[Tuple[int, int]]]] = None,\n        callbacks: Optional[\n            Union[Callable[[\"AuxIVA\"], None], List[Callable[[\"AuxIVA\"], None]]]\n        ] = None,\n        scale_restoration: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n        **kwargs,\n    ) -> None:\n        super().__init__(\n            contrast_fn=contrast_fn,\n            d_contrast_fn=d_contrast_fn,\n            flooring_fn=flooring_fn,\n            callbacks=callbacks,\n            scale_restoration=scale_restoration,\n            record_loss=record_loss,\n            reference_id=reference_id,\n        )\n\n        assert spatial_algorithm in spatial_algorithms, \"Not support {}.\".format(spatial_algorithm)\n\n        self.spatial_algorithm = spatial_algorithm\n\n        if pair_selector is None:\n            if spatial_algorithm in [\"IP2\", \"ISS2\"]:\n                self.pair_selector = sequential_pair_selector\n        else:\n            self.pair_selector = pair_selector\n\n        if spatial_algorithm == \"IPA\":\n            valid_keys = set(self.__class__._ipa_default_kwargs.keys())\n        else:\n            valid_keys = set()\n\n        invalid_keys = set(kwargs) - valid_keys\n\n        assert invalid_keys == set(), \"Invalid keywords {} are given.\".format(invalid_keys)\n\n        for key, value in kwargs.items():\n            setattr(self, key, value)\n\n        # set default values if necessary\n        for key in valid_keys:\n            if not hasattr(self, key):\n                value = self.__class__._default_kwargs[key]\n                setattr(self, key, value)\n\n    def __call__(\n        self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs\n    ) -> np.ndarray:\n        r\"\"\"Separate a frequency-domain multichannel signal.\n\n        Args:\n            input (numpy.ndarray):\n                The mixture signal in frequency-domain.\n                The shape is (n_channels, n_bins, n_frames).\n            n_iter (int):\n                The number of iterations of demixing filter updates.\n                Default: ``100``.\n            initial_call (bool):\n                If ``True``, perform callbacks (and computation of loss if necessary)\n                before iterations.\n\n        Returns:\n            numpy.ndarray of the separated signal in frequency-domain.\n            The shape is (n_channels, n_bins, n_frames).\n        \"\"\"\n        self.input = input.copy()\n\n        self._reset(**kwargs)\n\n        # Call __call__ of IVABase's parent, i.e. __call__ of IterativeMethodBase\n        super(IVABase, self).__call__(n_iter=n_iter, initial_call=initial_call)\n\n        if self.scale_restoration:\n            self.restore_scale()\n\n        if self.demix_filter is None:\n            pass\n        else:\n            self.output = self.separate(self.input, demix_filter=self.demix_filter)\n\n        return self.output\n\n    def __repr__(self) -> str:\n        s = \"AuxIVA(\"\n        s += \"spatial_algorithm={spatial_algorithm}\"\n        s += \", scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def _reset(self, **kwargs) -> None:\n        r\"\"\"Reset attributes by given keyword arguments.\n\n        Args:\n            kwargs:\n                Keyword arguments to set as attributes of IVA.\n        \"\"\"\n        super()._reset(**kwargs)\n\n        if self.spatial_algorithm in [\"ISS\", \"ISS1\", \"ISS2\", \"IPA\"]:\n            self.demix_filter = None\n\n    def update_once(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update demixing filters once.\n\n        - If ``self.spatial_algorithm`` is ``IP`` or ``IP1``, ``update_once_ip1`` is called.\n        - If ``self.spatial_algorithm`` is ``IP2``, ``update_once_ip2`` is called.\n        - If ``self.spatial_algorithm`` is ``ISS`` or ``ISS1``, ``update_once_iss1`` is called.\n        - If ``self.spatial_algorithm`` is ``ISS2``, ``update_once_iss2`` is called.\n        - If ``self.spatial_algorithm`` is ``IPA``, ``update_once_ipa`` is called.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if self.spatial_algorithm in [\"IP\", \"IP1\"]:\n            self.update_once_ip1(flooring_fn=flooring_fn)\n        elif self.spatial_algorithm in [\"IP2\"]:\n            self.update_once_ip2(flooring_fn=flooring_fn)\n        elif self.spatial_algorithm in [\"ISS\", \"ISS1\"]:\n            self.update_once_iss1(flooring_fn=flooring_fn)\n        elif self.spatial_algorithm in [\"ISS2\"]:\n            self.update_once_iss2(flooring_fn=flooring_fn)\n        elif self.spatial_algorithm in [\"IPA\"]:\n            self.update_once_ipa(flooring_fn=flooring_fn)\n        else:\n            raise NotImplementedError(\"Not support {}.\".format(self.spatial_algorithm))\n\n    def update_once_ip1(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update demixing filters once using iterative projection.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        Compute auxiliary variables:\n\n        .. math::\n            \\bar{r}_{jn}\n            \\leftarrow\\|\\vec{\\boldsymbol{y}}_{jn}\\|_{2}\n\n        Then, demixing filters are updated sequentially for :math:`n=1,\\ldots,N` as follows:\n\n        .. math::\n            \\boldsymbol{w}_{in}\n            &\\leftarrow\\left(\\boldsymbol{W}_{in}^{\\mathsf{H}}\\boldsymbol{U}_{in}\\right)^{-1} \\\n            \\boldsymbol{e}_{n}, \\\\\n            \\boldsymbol{w}_{in}\n            &\\leftarrow\\frac{\\boldsymbol{w}_{in}}\n            {\\sqrt{\\boldsymbol{w}_{in}^{\\mathsf{H}}\\boldsymbol{U}_{in}\\boldsymbol{w}_{in}}}, \\\\\n\n        where\n\n        .. math::\n            \\boldsymbol{U}_{in}\n            &= \\frac{1}{J}\\sum_{j}\n            \\varphi(\\bar{r}_{jn})\\boldsymbol{x}_{ij}\\boldsymbol{x}_{ij}^{\\mathsf{H}}, \\\\\n            \\varphi(\\bar{r}_{jn})\n            &= \\frac{G'_{\\mathbb{R}}(\\bar{r}_{jn})}{2\\bar{r}_{jn}}, \\\\\n            G(\\vec{\\boldsymbol{y}}_{jn})\n            &= -\\log p(\\vec{\\boldsymbol{y}}_{jn}), \\\\\n            G_{\\mathbb{R}}(\\|\\vec{\\boldsymbol{y}}_{jn}\\|_{2})\n            &= G(\\vec{\\boldsymbol{y}}_{jn}).\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        X, W = self.input, self.demix_filter\n        Y = self.separate(X, demix_filter=W)\n\n        XX_Hermite = X[:, np.newaxis, :, :] * X[np.newaxis, :, :, :].conj()\n        XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3)  # (n_bins, n_channels, n_channels, n_frames)\n        norm = np.linalg.norm(Y, axis=1)\n        denom = flooring_fn(2 * norm)\n        weight = self.d_contrast_fn(norm) / denom  # (n_sources, n_frames)\n        GXX = weight[:, np.newaxis, np.newaxis, :] * XX_Hermite[:, np.newaxis, :, :, :]\n        U = np.mean(GXX, axis=-1)  # (n_bins, n_sources, n_channels, n_channels)\n\n        self.demix_filter = update_by_ip1(W, U, flooring_fn=flooring_fn)\n\n    def update_once_ip2(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update demixing filters once using pairwise iterative projection.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        For :math:`n_{1}` and :math:`n_{2}` (:math:`n_{1}\\neq n_{2}`),\n        compute auxiliary variables:\n\n        .. math::\n            \\bar{r}_{jn_{1}}\n            &\\leftarrow\\|\\vec{\\boldsymbol{y}}_{jn_{1}}\\|_{2} \\\\\n            \\bar{r}_{jn_{2}}\n            &\\leftarrow\\|\\vec{\\boldsymbol{y}}_{jn_{2}}\\|_{2}\n\n        Then, for :math:`n=n_{1},n_{2}`, compute weighted covariance matrix as follows:\n\n        .. math::\n            \\boldsymbol{U}_{in_{1}}\n            &= \\frac{1}{J}\\sum_{j}\n            \\varphi(\\bar{r}_{jn_{1}})\\boldsymbol{x}_{ij}\\boldsymbol{x}_{ij}^{\\mathsf{H}}, \\\\\n            \\boldsymbol{U}_{in_{2}}\n            &= \\frac{1}{J}\\sum_{j}\n            \\varphi(\\bar{r}_{jn_{2}})\\boldsymbol{x}_{ij}\\boldsymbol{x}_{ij}^{\\mathsf{H}},\n\n        where\n\n        .. math::\n            \\varphi(\\bar{r}_{jn})\n            = \\frac{G'_{\\mathbb{R}}(\\bar{r}_{jn})}{2\\bar{r}_{jn}}.\n\n        Using :math:`\\boldsymbol{U}_{in_{1}}` and\n        :math:`\\boldsymbol{U}_{in_{2}}`, we compute generalized eigenvectors.\n\n        .. math::\n            \\left({\\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}}^{\\mathsf{H}}\\boldsymbol{U}_{in_{1}}\n            \\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\\right)\\boldsymbol{h}_{i}\n            = \\lambda_{i}\n            \\left({\\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}}^{\\mathsf{H}}\\boldsymbol{U}_{in_{2}}\n            \\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\\right)\\boldsymbol{h}_{i},\n\n        where\n\n        .. math::\n            \\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\n            &= (\\boldsymbol{W}_{i}\\boldsymbol{U}_{in_{1}})^{-1}\n            (\n            \\begin{array}{cc}\n                \\boldsymbol{e}_{n_{1}} & \\boldsymbol{e}_{n_{2}}\n            \\end{array}\n            ), \\\\\n            \\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\n            &= (\\boldsymbol{W}_{i}\\boldsymbol{U}_{in_{2}})^{-1}\n            (\n            \\begin{array}{cc}\n                \\boldsymbol{e}_{n_{1}} & \\boldsymbol{e}_{n_{2}}\n            \\end{array}\n            ).\n\n        After that, we standardize two eigenvectors :math:`\\boldsymbol{h}_{in_{1}}`\n        and :math:`\\boldsymbol{h}_{in_{2}}`.\n\n        .. math::\n            \\boldsymbol{h}_{in_{1}}\n            &\\leftarrow\\frac{\\boldsymbol{h}_{in_{1}}}\n            {\\sqrt{\\boldsymbol{h}_{in_{1}}^{\\mathsf{H}}\n            \\left({\\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}}^{\\mathsf{H}}\\boldsymbol{U}_{in_{1}}\n            \\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\\right)\n            \\boldsymbol{h}_{in_{1}}}}, \\\\\n            \\boldsymbol{h}_{in_{2}}\n            &\\leftarrow\\frac{\\boldsymbol{h}_{in_{2}}}\n            {\\sqrt{\\boldsymbol{h}_{in_{2}}^{\\mathsf{H}}\n            \\left({\\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}}^{\\mathsf{H}}\\boldsymbol{U}_{in_{2}}\n            \\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\\right)\n            \\boldsymbol{h}_{in_{2}}}}.\n\n        Then, update :math:`\\boldsymbol{w}_{in_{1}}` and :math:`\\boldsymbol{w}_{in_{2}}`\n        simultaneously.\n\n        .. math::\n            \\boldsymbol{w}_{in_{1}}\n            &\\leftarrow \\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\\boldsymbol{h}_{in_{1}} \\\\\n            \\boldsymbol{w}_{in_{2}}\n            &\\leftarrow \\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\\boldsymbol{h}_{in_{2}}.\n\n        At each iteration, we update pairs of :math:`n_{1}` and :math:`n_{1}`\n        for :math:`n_{1}\\neq n_{2}`.\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        n_sources = self.n_sources\n        X, W = self.input, self.demix_filter\n        XX_Hermite = X[:, np.newaxis, :, :] * X[np.newaxis, :, :, :].conj()\n        XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3)\n\n        for m, n in self.pair_selector(n_sources):\n            W_mn = W[:, (m, n), :]\n            Y_mn = self.separate(X, demix_filter=W_mn)\n\n            norm = np.linalg.norm(Y_mn, axis=1)\n            weight = self.d_contrast_fn(norm) / flooring_fn(2 * norm)\n            GXX_mn = weight[:, np.newaxis, np.newaxis, :] * XX_Hermite[:, np.newaxis, :, :, :]\n            U_mn = np.mean(GXX_mn, axis=-1)\n\n            W[:, (m, n), :] = update_by_ip2_one_pair(\n                W,\n                U_mn,\n                pair=(m, n),\n                flooring_fn=flooring_fn,\n            )\n\n        self.demix_filter = W\n\n    def update_once_iss1(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update estimated spectrograms once using \\\n        iterative source steering [#scheibler2020fast]_.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        First, update auxiliary variables\n\n        .. math::\n            \\bar{r}_{jn}\n            \\leftarrow\\|\\vec{\\boldsymbol{y}}_{jn}\\|_{2}.\n\n        Then, update :math:`y_{ijn}` as follows:\n\n        .. math::\n            \\boldsymbol{y}_{ij}\n            & \\leftarrow\\boldsymbol{y}_{ij} - \\boldsymbol{d}_{in}y_{ijn}, \\\\\n            d_{inn'}\n            &= \\begin{cases}\n                \\dfrac{\\sum_{j}\\dfrac{G'_{\\mathbb{R}}(\\bar{r}_{jn'})}{2\\bar{r}_{jn'}}\n                y_{ijn'}y_{ijn}^{*}}{\\sum_{j}\\dfrac{G'_{\\mathbb{R}}(\\bar{r}_{jn'})}\n                {2\\bar{r}_{jn'}}|y_{ijn}|^{2}}\n                & (n'\\neq n) \\\\\n                1 - \\dfrac{1}{\\sqrt{\\dfrac{1}{J}\\sum_{j}\\dfrac{G'_{\\mathbb{R}}(\\bar{r}_{jn'})}\n                {2\\bar{r}_{jn'}}\n                |y_{ijn}|^{2}}} & (n'=n)\n            \\end{cases}.\n\n        .. [#scheibler2020fast] R. Scheibler and N. Ono,\n            \"Fast and stable blind source separation with rank-1 updates,\"\n            in *Proc. ICASSP*, 2020, pp. 236-240.\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        Y = self.output\n        r = np.linalg.norm(Y, axis=1)\n        denom = flooring_fn(2 * r)\n        varphi = self.d_contrast_fn(r) / denom  # (n_sources, n_frames)\n\n        self.output = update_by_iss1(Y, varphi[:, np.newaxis, :], flooring_fn=flooring_fn)\n\n    def update_once_iss2(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update estimated spectrograms once using \\\n        pairwise iterative source steering [#ikeshita2022iss2]_.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        First, we compute auxiliary variables:\n\n        .. math::\n            \\bar{r}_{jn}\n            \\leftarrow\\|\\vec{\\boldsymbol{y}}_{jn}\\|_{2},\n\n        where\n\n        .. math::\n            G(\\vec{\\boldsymbol{y}}_{jn})\n            &= -\\log p(\\vec{\\boldsymbol{y}}_{jn}), \\\\\n            G_{\\mathbb{R}}(\\|\\vec{\\boldsymbol{y}}_{jn}\\|_{2})\n            &= G(\\vec{\\boldsymbol{y}}_{jn}).\n\n        Then, we compute :math:`\\boldsymbol{G}_{in}^{(n_{1},n_{2})}` \\\n        and :math:`\\boldsymbol{f}_{in}^{(n_{1},n_{2})}` for :math:`n_{1}\\neq n_{2}`:\n\n        .. math::\n            \\begin{array}{rclc}\n                \\boldsymbol{G}_{in}^{(n_{1},n_{2})}\n                &=& {\\displaystyle\\frac{1}{J}\\sum_{j}}\\varphi(\\bar{r}_{jn})\n                \\boldsymbol{y}_{ij}^{(n_{1},n_{2})}{\\boldsymbol{y}_{ij}^{(n_{1},n_{2})}}^{\\mathsf{H}}\n                &(n=1,\\ldots,N), \\\\\n                \\boldsymbol{f}_{in}^{(n_{1},n_{2})}\n                &=& {\\displaystyle\\frac{1}{J}\\sum_{j}}\n                \\varphi(\\bar{r}_{jn})y_{ijn}^{*}\\boldsymbol{y}_{ij}^{(n_{1},n_{2})}\n                &(n\\neq n_{1},n_{2}), \\\\\n                \\varphi(\\bar{r}_{jn})\n                &=&\\dfrac{G'_{\\mathbb{R}}(\\bar{r}_{jn})}{2\\bar{r}_{jn}}.\n            \\end{array}\n\n        Using :math:`\\boldsymbol{G}_{in}^{(n_{1},n_{2})}` and \\\n        :math:`\\boldsymbol{f}_{in}^{(n_{1},n_{2})}`, we compute\n\n        .. math::\n            \\begin{array}{rclc}\n                \\boldsymbol{p}_{in}\n                &=& \\dfrac{\\boldsymbol{h}_{in}}\n                {\\sqrt{\\boldsymbol{h}_{in}^{\\mathsf{H}}\\boldsymbol{G}_{in}^{(n_{1},n_{2})}\n                \\boldsymbol{h}_{in}}} & (n=n_{1},n_{2}), \\\\\n                \\boldsymbol{q}_{in}\n                &=& -{\\boldsymbol{G}_{in}^{(n_{1},n_{2})}}^{-1}\\boldsymbol{f}_{in}^{(n_{1},n_{2})}\n                & (n\\neq n_{1},n_{2}),\n            \\end{array}\n\n        where :math:`\\boldsymbol{h}_{in}` (:math:`n=n_{1},n_{2}`) is \\\n        a generalized eigenvector obtained from\n\n        .. math::\n            \\boldsymbol{G}_{in_{1}}^{(n_{1},n_{2})}\\boldsymbol{h}_{i}\n            = \\lambda_{i}\\boldsymbol{G}_{in_{2}}^{(n_{1},n_{2})}\\boldsymbol{h}_{i}.\n\n        Separated signal :math:`y_{ijn}` is updated as follows:\n\n        .. math::\n            y_{ijn}\n            &\\leftarrow\\begin{cases}\n            &\\boldsymbol{p}_{in}^{\\mathsf{H}}\\boldsymbol{y}_{ij}^{(n_{1},n_{2})}\n            & (n=n_{1},n_{2}) \\\\\n            &\\boldsymbol{q}_{in}^{\\mathsf{H}}\\boldsymbol{y}_{ij}^{(n_{1},n_{2})} + y_{ijn}\n            & (n\\neq n_{1},n_{2})\n            \\end{cases}.\n\n        .. [#ikeshita2022iss2]\n            R. Ikeshita and T. Nakatani,\n            \"ISS2: An extension of iterative source steering algorithm for \\\n            majorization-minimization-based independent vector analysis,\"\n            *arXiv:2202.00875*, 2022.\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        Y = self.output\n\n        # Auxiliary variables\n        r = np.linalg.norm(Y, axis=1)\n        varphi = self.d_contrast_fn(r) / flooring_fn(2 * r)\n\n        self.output = update_by_iss2(\n            Y,\n            varphi[:, np.newaxis, :],\n            flooring_fn=flooring_fn,\n            pair_selector=self.pair_selector,\n        )\n\n    def update_once_ipa(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update estimated spectrograms once using \\\n        iterative projection with adjustment [#scheibler2021independent]_.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        First, we compute auxiliary variables:\n\n        .. math::\n            \\bar{r}_{jn}\n            \\leftarrow\\|\\vec{\\boldsymbol{y}}_{jn}\\|_{2},\n\n        where\n\n        .. math::\n            G(\\vec{\\boldsymbol{y}}_{jn})\n            &= -\\log p(\\vec{\\boldsymbol{y}}_{jn}), \\\\\n            G_{\\mathbb{R}}(\\|\\vec{\\boldsymbol{y}}_{jn}\\|_{2})\n            &= G(\\vec{\\boldsymbol{y}}_{jn}).\n\n        Then, by defining, :math:`\\tilde{\\boldsymbol{U}}_{in'}`,\n        :math:`\\boldsymbol{A}_{in}\\in\\mathbb{R}^{(N-1)\\times(N-1)}`,\n        :math:`\\boldsymbol{b}_{in}\\in\\mathbb{C}^{N-1}`,\n        :math:`\\boldsymbol{C}_{in}\\in\\mathbb{C}^{(N-1)\\times(N-1)}`,\n        :math:`\\boldsymbol{d}_{in}\\in\\mathbb{C}^{N-1}`,\n        and :math:`z_{in}\\in\\mathbb{R}_{\\geq 0}` as follows:\n\n        .. math::\n\n            \\tilde{\\boldsymbol{U}}_{in'}\n            &= \\frac{1}{J}\\sum_{j}\\frac{G'_{\\mathbb{R}}(\\bar{r}_{jn'})}{2\\bar{r}_{jn'}}\n            \\boldsymbol{y}_{ij}\\boldsymbol{y}_{ij}^{\\mathsf{H}}, \\\\\n            \\boldsymbol{A}_{in}\n            &= \\mathrm{diag}(\\ldots,\n            \\boldsymbol{e}_{n}^{\\mathsf{T}}\\tilde{\\boldsymbol{U}}_{in'}\\boldsymbol{e}_{n}\n            ,\\ldots)~~(n'\\neq n), \\\\\n            \\boldsymbol{b}_{in}\n            &= (\\ldots,\n            \\boldsymbol{e}_{n}^{\\mathsf{T}}\\tilde{\\boldsymbol{U}}_{in'}\\boldsymbol{e}_{n'}\n            ,\\ldots)^{\\mathsf{T}}~~(n'\\neq n), \\\\\n            \\boldsymbol{C}_{in}\n            &= \\bar{\\boldsymbol{E}}_{n}^{\\mathsf{T}}(\\tilde{\\boldsymbol{U}}_{in}^{-1})^{*}\n            \\bar{\\boldsymbol{E}}_{n}, \\\\\n            \\boldsymbol{d}_{in}\n            &= \\bar{\\boldsymbol{E}}_{n}^{\\mathsf{T}}(\\tilde{\\boldsymbol{U}}_{in}^{-1})^{*}\n            \\boldsymbol{e}_{n}, \\\\\n            z_{in}\n            &= \\boldsymbol{e}_{n}^{\\mathsf{T}}\\tilde{\\boldsymbol{U}}_{in}^{-1}\\boldsymbol{e}_{n}\n            - \\boldsymbol{d}_{in}^{\\mathsf{H}}\\boldsymbol{C}_{in}^{-1}\\boldsymbol{d}_{in},\n\n        :math:`\\boldsymbol{y}_{ij}` is updated via log-quadratically penelized\n        quadratic minimization (LQPQM).\n\n        .. math::\n            \\check{\\boldsymbol{q}}_{in}\n            &\\leftarrow \\mathrm{LQPQM2}(\\boldsymbol{H}_{in},\\boldsymbol{v}_{in},z_{in}), \\\\\n            \\boldsymbol{q}_{in}\n            &\\leftarrow \\boldsymbol{G}_{in}^{-1}\\check{\\boldsymbol{q}}_{in}\n            - \\boldsymbol{A}_{in}^{-1}\\boldsymbol{b}_{in}, \\\\\n            \\tilde{\\boldsymbol{q}}_{in}\n            &\\leftarrow \\boldsymbol{e}_{n} - \\bar{\\boldsymbol{E}}_{n}\\boldsymbol{q}_{in}, \\\\\n            \\boldsymbol{p}_{in}\n            &\\leftarrow \\frac{\\tilde{\\boldsymbol{U}}_{in}^{-1}\\tilde{\\boldsymbol{q}}_{in}^{*}}\n            {\\sqrt{(\\tilde{\\boldsymbol{q}}_{in}^{*})^{\\mathsf{H}}\\tilde{\\boldsymbol{U}}_{in}^{-1}\n            \\tilde{\\boldsymbol{q}}_{in}^{*}}}, \\\\\n            \\boldsymbol{\\Upsilon}_{i}^{(n)}\n            &\\leftarrow \\boldsymbol{I}\n            + \\boldsymbol{e}_{n}(\\boldsymbol{p}_{in} - \\boldsymbol{e}_{n})^{\\mathsf{H}}\n            + \\bar{\\boldsymbol{E}}_{n}\\boldsymbol{q}_{in}^{*}\\boldsymbol{e}_{n}^{\\mathsf{T}}, \\\\\n            \\boldsymbol{y}_{ij}\n            &\\leftarrow \\boldsymbol{\\Upsilon}_{i}^{(n)}\\boldsymbol{y}_{ij},\n\n        .. [#scheibler2021independent]\n            R. Scheibler,\n            \"Independent vector analysis via log-quadratically penalized quadratic minimization,\"\n            *IEEE Trans. Signal Processing*, vol. 69, pp. 2509-2524, 2021.\n\n        \"\"\"\n        self.lqpqm_normalization: bool\n        self.newton_iter: int\n\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        Y = self.output\n        r = np.linalg.norm(Y, axis=1)\n        denom = flooring_fn(2 * r)\n        varphi = self.d_contrast_fn(r) / denom\n\n        normalization = self.lqpqm_normalization\n        max_iter = self.newton_iter\n\n        self.output = update_by_ipa(\n            Y,\n            varphi[:, np.newaxis, :],\n            normalization=normalization,\n            flooring_fn=flooring_fn,\n            max_iter=max_iter,\n        )\n\n    def compute_loss(self) -> float:\n        r\"\"\"Compute loss.\"\"\"\n        if self.demix_filter is None:\n            X, Y = self.input, self.output\n            G = self.contrast_fn(Y)  # (n_sources, n_frames)\n            X, Y = X.transpose(1, 0, 2), Y.transpose(1, 0, 2)\n            X_Hermite = X.transpose(0, 2, 1).conj()\n            XX_Hermite = X @ X_Hermite  # (n_bins, n_channels, n_channels)\n            W = Y @ X_Hermite @ np.linalg.inv(XX_Hermite)\n            logdet = self.compute_logdet(W)  # (n_bins,)\n            loss = np.sum(np.mean(G, axis=1), axis=0) - 2 * np.sum(logdet, axis=0)\n            loss = loss.item()\n\n            return loss\n        else:\n            return super().compute_loss()\n\n    def apply_projection_back(self) -> None:\n        r\"\"\"Apply projection back technique to estimated spectrograms.\"\"\"\n        if self.demix_filter is None:\n            assert self.scale_restoration, \"Set self.scale_restoration=True.\"\n\n            X, Y = self.input, self.output\n            Y_scaled = projection_back(Y, reference=X, reference_id=self.reference_id)\n\n            self.output = Y_scaled\n        else:\n            super().apply_projection_back()\n\n    def apply_minimal_distortion_principle(self) -> None:\n        r\"\"\"Apply minimal distortion principle to estimated spectrograms.\"\"\"\n        if self.demix_filter is None:\n            X, Y = self.input, self.output\n            Y_scaled = minimal_distortion_principle(Y, reference=X, reference_id=self.reference_id)\n\n            self.output = Y_scaled\n        else:\n            super().apply_minimal_distortion_principle()\n\n\nclass PDSIVA(PDSBSS):\n    def __init__(\n        self,\n        mu1: float = 1,\n        mu2: float = 1,\n        alpha: float = None,\n        relaxation: float = 1,\n        contrast_fn: Callable[[np.ndarray], np.ndarray] = None,\n        prox_penalty: Callable[[np.ndarray, float], np.ndarray] = None,\n        callbacks: Optional[\n            Union[Callable[[\"PDSIVA\"], None], List[Callable[[\"PDSIVA\"], None]]]\n        ] = None,\n        scale_restoration: bool = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n    ) -> None:\n        if contrast_fn is not None and prox_penalty is None:\n            raise ValueError(\"Set prox_penalty.\")\n        elif contrast_fn is None and prox_penalty is not None:\n            raise ValueError(\"Set contrast_fn.\")\n        elif contrast_fn is None and prox_penalty is None:\n\n            def _contrast_fn(y: np.ndarray) -> np.ndarray:\n                return np.linalg.norm(y, axis=1)\n\n            def _prox_penalty(x: np.ndarray, step_size: float = 1) -> np.ndarray:\n                return prox.l21(x, step_size=step_size, axis2=1)\n\n            contrast_fn = _contrast_fn\n            prox_penalty = _prox_penalty\n\n        def penalty_fn(y: np.ndarray) -> float:\n            r\"\"\"Sum of contrast function.\n\n            Args:\n                y (numpy.ndarray):\n                    The shape is (n_sources, n_bins, n_frames).\n\n            Returns:\n                Computed loss.\n            \"\"\"\n            G = contrast_fn(y)  # (n_sources, n_frames)\n            loss = np.sum(G, axis=(0, 1))\n            loss = loss.item()\n\n            return loss\n\n        super().__init__(\n            mu1=mu1,\n            mu2=mu2,\n            alpha=alpha,\n            relaxation=relaxation,\n            penalty_fn=penalty_fn,\n            prox_penalty=prox_penalty,\n            callbacks=callbacks,\n            scale_restoration=scale_restoration,\n            record_loss=record_loss,\n            reference_id=reference_id,\n        )\n\n        self.contrast_fn = contrast_fn\n\n\nclass ADMMIVA(ADMMBSS):\n    def __init__(\n        self,\n        rho: float = 1,\n        alpha: float = None,\n        relaxation: float = 1,\n        contrast_fn: Callable[[np.ndarray], np.ndarray] = None,\n        prox_penalty: Callable[[np.ndarray, float], np.ndarray] = None,\n        callbacks: Optional[\n            Union[Callable[[\"ADMMIVA\"], None], List[Callable[[\"ADMMIVA\"], None]]]\n        ] = None,\n        scale_restoration: bool = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n    ) -> None:\n        if contrast_fn is not None and prox_penalty is None:\n            raise ValueError(\"Set prox_penalty.\")\n        elif contrast_fn is None and prox_penalty is not None:\n            raise ValueError(\"Set contrast_fn.\")\n        elif contrast_fn is None and prox_penalty is None:\n\n            def _contrast_fn(y: np.ndarray) -> np.ndarray:\n                return np.linalg.norm(y, axis=1)\n\n            def _prox_penalty(x: np.ndarray, step_size: float = 1) -> np.ndarray:\n                return prox.l21(x, step_size=step_size, axis2=1)\n\n            contrast_fn = _contrast_fn\n            prox_penalty = _prox_penalty\n\n        def penalty_fn(y: np.ndarray) -> float:\n            r\"\"\"Sum of contrast function.\n\n            Args:\n                y (numpy.ndarray):\n                    The shape is (n_sources, n_bins, n_frames).\n\n            Returns:\n                Computed loss.\n            \"\"\"\n            G = contrast_fn(y)  # (n_sources, n_frames)\n            loss = np.sum(G, axis=(0, 1))\n            loss = loss.item()\n\n            return loss\n\n        super().__init__(\n            rho=rho,\n            alpha=alpha,\n            relaxation=relaxation,\n            penalty_fn=penalty_fn,\n            prox_penalty=prox_penalty,\n            callbacks=callbacks,\n            scale_restoration=scale_restoration,\n            record_loss=record_loss,\n            reference_id=reference_id,\n        )\n\n        self.contrast_fn = contrast_fn\n\n\nclass GradLaplaceIVA(GradIVA):\n    r\"\"\"Independent vector analysis (IVA) using the gradient descent on a Laplace distribution.\n\n    We assume :math:`\\vec{\\boldsymbol{y}}_{jn}` follows a Laplace distribution.\n\n    .. math::\n        p(\\vec{\\boldsymbol{y}}_{jn})\\propto\\exp(\\|\\vec{\\boldsymbol{y}}_{jn}\\|_{2})\n\n    Args:\n        step_size (float):\n            A step size of the gradient descent. Default: ``1e-1``.\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        is_holonomic (bool):\n            If ``is_holonomic=True``, Holonomic-type update is used.\n            Otherwise, Nonholonomic-type update is used. Default: ``False``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back``\n            or ``minimal_distortion_principle``. Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the gradient descent if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back and minimal distortion principle. Default: ``0``.\n\n    Examples:\n        Update demixing filters using Holonomic-type update:\n\n        .. code-block:: python\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> iva = GradLaplaceIVA(is_holonomic=True)\n            >>> spectrogram_est = iva(spectrogram_mix, n_iter=5000)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n        Update demixing filters using Nonholonomic-type update:\n\n        .. code-block:: python\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> iva = GradLaplaceIVA(is_holonomic=False)\n            >>> spectrogram_est = iva(spectrogram_mix, n_iter=5000)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n    \"\"\"\n\n    def __init__(\n        self,\n        step_size: float = 1e-1,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        callbacks: Optional[\n            Union[Callable[[\"GradLaplaceIVA\"], None], List[Callable[[\"GradLaplaceIVA\"], None]]]\n        ] = None,\n        is_holonomic: bool = True,\n        scale_restoration: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n    ) -> None:\n        def contrast_fn(y: np.ndarray) -> np.ndarray:\n            r\"\"\"Contrast function.\n\n            Args:\n                y (numpy.ndarray):\n                    The shape is (n_sources, n_bins, n_frames).\n\n            Returns:\n                numpy.ndarray of the shape is (n_sources, n_frames).\n            \"\"\"\n            return 2 * np.linalg.norm(y, axis=1)\n\n        def score_fn(y: np.ndarray) -> np.ndarray:\n            r\"\"\"Score function.\n\n            Args:\n                y (numpy.ndarray):\n                    The shape is (n_sources, n_bins, n_frames).\n\n            Returns:\n                numpy.ndarray of the shape is (n_sources, n_bins, n_frames).\n            \"\"\"\n            norm = np.linalg.norm(y, axis=1, keepdims=True)\n            norm = self.flooring_fn(norm)\n            return y / norm\n\n        super().__init__(\n            step_size=step_size,\n            contrast_fn=contrast_fn,\n            score_fn=score_fn,\n            flooring_fn=flooring_fn,\n            callbacks=callbacks,\n            is_holonomic=is_holonomic,\n            scale_restoration=scale_restoration,\n            record_loss=record_loss,\n            reference_id=reference_id,\n        )\n\n    def update_once(self) -> None:\n        r\"\"\"Update demixing filters once using the gradient descent.\n\n        If ``is_holonomic=True``, demixing filters are updated as follows:\n\n        .. math::\n            \\boldsymbol{W}_{i}\n            \\leftarrow\\boldsymbol{W}_{i} - \\eta\\left(\\frac{1}{J}\\sum_{j} \\\n            \\boldsymbol{\\phi}_{i}(\\vec{\\boldsymbol{Y}}_{j})\\boldsymbol{y}_{ij}^{\\mathsf{H}} \\\n            -\\boldsymbol{I}\\right)\\boldsymbol{W}_{i}^{-\\mathsf{H}},\n\n        where\n\n        .. math::\n            \\boldsymbol{\\phi}_{i}(\\vec{\\boldsymbol{Y}}_{j})\n            &= \\left(\\phi_{i}(\\vec{\\boldsymbol{y}}_{j1}),\\ldots,\\\n            \\phi_{i}(\\vec{\\boldsymbol{y}}_{jn}),\\ldots,\\\n            \\phi_{i}(\\vec{\\boldsymbol{y}}_{jN}))\\\n            \\right)^{\\mathsf{T}}\\in\\mathbb{C}^{N}, \\\\\n            \\phi_{i}(\\vec{\\boldsymbol{y}}_{jn})\n            &= \\frac{y_{ijn}}{\\|\\vec{\\boldsymbol{y}}_{jn}\\|_{2}}.\n\n        Otherwise (``is_holonomic=False``),\n\n        .. math::\n            \\boldsymbol{W}_{i}\n            \\leftarrow\\boldsymbol{W}_{i}\n            - \\eta\\cdot\\mathrm{offdiag}\\left(\\frac{1}{J}\\sum_{j}\n            \\boldsymbol{\\phi}_{i}(\\vec{\\boldsymbol{Y}}_{j})\\boldsymbol{y}_{ij}^{\\mathsf{H}}\\right)\n            \\boldsymbol{W}_{i}^{-\\mathsf{H}}.\n        \"\"\"\n        return super().update_once()\n\n    def compute_loss(self) -> float:\n        r\"\"\"Compute loss :math:`\\mathcal{L}`.\n\n        :math:`\\mathcal{L}` is given as follows:\n\n        .. math::\n            \\mathcal{L} \\\n            = \\frac{2}{J}\\sum_{j,n}\\|\\vec{\\boldsymbol{y}}_{jn}\\|_{2} \\\n            - 2\\sum_{i}\\log|\\det\\boldsymbol{W}_{i}|.\n\n        Returns:\n            Computed loss.\n        \"\"\"\n        return super().compute_loss()\n\n\nclass GradGaussIVA(GradIVA):\n    r\"\"\"Independent vector analysis (IVA) using the gradient descent on \\\n    a time-varying Gaussian distribution.\n\n    We assume :math:`\\vec{\\boldsymbol{y}}_{jn}` follows a time-varying Gaussian distribution.\n\n    .. math::\n        p(\\vec{\\boldsymbol{y}}_{jn})\n        \\propto\\frac{1}{\\alpha_{jn}^{I}}\n        \\exp\\left(\\frac{\\|\\vec{\\boldsymbol{y}}_{jn}\\|_{2}^{2}}{\\alpha_{jn}}\\right).\n\n    Args:\n        step_size (float):\n            A step size of the gradient descent. Default: ``1e-1``.\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        is_holonomic (bool):\n            If ``is_holonomic=True``, Holonomic-type update is used.\n            Otherwise, Nonholonomic-type update is used. Default: ``False``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back``\n            or ``minimal_distortion_principle``. Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the gradient descent if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back and minimal distortion principle. Default: ``0``.\n\n    Examples:\n        Update demixing filters using Holonomic-type update:\n\n        .. code-block:: python\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> iva = GradGaussIVA(is_holonomic=True)\n            >>> spectrogram_est = iva(spectrogram_mix, n_iter=5000)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n        Update demixing filters using Nonholonomic-type update:\n\n        .. code-block:: python\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> iva = GradGaussIVA(is_holonomic=False)\n            >>> spectrogram_est = iva(spectrogram_mix, n_iter=5000)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n    \"\"\"\n\n    def __init__(\n        self,\n        step_size: float = 1e-1,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        callbacks: Optional[\n            Union[Callable[[\"GradGaussIVA\"], None], List[Callable[[\"GradGaussIVA\"], None]]]\n        ] = None,\n        is_holonomic: bool = True,\n        scale_restoration: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n    ) -> None:\n        def contrast_fn(y: np.ndarray) -> np.ndarray:\n            r\"\"\"\n            Args:\n                y (numpy.ndarray):\n                    Separated signal with shape of (n_sources, n_bins, n_frames).\n\n            Returns:\n                numpy.ndarray of computed contrast function.\n                The shape is (n_sources, n_frames).\n            \"\"\"\n            n_bins = self.n_bins\n            alpha = self.variance\n            norm = np.linalg.norm(y, axis=1)\n\n            return n_bins * np.log(alpha) + (norm**2) / alpha\n\n        def score_fn(y: np.ndarray) -> np.ndarray:\n            r\"\"\"\n            Args:\n                y (numpy.ndarray):\n                    Norm of separated signal.\n                    The shape is (n_sources, n_bins, n_frames).\n\n            Returns:\n                numpy.ndarray of computed contrast function.\n                The shape is (n_sources, n_frames).\n            \"\"\"\n            alpha = self.variance\n            return y / alpha[:, np.newaxis, :]\n\n        super().__init__(\n            step_size=step_size,\n            contrast_fn=contrast_fn,\n            score_fn=score_fn,\n            flooring_fn=flooring_fn,\n            callbacks=callbacks,\n            is_holonomic=is_holonomic,\n            scale_restoration=scale_restoration,\n            record_loss=record_loss,\n            reference_id=reference_id,\n        )\n\n    def _reset(self, **kwargs) -> None:\n        r\"\"\"Reset attributes by given keyword arguments.\n\n        We also set variance of Gaussian distribution.\n\n        Args:\n            kwargs:\n                Keyword arguments to set as attributes of IVA.\n        \"\"\"\n        super()._reset(**kwargs)\n\n        n_sources, n_frames = self.n_sources, self.n_frames\n\n        self.variance = np.ones((n_sources, n_frames))\n\n    def update_once(self) -> None:\n        r\"\"\"Update variance and demixing filters and once.\"\"\"\n        self.update_source_model()\n\n        super().update_once()\n\n    def update_source_model(self) -> None:\n        r\"\"\"Update variance of Gaussian distribution.\"\"\"\n        X, W = self.input, self.demix_filter\n        Y = self.separate(X, demix_filter=W)\n\n        self.variance = np.mean(np.abs(Y) ** 2, axis=1)\n\n\nclass NaturalGradLaplaceIVA(NaturalGradIVA):\n    r\"\"\"Independent vector analysis (IVA) using the natural gradient descent \\\n    on a Laplace distribution.\n\n    We assume :math:`\\vec{\\boldsymbol{y}}_{jn}` follows a Laplace distribution.\n\n    .. math::\n        p(\\vec{\\boldsymbol{y}}_{jn})\n        \\propto\\frac{1}{\\alpha_{jn}^{I}}\n        \\exp\\left(\\frac{\\|\\vec{\\boldsymbol{y}}_{jn}\\|_{2}}{\\alpha_{jn}}\\right)\n\n    Args:\n        step_size (float):\n            A step size of the gradient descent. Default: ``1e-1``.\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        is_holonomic (bool):\n            If ``is_holonomic=True``, Holonomic-type update is used.\n            Otherwise, Nonholonomic-type update is used. Default: ``False``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back``\n            or ``minimal_distortion_principle``. Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the gradient descent if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back and minimal distortion principle. Default: ``0``.\n\n    Examples:\n        Update demixing filters using Holonomic-type update:\n\n        .. code-block:: python\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> iva = NaturalGradLaplaceIVA(is_holonomic=True)\n            >>> spectrogram_est = iva(spectrogram_mix, n_iter=500)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n        Update demixing filters using Nonholonomic-type update:\n\n        .. code-block:: python\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> iva = NaturalGradLaplaceIVA(is_holonomic=False)\n            >>> spectrogram_est = iva(spectrogram_mix, n_iter=500)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n    \"\"\"\n\n    def __init__(\n        self,\n        step_size: float = 1e-1,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        callbacks: Optional[\n            Union[\n                Callable[[\"NaturalGradLaplaceIVA\"], None],\n                List[Callable[[\"NaturalGradLaplaceIVA\"], None]],\n            ]\n        ] = None,\n        is_holonomic: bool = True,\n        scale_restoration: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n    ) -> None:\n        def contrast_fn(y: np.ndarray) -> np.ndarray:\n            r\"\"\"Contrast function.\n\n            Args:\n                y (numpy.ndarray):\n                    The shape is (n_sources, n_bins, n_frames).\n\n            Returns:\n                numpy.ndarray of the shape is (n_sources, n_frames).\n            \"\"\"\n            return 2 * np.linalg.norm(y, axis=1)\n\n        def score_fn(y: np.ndarray) -> np.ndarray:\n            r\"\"\"Score function.\n\n            Args:\n                y (numpy.ndarray):\n                    The shape is (n_sources, n_bins, n_frames).\n\n            Returns:\n                numpy.ndarray of the shape is (n_sources, n_bins, n_frames).\n            \"\"\"\n            norm = np.linalg.norm(y, axis=1, keepdims=True)\n            norm = self.flooring_fn(norm)\n            return y / norm\n\n        super().__init__(\n            step_size=step_size,\n            contrast_fn=contrast_fn,\n            score_fn=score_fn,\n            flooring_fn=flooring_fn,\n            callbacks=callbacks,\n            is_holonomic=is_holonomic,\n            scale_restoration=scale_restoration,\n            record_loss=record_loss,\n            reference_id=reference_id,\n        )\n\n    def update_once(self) -> None:\n        r\"\"\"Update demixing filters once using the natural gradient descent.\n\n        If ``is_holonomic=True``, demixing filters are updated as follows:\n\n        .. math::\n            \\boldsymbol{W}_{i}\n            \\leftarrow\\boldsymbol{W}_{i} - \\eta\\left(\\frac{1}{J}\\sum_{j} \\\n            \\boldsymbol{\\phi}_{i}(\\vec{\\boldsymbol{Y}}_{j})\\boldsymbol{y}_{ij}^{\\mathsf{H}} \\\n            -\\boldsymbol{I}\\right)\\boldsymbol{W}_{i},\n\n        where\n\n        .. math::\n            \\boldsymbol{\\phi}_{i}(\\vec{\\boldsymbol{Y}}_{j})\n            &= \\left(\\phi_{i}(\\vec{\\boldsymbol{y}}_{j1}),\\ldots,\\\n            \\phi_{i}(\\vec{\\boldsymbol{y}}_{jn}),\\ldots,\\\n            \\phi_{i}(\\vec{\\boldsymbol{y}}_{jN}))\\\n            \\right)^{\\mathsf{T}}\\in\\mathbb{C}^{N}, \\\\\n            \\phi_{i}(\\vec{\\boldsymbol{y}}_{jn})\n            &= \\frac{y_{ijn}}{\\|\\vec{\\boldsymbol{y}}_{jn}\\|_{2}}.\n\n        Otherwise (``is_holonomic=False``),\n\n        .. math::\n            \\boldsymbol{W}_{i}\n            \\leftarrow\\boldsymbol{W}_{i}\n            - \\eta\\cdot\\mathrm{offdiag}\\left(\\frac{1}{J}\\sum_{j}\n            \\boldsymbol{\\phi}_{i}(\\vec{\\boldsymbol{Y}}_{j})\\boldsymbol{y}_{ij}^{\\mathsf{H}}\\right)\n            \\boldsymbol{W}_{i}.\n        \"\"\"\n        return super().update_once()\n\n    def compute_loss(self) -> float:\n        r\"\"\"Compute loss :math:`\\mathcal{L}`.\n\n        :math:`\\mathcal{L}` is given as follows:\n\n        .. math::\n            \\mathcal{L} \\\n            = \\frac{2}{J}\\sum_{j,n}\\|\\vec{\\boldsymbol{y}}_{jn}\\|_{2} \\\n            - 2\\sum_{i}\\log|\\det\\boldsymbol{W}_{i}|.\n\n        Returns:\n            Computed loss.\n        \"\"\"\n        return super().compute_loss()\n\n\nclass NaturalGradGaussIVA(NaturalGradIVA):\n    r\"\"\"Independent vector analysis (IVA) using the natural gradient descent \\\n    on a time-varying Gaussian distribution.\n\n    We assume :math:`\\vec{\\boldsymbol{y}}_{jn}` follows a time-varying Gaussian distribution.\n\n    .. math::\n        p(\\vec{\\boldsymbol{y}}_{jn})\n        \\propto\\frac{1}{\\alpha_{jn}^{I}}\n        \\exp\\left(\\frac{\\|\\vec{\\boldsymbol{y}}_{jn}\\|_{2}^{2}}{\\alpha_{jn}}\\right).\n\n    Args:\n        step_size (float):\n            A step size of the gradient descent. Default: ``1e-1``.\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        is_holonomic (bool):\n            If ``is_holonomic=True``, Holonomic-type update is used.\n            Otherwise, Nonholonomic-type update is used. Default: ``False``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back``\n            or ``minimal_distortion_principle``. Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the gradient descent if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back and minimal distortion principle. Default: ``0``.\n\n    Examples:\n        Update demixing filters using Holonomic-type update:\n\n        .. code-block:: python\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> iva = NaturalGradGaussIVA(is_holonomic=True)\n            >>> spectrogram_est = iva(spectrogram_mix, n_iter=500)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n        Update demixing filters using Nonholonomic-type update:\n\n        .. code-block:: python\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> iva = NaturalGradGaussIVA(is_holonomic=False)\n            >>> spectrogram_est = iva(spectrogram_mix, n_iter=500)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n    \"\"\"\n\n    def __init__(\n        self,\n        step_size: float = 1e-1,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        callbacks: Optional[\n            Union[\n                Callable[[\"NaturalGradGaussIVA\"], None],\n                List[Callable[[\"NaturalGradGaussIVA\"], None]],\n            ]\n        ] = None,\n        is_holonomic: bool = True,\n        scale_restoration: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n    ) -> None:\n        def contrast_fn(y: np.ndarray) -> np.ndarray:\n            r\"\"\"\n            Args:\n                y (numpy.ndarray):\n                    Separated signal with shape of (n_sources, n_bins, n_frames).\n\n            Returns:\n                numpy.ndarray of computed contrast function.\n                The shape is (n_sources, n_frames).\n            \"\"\"\n            n_bins = self.n_bins\n            alpha = self.variance\n            norm = np.linalg.norm(y, axis=1)\n\n            return n_bins * np.log(alpha) + (norm**2) / alpha\n\n        def score_fn(y: np.ndarray) -> np.ndarray:\n            r\"\"\"\n            Args:\n                y (numpy.ndarray):\n                    Norm of separated signal.\n                    The shape is (n_sources, n_bins, n_frames).\n\n            Returns:\n                numpy.ndarray of computed contrast function.\n                The shape is (n_sources, n_frames).\n            \"\"\"\n            alpha = self.variance\n            return y / alpha[:, np.newaxis, :]\n\n        super().__init__(\n            step_size=step_size,\n            contrast_fn=contrast_fn,\n            score_fn=score_fn,\n            flooring_fn=flooring_fn,\n            callbacks=callbacks,\n            is_holonomic=is_holonomic,\n            scale_restoration=scale_restoration,\n            record_loss=record_loss,\n            reference_id=reference_id,\n        )\n\n    def _reset(self, **kwargs) -> None:\n        r\"\"\"Reset attributes by given keyword arguments.\n\n        We also set variance of Gaussian distribution.\n\n        Args:\n            kwargs:\n                Keyword arguments to set as attributes of IVA.\n        \"\"\"\n        super()._reset(**kwargs)\n\n        n_sources, n_frames = self.n_sources, self.n_frames\n\n        self.variance = np.ones((n_sources, n_frames))\n\n    def update_once(self) -> None:\n        r\"\"\"Update variance and demixing filters and once.\"\"\"\n        self.update_source_model()\n\n        super().update_once()\n\n    def update_source_model(self) -> None:\n        r\"\"\"Update variance of Gaussian distribution.\"\"\"\n        X, W = self.input, self.demix_filter\n        Y = self.separate(X, demix_filter=W)\n\n        self.variance = np.mean(np.abs(Y) ** 2, axis=1)\n\n\nclass AuxLaplaceIVA(AuxIVA):\n    r\"\"\"Auxiliary-function-based independent vector analysis (IVA) \\\n    on a Laplace distribution.\n\n    We assume :math:`\\vec{\\boldsymbol{y}}_{jn}` follows a Laplace distribution.\n\n    .. math::\n        p(\\vec{\\boldsymbol{y}}_{jn})\\propto\\exp(\\|\\vec{\\boldsymbol{y}}_{jn}\\|_{2})\n\n    Args:\n        spatial_algorithm (str):\n            Algorithm for demixing filter updates.\n            Choose ``IP``, ``IP1``, ``IP2``, ``ISS``, ``ISS1``, or ``ISS2``.\n            Default: ``IP``.\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n        pair_selector (callable, optional):\n            Selector to choose updaing pair in ``IP2`` and ``ISS2``.\n            If ``None`` is given, ``sequential_pair_selector`` is used.\n            Default: ``None``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back``\n            or ``minimal_distortion_principle``. Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the demixing filter update if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back and minimal distortion principle. Default: ``0``.\n\n    Examples:\n        Update demixing filters by IP:\n\n        .. code-block:: python\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> iva = AuxLaplaceIVA(spatial_algorithm=\"IP\")\n            >>> spectrogram_est = iva(spectrogram_mix, n_iter=100)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n        Update demixing filters by IP2:\n\n        .. code-block:: python\n\n            >>> from ssspy.utils.select_pair import sequential_pair_selector\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> iva = AuxLaplaceIVA(\n            ...     spatial_algorithm=\"IP2\",\n            ...     pair_selector=sequential_pair_selector,\n            ... )\n            >>> spectrogram_est = iva(spectrogram_mix, n_iter=100)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n        Update demixing filters by ISS:\n\n        .. code-block:: python\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> iva = AuxLaplaceIVA(spatial_algorithm=\"ISS\")\n            >>> spectrogram_est = iva(spectrogram_mix, n_iter=100)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n        Update demixing filters by ISS2:\n\n        .. code-block:: python\n\n            >>> import functools\n            >>> from ssspy.utils.select_pair import sequential_pair_selector\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> iva = AuxLaplaceIVA(\n            ...     spatial_algorithm=\"ISS2\",\n            ...     pair_selector=functools.partial(sequential_pair_selector, step=2),\n            ... )\n            >>> spectrogram_est = iva(spectrogram_mix, n_iter=100)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_algorithm: str = \"IP\",\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        pair_selector: Optional[Callable[[int], Iterable[Tuple[int, int]]]] = None,\n        callbacks: Optional[\n            Union[Callable[[\"AuxLaplaceIVA\"], None], List[Callable[[\"AuxLaplaceIVA\"], None]]]\n        ] = None,\n        scale_restoration: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n        **kwargs,\n    ) -> None:\n        def contrast_fn(y) -> np.ndarray:\n            r\"\"\"Contrast function.\n\n            Args:\n                y (numpy.ndarray):\n                    The shape is (n_sources, n_bins, n_frames).\n\n            Returns:\n                numpy.ndarray of the shape is (n_sources, n_frames).\n            \"\"\"\n            return 2 * np.linalg.norm(y, axis=1)\n\n        def d_contrast_fn(y) -> np.ndarray:\n            r\"\"\"Derivative of contrast function.\n\n            Args:\n                y (numpy.ndarray):\n                    The shape is (n_sources, n_bins, n_frames).\n\n            Returns:\n                numpy.ndarray of the shape is (n_sources, n_frames).\n            \"\"\"\n            return 2 * np.ones_like(y)\n\n        super().__init__(\n            spatial_algorithm=spatial_algorithm,\n            contrast_fn=contrast_fn,\n            d_contrast_fn=d_contrast_fn,\n            flooring_fn=flooring_fn,\n            pair_selector=pair_selector,\n            callbacks=callbacks,\n            scale_restoration=scale_restoration,\n            record_loss=record_loss,\n            reference_id=reference_id,\n            **kwargs,\n        )\n\n\nclass AuxGaussIVA(AuxIVA):\n    r\"\"\"Auxiliary-function-based independent vector analysis (IVA) \\\n    on a time-varying Gaussian distribution [#ono2012auxiliary]_.\n\n    We assume :math:`\\vec{\\boldsymbol{y}}_{jn}` follows a time-varying Gaussian distribution.\n\n    .. math::\n        p(\\vec{\\boldsymbol{y}}_{jn})\n        \\propto\\frac{1}{\\alpha_{jn}^{I}}\n        \\exp\\left(\\frac{\\|\\vec{\\boldsymbol{y}}_{jn}\\|_{2}^{2}}{\\alpha_{jn}}\\right).\n\n    Args:\n        spatial_algorithm (str):\n            Algorithm for demixing filter updates.\n            Choose ``IP``, ``IP1``, ``IP2``, ``ISS``, ``ISS1``, or ``ISS2``.\n            Default: ``IP``.\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n        pair_selector (callable, optional):\n            Selector to choose updaing pair in ``IP2`` and ``ISS2``.\n            If ``None`` is given, ``sequential_pair_selector`` is used.\n            Default: ``None``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back``\n            or ``minimal_distortion_principle``. Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the demixing filter update if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back and minimal distortion principle. Default: ``0``.\n\n    Examples:\n        Update demixing filters by IP:\n\n        .. code-block:: python\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> iva = AuxGaussIVA(spatial_algorithm=\"IP\")\n            >>> spectrogram_est = iva(spectrogram_mix, n_iter=100)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n        Update demixing filters by IP2:\n\n        .. code-block:: python\n\n            >>> from ssspy.utils.select_pair import sequential_pair_selector\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> iva = AuxGaussIVA(\n            ...     spatial_algorithm=\"IP2\",\n            ...     pair_selector=sequential_pair_selector,\n            ... )\n            >>> spectrogram_est = iva(spectrogram_mix, n_iter=100)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n        Update demixing filters by ISS:\n\n        .. code-block:: python\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> iva = AuxGaussIVA(spatial_algorithm=\"ISS\")\n            >>> spectrogram_est = iva(spectrogram_mix, n_iter=100)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n        Update demixing filters by ISS2:\n\n        .. code-block:: python\n\n            >>> import functools\n            >>> from ssspy.utils.select_pair import sequential_pair_selector\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) \\\n            ...     + 1j * np.random.randn(n_channels, n_bins, n_frames)\n\n            >>> iva = AuxGaussIVA(\n            ...     spatial_algorithm=\"ISS2\",\n            ...     pair_selector=functools.partial(sequential_pair_selector, step=2),\n            ... )\n            >>> spectrogram_est = iva(spectrogram_mix, n_iter=100)\n            >>> print(spectrogram_mix.shape, spectrogram_est.shape)\n            (2, 2049, 128), (2, 2049, 128)\n\n    .. [#ono2012auxiliary]\n        N. Ono,\n        \"Auxiliary-function-based independent vector analysis with power of \\\n        vector-norm type weighting functions,\"\n        in *Proc. APSIPA ASC*, 2012, pp. 1-4.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_algorithm: str = \"IP\",\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        pair_selector: Optional[Callable[[int], Iterable[Tuple[int, int]]]] = None,\n        callbacks: Optional[\n            Union[Callable[[\"AuxGaussIVA\"], None], List[Callable[[\"AuxGaussIVA\"], None]]]\n        ] = None,\n        scale_restoration: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n        **kwargs,\n    ) -> None:\n        def contrast_fn(y: np.ndarray) -> np.ndarray:\n            r\"\"\"\n            Args:\n                y (numpy.ndarray):\n                    Separated signal with shape of (n_sources, n_bins, n_frames).\n\n            Returns:\n                numpy.ndarray:\n                    Computed contrast function.\n                    The shape is (n_sources, n_frames).\n            \"\"\"\n            n_bins = self.n_bins\n            alpha = self.variance\n            norm = np.linalg.norm(y, axis=1)\n\n            return n_bins * np.log(alpha) + (norm**2) / alpha\n\n        def d_contrast_fn(y: np.ndarray, variance: np.ndarray = None) -> np.ndarray:\n            r\"\"\"\n            Args:\n                y (numpy.ndarray):\n                    Norm of separated signal.\n                    The shape is (n_sources, n_frames).\n\n            Returns:\n                numpy.ndarray of computed contrast function.\n                The shape is (n_sources, n_frames).\n            \"\"\"\n            if variance is None:\n                alpha = self.variance\n            else:\n                alpha = variance\n\n            return 2 * y / alpha\n\n        super().__init__(\n            spatial_algorithm=spatial_algorithm,\n            contrast_fn=contrast_fn,\n            d_contrast_fn=d_contrast_fn,\n            flooring_fn=flooring_fn,\n            pair_selector=pair_selector,\n            callbacks=callbacks,\n            scale_restoration=scale_restoration,\n            record_loss=record_loss,\n            reference_id=reference_id,\n            **kwargs,\n        )\n\n    def _reset(self, **kwargs) -> None:\n        r\"\"\"Reset attributes by given keyword arguments.\n\n        We also set variance of Gaussian distribution.\n\n        Args:\n            kwargs:\n                Keyword arguments to set as attributes of IVA.\n        \"\"\"\n        super()._reset(**kwargs)\n\n        n_sources, n_frames = self.n_sources, self.n_frames\n\n        self.variance = np.ones((n_sources, n_frames))\n\n    def update_once(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update variance and demixing filters and once.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        self.update_source_model()\n\n        super().update_once(flooring_fn=flooring_fn)\n\n    def update_once_ip2(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update demixing filters once using pairwise iterative projection.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        For :math:`n_{1}` and :math:`n_{2}` (:math:`n_{1}\\neq n_{2}`),\n        compute auxiliary variables:\n\n        .. math::\n            \\bar{r}_{jn_{1}}\n            &\\leftarrow\\|\\vec{\\boldsymbol{y}}_{jn_{1}}\\|_{2} \\\\\n            \\bar{r}_{jn_{2}}\n            &\\leftarrow\\|\\vec{\\boldsymbol{y}}_{jn_{2}}\\|_{2}\n\n        Then, for :math:`n=n_{1},n_{2}`, compute weighted covariance matrix as follows:\n\n        .. math::\n            \\boldsymbol{U}_{in_{1}}\n            &= \\frac{1}{J}\\sum_{j}\n            \\varphi(\\bar{r}_{jn_{1}})\\boldsymbol{x}_{ij}\\boldsymbol{x}_{ij}^{\\mathsf{H}}, \\\\\n            \\boldsymbol{U}_{in_{2}}\n            &= \\frac{1}{J}\\sum_{j}\n            \\varphi(\\bar{r}_{jn_{2}})\\boldsymbol{x}_{ij}\\boldsymbol{x}_{ij}^{\\mathsf{H}},\n\n        where\n\n        .. math::\n            \\varphi(\\bar{r}_{jn})\n            = \\frac{G'_{\\mathbb{R}}(\\bar{r}_{jn})}{2\\bar{r}_{jn}}.\n\n        Using :math:`\\boldsymbol{U}_{in_{1}}` and\n        :math:`\\boldsymbol{U}_{in_{2}}`, we compute generalized eigenvectors.\n\n        .. math::\n            \\left({\\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}}^{\\mathsf{H}}\\boldsymbol{U}_{in_{1}}\n            \\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\\right)\\boldsymbol{h}_{i}\n            = \\lambda_{i}\n            \\left({\\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}}^{\\mathsf{H}}\\boldsymbol{U}_{in_{2}}\n            \\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\\right)\\boldsymbol{h}_{i},\n\n        where\n\n        .. math::\n            \\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\n            &= (\\boldsymbol{W}_{i}\\boldsymbol{U}_{in_{1}})^{-1}\n            (\n            \\begin{array}{cc}\n                \\boldsymbol{e}_{n_{1}} & \\boldsymbol{e}_{n_{2}}\n            \\end{array}\n            ), \\\\\n            \\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\n            &= (\\boldsymbol{W}_{i}\\boldsymbol{U}_{in_{2}})^{-1}\n            (\n            \\begin{array}{cc}\n                \\boldsymbol{e}_{n_{1}} & \\boldsymbol{e}_{n_{2}}\n            \\end{array}\n            ).\n\n        After that, we standardize two eigenvectors :math:`\\boldsymbol{h}_{in_{1}}`\n        and :math:`\\boldsymbol{h}_{in_{2}}`.\n\n        .. math::\n            \\boldsymbol{h}_{in_{1}}\n            &\\leftarrow\\frac{\\boldsymbol{h}_{in_{1}}}\n            {\\sqrt{\\boldsymbol{h}_{in_{1}}^{\\mathsf{H}}\n            \\left({\\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}}^{\\mathsf{H}}\\boldsymbol{U}_{in_{1}}\n            \\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\\right)\n            \\boldsymbol{h}_{in_{1}}}}, \\\\\n            \\boldsymbol{h}_{in_{2}}\n            &\\leftarrow\\frac{\\boldsymbol{h}_{in_{2}}}\n            {\\sqrt{\\boldsymbol{h}_{in_{2}}^{\\mathsf{H}}\n            \\left({\\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}}^{\\mathsf{H}}\\boldsymbol{U}_{in_{2}}\n            \\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\\right)\n            \\boldsymbol{h}_{in_{2}}}}.\n\n        Then, update :math:`\\boldsymbol{w}_{in_{1}}` and :math:`\\boldsymbol{w}_{in_{2}}`\n        simultaneously.\n\n        .. math::\n            \\boldsymbol{w}_{in_{1}}\n            &\\leftarrow \\boldsymbol{P}_{in_{1}}^{(n_{1},n_{2})}\\boldsymbol{h}_{in_{1}} \\\\\n            \\boldsymbol{w}_{in_{2}}\n            &\\leftarrow \\boldsymbol{P}_{in_{2}}^{(n_{1},n_{2})}\\boldsymbol{h}_{in_{2}}.\n\n        At each iteration, we update pairs of :math:`n_{1}` and :math:`n_{1}`\n        for :math:`n_{1}\\neq n_{2}`.\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        n_sources = self.n_sources\n\n        X, W = self.input, self.demix_filter\n        R = self.variance\n\n        XX_Hermite = X[:, np.newaxis, :, :] * X[np.newaxis, :, :, :].conj()\n        XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3)\n\n        for m, n in self.pair_selector(n_sources):\n            W_mn = W[:, (m, n), :]\n            Y_mn = self.separate(X, demix_filter=W_mn)\n            R_mn = R[(m, n), :]\n\n            norm = np.linalg.norm(Y_mn, axis=1)\n            weight_mn = self.d_contrast_fn(norm, variance=R_mn) / flooring_fn(2 * norm)\n            GXX_mn = weight_mn[:, np.newaxis, np.newaxis, :] * XX_Hermite[:, np.newaxis, :, :, :]\n            U_mn = np.mean(GXX_mn, axis=-1)\n\n            W[:, (m, n), :] = update_by_ip2_one_pair(\n                W,\n                U_mn,\n                pair=(m, n),\n                flooring_fn=flooring_fn,\n            )\n\n        self.demix_filter = W\n\n    def update_source_model(self) -> None:\n        r\"\"\"Update variance of Gaussian distribution.\"\"\"\n        if self.demix_filter is None:\n            Y = self.output\n        else:\n            X, W = self.input, self.demix_filter\n            Y = self.separate(X, demix_filter=W)\n\n        self.variance = np.mean(np.abs(Y) ** 2, axis=1)\n"
  },
  {
    "path": "ssspy/bss/mnmf.py",
    "content": "import functools\nfrom typing import Callable, Iterable, List, Optional, Tuple, Union\n\nimport numpy as np\n\nfrom ..linalg._solve import solve\nfrom ..linalg.mean import gmeanmh\nfrom ..special.flooring import identity, max_flooring\nfrom ..special.psd import to_psd\nfrom ..utils.flooring import choose_flooring_fn\nfrom ..utils.select_pair import sequential_pair_selector\nfrom ._update_spatial_model import update_by_ip1, update_by_ip2\nfrom .base import IterativeMethodBase\n\n__all__ = [\"GaussMNMF\", \"FastGaussMNMF\"]\n\ndiagonalizer_algorithms = [\"IP\", \"IP1\", \"IP2\"]\nEPS = 1e-10\n\n\nclass MNMFBase(IterativeMethodBase):\n    r\"\"\"Base class of multichannel nonnegative matrix factorization (MNMF).\n\n    Args:\n        n_basis (int):\n            Number of NMF bases.\n        n_sources (int, optional):\n            Number of sources to be separated.\n            If ``None`` is given, ``n_sources`` is determined by number of channels\n            in input spectrogram. Default: ``None``.\n        partitioning (bool):\n            Whether to use partioning function. Default: ``False``.\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        record_loss (bool):\n            Record the loss at each iteration of the update algorithm if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel in multichannel Wiener filter. Default: ``0``.\n        rng (numpy.random.Generator, optioinal):\n            Random number generator. This is mainly used to randomly initialize PSDTF.\n            If ``None`` is given, ``np.random.default_rng()`` is used.\n            Default: ``None``.\n    \"\"\"\n\n    def __init__(\n        self,\n        n_basis: int,\n        n_sources: Optional[int] = None,\n        partitioning: bool = False,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        callbacks: Optional[\n            Union[Callable[[\"MNMFBase\"], None], List[Callable[[\"MNMFBase\"], None]]]\n        ] = None,\n        normalization: bool = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n        rng: Optional[np.random.Generator] = None,\n    ) -> None:\n        super().__init__(callbacks=callbacks, record_loss=record_loss)\n\n        self.n_basis = n_basis\n        self.n_sources = n_sources\n        self.partitioning = partitioning\n\n        if flooring_fn is None:\n            self.flooring_fn = identity\n        else:\n            self.flooring_fn = flooring_fn\n\n        self.normalization = normalization\n\n        self.input = None\n        self.reference_id = reference_id\n\n        if rng is None:\n            rng = np.random.default_rng()\n\n        self.rng = rng\n\n    def __call__(\n        self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs\n    ) -> np.ndarray:\n        r\"\"\"Separate a frequency-domain multichannel signal.\n\n        Args:\n            input (numpy.ndarray):\n                The mixture signal in frequency-domain.\n                The shape is (n_channels, n_bins, n_frames).\n            n_iter (int):\n                The number of iterations of demixing filter updates.\n                Default: ``100``.\n            initial_call (bool):\n                If ``True``, perform callbacks (and computation of loss if necessary)\n                before iterations.\n\n        Returns:\n            numpy.ndarray of the separated signal in frequency-domain.\n            The shape is (n_channels, n_bins, n_frames).\n        \"\"\"\n        self.input = input.copy()\n\n        self._reset(**kwargs)\n\n        super().__call__(n_iter=n_iter, initial_call=initial_call)\n\n        self.output = self.separate(self.input)\n\n        return self.output\n\n    def __repr__(self) -> str:\n        s = \"MNMF(\"\n        s += \"n_basis={n_basis}\"\n\n        if self.n_sources is not None:\n            s += \", n_sources={n_sources}\"\n\n        if hasattr(self, \"n_channels\"):\n            s += \", n_channels={n_channels}\"\n\n        s += \", partitioning={partitioning}\"\n        s += \", normalization={normalization}\"\n        s += \", record_loss={record_loss}\"\n        s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def _reset(self, **kwargs) -> None:\n        r\"\"\"Reset attributes by given keyword arguments.\n\n        Args:\n            kwargs:\n                Keyword arguments to set as attributes of MNMF.\n        \"\"\"\n        assert self.input is not None, \"Specify data!\"\n\n        for key in kwargs.keys():\n            setattr(self, key, kwargs[key])\n\n        X = self.input\n\n        n_sources = self.n_sources\n        n_channels, n_bins, n_frames = X.shape\n\n        if n_sources is None:\n            n_sources = n_channels\n\n        self.n_sources, self.n_channels = n_sources, n_channels\n        self.n_bins, self.n_frames = n_bins, n_frames\n\n        self._init_instant_covariance()\n        self._init_nmf(rng=self.rng)\n\n        self.output = self.separate(X)\n\n    def _init_instant_covariance(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Initialize instantaneous covariance of input.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        X = self.input\n        XX = X[:, np.newaxis] * X[np.newaxis, :].conj()\n        XX = XX.transpose(2, 3, 0, 1)\n        self.instant_covariance = to_psd(XX, flooring_fn=flooring_fn)\n\n    def _init_nmf(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n        rng: Optional[np.random.Generator] = None,\n    ) -> None:\n        r\"\"\"Initialize NMF.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n            rng (numpy.random.Generator, optional):\n                Random number generator. If ``None`` is given,\n                ``np.random.default_rng()`` is used.\n                Default: ``None``.\n        \"\"\"\n        n_basis = self.n_basis\n        n_sources = self.n_sources\n        n_bins, n_frames = self.n_bins, self.n_frames\n\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if rng is None:\n            rng = np.random.default_rng()\n\n        if self.partitioning:\n            if not hasattr(self, \"basis\"):\n                T = rng.random((n_bins, n_basis))\n                T = flooring_fn(T)\n            else:\n                # To avoid overwriting.\n                T = self.basis.copy()\n\n            if not hasattr(self, \"activation\"):\n                V = rng.random((n_basis, n_frames))\n                V = flooring_fn(V)\n            else:\n                # To avoid overwriting.\n                V = self.activation.copy()\n\n            if not hasattr(self, \"latent\"):\n                Z = rng.random((n_sources, n_basis))\n                Z = Z / Z.sum(axis=0)\n                Z = flooring_fn(Z)\n            else:\n                # To avoid overwriting.\n                Z = self.latent.copy()\n\n            self.basis, self.activation = T, V\n            self.latent = Z\n        else:\n            if not hasattr(self, \"basis\"):\n                T = rng.random((n_sources, n_bins, n_basis))\n                T = flooring_fn(T)\n            else:\n                # To avoid overwriting.\n                T = self.basis.copy()\n\n            if not hasattr(self, \"activation\"):\n                V = rng.random((n_sources, n_basis, n_frames))\n                V = flooring_fn(V)\n            else:\n                # To avoid overwriting.\n                V = self.activation.copy()\n\n            self.basis, self.activation = T, V\n\n    def separate(self, input: np.ndarray) -> np.ndarray:\n        raise NotImplementedError(\"Implement 'separate' method.\")\n\n    def reconstruct_nmf(\n        self,\n        basis: np.ndarray,\n        activation: np.ndarray,\n        latent: Optional[np.ndarray] = None,\n    ) -> np.ndarray:\n        r\"\"\"Reconstruct single-channel NMF.\n\n        Args:\n            basis (numpy.ndarray):\n                Basis matrix.\n                The shape is (n_sources, n_basis, n_frames) if latent is given.\n                Otherwise, (n_basis, n_frames).\n            activation (numpy.ndarray):\n                Activation matrix.\n                The shape is (n_sources, n_bins, n_basis) if latent is given.\n                Otherwise, (n_bins, n_basis).\n            latent (numpy.ndarray, optional):\n                Latent variable that determines number of bases per source.\n\n        Returns:\n            numpy.ndarray of reconstructed single-channel NMF.\n            The shape is (n_sources, n_bins, n_frames).\n        \"\"\"\n        if latent is None:\n            T, V = basis, activation\n            Lamb = T @ V\n        else:\n            Z = latent\n            T, V = basis, activation\n            TV = T[:, :, np.newaxis] * V[np.newaxis, :, :]\n            Lamb = np.sum(Z[:, np.newaxis, :, np.newaxis] * TV[np.newaxis, :, :, :], axis=2)\n\n        return Lamb\n\n\nclass MNMF(MNMFBase):\n    def __init__(\n        self,\n        n_basis: int,\n        n_sources: Optional[int] = None,\n        partitioning: bool = False,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        callbacks: Optional[Union[Callable[[\"MNMF\"], None], List[Callable[[\"MNMF\"], None]]]] = None,\n        normalization: bool = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n        rng: Optional[np.random.Generator] = None,\n    ) -> None:\n        super().__init__(\n            n_basis,\n            n_sources=n_sources,\n            partitioning=partitioning,\n            flooring_fn=flooring_fn,\n            callbacks=callbacks,\n            normalization=normalization,\n            record_loss=record_loss,\n            reference_id=reference_id,\n            rng=rng,\n        )\n\n    def _init_nmf(self, rng: Optional[np.random.Generator] = None) -> None:\n        r\"\"\"Initialize NMF.\n\n        Args:\n            rng (numpy.random.Generator, optional):\n                Random number generator. If ``None`` is given,\n                ``np.random.default_rng()`` is used.\n                Default: ``None``.\n        \"\"\"\n        if rng is None:\n            rng = np.random.default_rng()\n\n        super()._init_nmf(rng=rng)\n\n        n_sources, n_channels = self.n_sources, self.n_channels\n        n_bins = self.n_bins\n\n        if not hasattr(self, \"spatial\"):\n            H = np.eye(n_channels, dtype=self.input.dtype)\n            trace = np.trace(H, axis1=-2, axis2=-1)\n            H = H / np.real(trace)\n            H = np.tile(H, reps=(n_sources, n_bins, 1, 1))\n        else:\n            # To avoid overwriting.\n            H = self.spatial.copy()\n\n        self.spatial = H\n\n    def reconstruct_mnmf(\n        self,\n        basis: np.ndarray,\n        activation: np.ndarray,\n        spatial: np.ndarray,\n        latent: Optional[np.ndarray] = None,\n    ) -> np.ndarray:\n        r\"\"\"Reconstruct multichannel NMF.\n\n        Args:\n            basis (numpy.ndarray):\n                Basis matrix with shape of (n_bins, n_basis).\n            activation (numpy.ndarray):\n                Activation matrix with shape of (n_basis, n_frames).\n            spatial (numpy.ndarray):\n                Spatial property with shape of (n_sources, n_bins, n_channels, n_channels).\n            latent (numpy.ndarray, optional):\n                Latent variables with shape of (n_sources, n_basis).\n\n        Returns:\n            numpy.ndarray of reconstructed multichannel NMF.\n            The shape is (n_bins, n_frames, n_channels, n_channels).\n        \"\"\"\n        T, V = basis, activation\n        H = spatial\n\n        if latent is None:\n            Lamb = self.reconstruct_nmf(T, V)\n        else:\n            Lamb = self.reconstruct_nmf(T, V, latent=latent)\n\n        R_n = Lamb[:, :, :, np.newaxis, np.newaxis] * H[:, :, np.newaxis, :, :]\n        R = np.sum(R_n, axis=0)\n\n        return R\n\n    def normalize(self, axis1=-2, axis2=-1) -> None:\n        r\"\"\"Ensure unit trace of spatial property of MNMF.\"\"\"\n        H = self.spatial\n        n_dims = H.ndim\n\n        axis1 = n_dims + axis1 if axis1 < 0 else axis1\n        axis2 = n_dims + axis2 if axis2 < 0 else axis2\n\n        assert axis1 == 2 and axis2 == 3\n\n        trace = np.trace(H, axis1=axis1, axis2=axis2)\n        trace = np.real(trace)\n        H = H / trace[..., np.newaxis, np.newaxis]\n\n        if self.partitioning:\n            # When self.partitioning=True,\n            # normalization may change value of cost function\n            pass\n        else:\n            T = self.basis\n            T = trace[:, :, np.newaxis] * T\n            self.basis = T\n\n        self.spatial = H\n\n\nclass FastMNMFBase(MNMFBase):\n    r\"\"\"Base class of fast multichannel nonnegative matrix factorization (Fast MNMF).\n\n    Args:\n        n_basis (int):\n            Number of NMF bases.\n        n_sources (int, optional):\n            Number of sources to be separated.\n            If ``None`` is given, ``n_sources`` is determined by number of channels\n            in input spectrogram. Default: ``None``.\n        partitioning (bool):\n            Whether to use partioning function. Default: ``False``.\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n        normalization (bool or str):\n            Normalization of diagonalizers and diagonal elements of spatial covariance matrices.\n            Only power-based normalization is supported. Default: ``True``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        record_loss (bool):\n            Record the loss at each iteration of the update algorithm if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel in multichannel Wiener filter. Default: ``0``.\n        rng (numpy.random.Generator, optioinal):\n            Random number generator. This is mainly used to randomly initialize PSDTF.\n            If ``None`` is given, ``np.random.default_rng()`` is used.\n            Default: ``None``.\n    \"\"\"\n\n    def __init__(\n        self,\n        n_basis: int,\n        n_sources: Optional[int] = None,\n        partitioning: bool = False,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        callbacks: Optional[\n            Union[Callable[[\"FastMNMFBase\"], None], List[Callable[[\"FastMNMFBase\"], None]]]\n        ] = None,\n        normalization: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n        rng: Optional[np.random.Generator] = None,\n    ) -> None:\n        super().__init__(\n            n_basis,\n            n_sources=n_sources,\n            partitioning=partitioning,\n            flooring_fn=flooring_fn,\n            callbacks=callbacks,\n            normalization=normalization,\n            record_loss=record_loss,\n            reference_id=reference_id,\n            rng=rng,\n        )\n\n    def __repr__(self) -> str:\n        s = \"FastMNMF(\"\n        s += \"n_basis={n_basis}\"\n\n        if self.n_sources is not None:\n            s += \", n_sources={n_sources}\"\n\n        if hasattr(self, \"n_channels\"):\n            s += \", n_channels={n_channels}\"\n\n        s += \", partitioning={partitioning}\"\n        s += \", normalization={normalization}\"\n        s += \", record_loss={record_loss}\"\n        s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def _reset(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n        **kwargs,\n    ) -> None:\n        r\"\"\"Reset attributes by given keyword arguments.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n            kwargs:\n                Keyword arguments to set as attributes of MNMF.\n        \"\"\"\n        assert self.input is not None, \"Specify data!\"\n\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        for key in kwargs.keys():\n            setattr(self, key, kwargs[key])\n\n        X = self.input\n\n        n_sources = self.n_sources\n        n_channels, n_bins, n_frames = X.shape\n\n        if n_sources is None:\n            n_sources = n_channels\n\n        self.n_sources, self.n_channels = n_sources, n_channels\n        self.n_bins, self.n_frames = n_bins, n_frames\n\n        self._init_instant_covariance(flooring_fn=flooring_fn)\n        self._init_nmf(flooring_fn=flooring_fn, rng=self.rng)\n        self._init_diagonalizer(rng=self.rng)\n        self._init_spatial(flooring_fn=flooring_fn, rng=self.rng)\n\n        self.output = self.separate(X)\n\n    def _init_diagonalizer(self, rng: Optional[np.random.Generator] = None) -> None:\n        \"\"\"Initialize diagonalizer.\n\n        Args:\n            rng (numpy.random.Generator, optional):\n                Random number generator. If ``None`` is given,\n                ``np.random.default_rng()`` is used.\n                Default: ``None``.\n        \"\"\"\n        n_channels = self.n_channels\n        n_bins = self.n_bins\n\n        if rng is None:\n            rng = np.random.default_rng()\n\n        if not hasattr(self, \"diagonalizer\"):\n            Q = np.eye(n_channels, dtype=np.complex128)\n            Q = np.tile(Q, reps=(n_bins, 1, 1))\n        else:\n            # To avoid overwriting.\n            Q = self.diagonalizer.copy()\n\n        self.diagonalizer = Q\n\n    def _init_spatial(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n        rng: Optional[np.random.Generator] = None,\n    ) -> None:\n        \"\"\"Initialize diagonal elements of spatial covariance matrices.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n            rng (numpy.random.Generator, optional):\n                Random number generator. If ``None`` is given,\n                ``np.random.default_rng()`` is used.\n                Default: ``None``.\n        \"\"\"\n        n_sources, n_channels = self.n_sources, self.n_channels\n        n_bins = self.n_bins\n\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if rng is None:\n            rng = np.random.default_rng()\n\n        if not hasattr(self, \"spatial\"):\n            D = rng.random((n_bins, n_sources, n_channels))\n            D = flooring_fn(D)\n        else:\n            D = self.spatial\n\n        self.spatial = D\n\n    def normalize(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Normalize diagonalizers and diagonal elements of spatial covariance matrices.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        normalization = self.normalization\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        assert normalization, \"Set normalization.\"\n\n        if type(normalization) is bool:\n            # when normalization is True\n            normalization = \"power\"\n\n        if normalization == \"power\":\n            self.normalize_by_power(flooring_fn=flooring_fn)\n        else:\n            raise NotImplementedError(\"Normalization {} is not implemented.\".format(normalization))\n\n    def normalize_by_power(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Normalize diagonalizers and diagonal elements of spatial covariance matrices by power.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        Diagonalizers are normalized by\n\n        .. math::\n            \\boldsymbol{q}_{im}\n            \\leftarrow\\frac{\\boldsymbol{q}_{im}}{\\psi_{im}},\n\n        where\n\n        .. math::\n            \\psi_{im}\n            = \\sqrt{\\frac{1}{IJ}\\sum_{i,j}|\\boldsymbol{q}_{im}^{\\mathsf{H}}\n            \\boldsymbol{x}_{ij}|^{2}}.\n\n        For diagonal elements of spatial covariance matrices,\n\n        .. math::\n            d_{inm}\n            \\leftarrow\\frac{d_{inm}}{\\psi_{im}^{2}}.\n        \"\"\"\n        X = self.input\n        Q, D = self.diagonalizer, self.spatial\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        QX = Q @ X.transpose(1, 0, 2)\n        QX2 = np.mean(np.abs(QX) ** 2, axis=(0, 2))\n        psi = np.sqrt(QX2)\n        psi = flooring_fn(psi)\n\n        Q = Q / psi[np.newaxis, :, np.newaxis]\n        D = D / (psi**2)\n\n        self.diagonalizer, self.spatial = Q, D\n\n\nclass GaussMNMF(MNMF):\n    def __init__(\n        self,\n        n_basis: int,\n        n_sources: Optional[int] = None,\n        partitioning: bool = False,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        callbacks: Optional[\n            Union[Callable[[\"GaussMNMF\"], None], List[Callable[[\"GaussMNMF\"], None]]]\n        ] = None,\n        normalization: Union[bool, str] = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n        rng: Optional[np.random.Generator] = None,\n    ) -> None:\n        super().__init__(\n            n_basis,\n            n_sources=n_sources,\n            partitioning=partitioning,\n            flooring_fn=flooring_fn,\n            callbacks=callbacks,\n            normalization=normalization,\n            record_loss=record_loss,\n            reference_id=reference_id,\n            rng=rng,\n        )\n\n    def __repr__(self) -> str:\n        s = \"GaussMNMF(\"\n        s += \"n_basis={n_basis}\"\n\n        if self.n_sources is not None:\n            s += \", n_sources={n_sources}\"\n\n        if hasattr(self, \"n_channels\"):\n            s += \", n_channels={n_channels}\"\n\n        s += \", partitioning={partitioning}\"\n        s += \", normalization={normalization}\"\n        s += \", record_loss={record_loss}\"\n        s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def separate(self, input: np.ndarray) -> np.ndarray:\n        \"\"\"Separate ``input`` using multichannel Wiener filter.\n\n        Args:\n            input (numpy.ndarray):\n                The mixture signal in frequency-domain.\n                The shape is (n_channels, n_bins, n_frames).\n\n        Returns:\n            numpy.ndarray of the separated signal in frequency-domain.\n            The shape is (n_sources, n_bins, n_frames).\n        \"\"\"\n        n_sources = self.n_sources\n        reference_id = self.reference_id\n\n        X = input\n        T, V = self.basis, self.activation\n        H = self.spatial\n\n        if self.partitioning:\n            Lamb = self.reconstruct_nmf(T, V, latent=self.latent)\n        else:\n            Lamb = self.reconstruct_nmf(T, V)\n\n        R_n = Lamb[:, :, :, np.newaxis, np.newaxis] * H[:, :, np.newaxis, :, :]\n        R = np.sum(R_n, axis=0)\n        R = to_psd(R, flooring_fn=self.flooring_fn)\n        R = np.tile(R, reps=(n_sources, 1, 1, 1, 1))\n        W_Hermite = solve(R, R_n)\n        W = W_Hermite.transpose(0, 1, 2, 4, 3).conj()\n        W_ref = W[:, :, :, reference_id, :]\n        W_ref = W_ref.transpose(0, 3, 1, 2)\n        Y = np.sum(W_ref * X, axis=1)\n\n        return Y\n\n    def compute_loss(self) -> float:\n        r\"\"\"Compute loss :math:`\\mathcal{L}`.\n\n        Returns:\n            Computed loss.\n        \"\"\"\n        XX = self.instant_covariance\n        T, V = self.basis, self.activation\n        H = self.spatial\n\n        if self.partitioning:\n            R = self.reconstruct_mnmf(T, V, H, latent=self.latent)\n        else:\n            R = self.reconstruct_mnmf(T, V, H)\n\n        R = to_psd(R, flooring_fn=self.flooring_fn)\n        XXR_inv = solve(R, XX)  # Hermitian transpose of XX @ np.linalg.inv(R)\n        trace = np.trace(XXR_inv, axis1=-2, axis2=-1)\n        trace = np.real(trace)\n        logdet = self.compute_logdet(R)\n        loss = np.mean(trace + logdet, axis=-1)\n        loss = loss.sum(axis=0)\n        loss = loss.item()\n\n        return loss\n\n    def compute_logdet(self, reconstructed: np.ndarray) -> np.ndarray:\n        r\"\"\"Compute log-determinant.\n\n        Args:\n            reconstructed:\n                Reconstructed MNMF with shape of (\\*, n_channels, n_channels).\n\n        Returns:\n            numpy.ndarray of computed log-determinant values.\n            The shape is (\\*).\n        \"\"\"\n        _, logdet = np.linalg.slogdet(reconstructed)\n\n        return logdet\n\n    def update_once(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update MNMF parameters once.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        self.update_basis(flooring_fn=flooring_fn)\n        self.update_activation(flooring_fn=flooring_fn)\n        self.update_spatial(flooring_fn=flooring_fn)\n\n        if self.normalization:\n            # ensure unit trace of spatial property\n            # before updates of latent variables in MNMF\n            self.normalize(axis1=-2, axis2=-1)\n\n        if self.partitioning:\n            self.update_latent(flooring_fn=flooring_fn)\n\n    def update_basis(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update NMF bases by MM algorithm.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        n_sources = self.n_sources\n        n_frames = self.n_frames\n        na = np.newaxis\n\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        def _compute_traces(\n            target: np.ndarray, reconstructed: np.ndarray, spatial: np.ndarray\n        ) -> np.ndarray:\n            RXX = solve(reconstructed, target)\n            R = np.tile(reconstructed, reps=(n_sources, 1, 1, 1, 1))\n            H = np.tile(spatial[:, :, na, :, :], reps=(1, 1, n_frames, 1, 1))\n            RH = solve(R, H)\n\n            trace_RXXRH = np.trace(RXX @ RH, axis1=-2, axis2=-1)\n            trace_RXXRH = np.real(trace_RXXRH)\n            trace_RH = np.trace(RH, axis1=-2, axis2=-1)\n            trace_RH = np.real(trace_RH)\n\n            return trace_RXXRH, trace_RH\n\n        XX = self.instant_covariance\n        T, V = self.basis, self.activation\n        H = self.spatial\n\n        if self.partitioning:\n            Z = self.latent\n            R = self.reconstruct_mnmf(T, V, H, latent=Z)\n            R = to_psd(R, flooring_fn=flooring_fn)\n\n            trace_RXXRH, trace_RH = _compute_traces(XX, R, spatial=H)\n\n            VRXXRH = np.sum(V[na, na, :] * trace_RXXRH[:, :, na], axis=-1)\n            VRH = np.sum(V[na, na, :] * trace_RH[:, :, na], axis=-1)\n\n            num = np.sum(Z[:, na, :] * VRXXRH, axis=0)\n            denom = np.sum(Z[:, na, :] * VRH, axis=0)\n        else:\n            R = self.reconstruct_mnmf(T, V, H)\n            R = to_psd(R, flooring_fn=flooring_fn)\n\n            trace_RXXRH, trace_RH = _compute_traces(XX, R, spatial=H)\n\n            num = np.sum(V[:, na, :, :] * trace_RXXRH[:, :, na, :], axis=-1)\n            denom = np.sum(V[:, na, :, :] * trace_RH[:, :, na, :], axis=-1)\n\n        T = T * np.sqrt(num / denom)\n        T = flooring_fn(T)\n\n        self.basis = T\n\n    def update_activation(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update NMF activations by MM algorithm.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        n_sources = self.n_sources\n        n_frames = self.n_frames\n        na = np.newaxis\n\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        def _compute_traces(\n            target: np.ndarray, reconstructed: np.ndarray, spatial: np.ndarray\n        ) -> np.ndarray:\n            RXX = solve(reconstructed, target)\n            R = np.tile(reconstructed, reps=(n_sources, 1, 1, 1, 1))\n            H = np.tile(spatial[:, :, na, :, :], reps=(1, 1, n_frames, 1, 1))\n            RH = solve(R, H)\n\n            trace_RXXRH = np.trace(RXX @ RH, axis1=-2, axis2=-1)\n            trace_RXXRH = np.real(trace_RXXRH)\n            trace_RH = np.trace(RH, axis1=-2, axis2=-1)\n            trace_RH = np.real(trace_RH)\n\n            return trace_RXXRH, trace_RH\n\n        XX = self.instant_covariance\n        T, V = self.basis, self.activation\n        H = self.spatial\n\n        if self.partitioning:\n            Z = self.latent\n            R = self.reconstruct_mnmf(T, V, H, latent=Z)\n            R = to_psd(R, flooring_fn=flooring_fn)\n\n            trace_RXXRH, trace_RH = _compute_traces(XX, R, spatial=H)\n\n            TRXXRH = np.sum(T[na, :, :, na] * trace_RXXRH[:, :, na, :], axis=1)\n            TRH = np.sum(T[na, :, :, na] * trace_RH[:, :, na, :], axis=1)\n\n            num = np.sum(Z[:, :, na] * TRXXRH, axis=0)\n            denom = np.sum(Z[:, :, na] * TRH, axis=0)\n        else:\n            R = self.reconstruct_mnmf(T, V, H)\n            R = to_psd(R, flooring_fn=flooring_fn)\n\n            trace_RXXRH, trace_RH = _compute_traces(XX, R, spatial=H)\n\n            num = np.sum(T[:, :, :, na] * trace_RXXRH[:, :, na, :], axis=1)\n            denom = np.sum(T[:, :, :, na] * trace_RH[:, :, na, :], axis=1)\n\n        V = V * np.sqrt(num / denom)\n        V = flooring_fn(V)\n\n        self.activation = V\n\n    def update_spatial(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update spatial properties in NMF by MM algorithm.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        na = np.newaxis\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        XX = self.instant_covariance\n        T, V = self.basis, self.activation\n        H = self.spatial\n\n        if self.partitioning:\n            Z = self.latent\n            Lamb = self.reconstruct_nmf(T, V, latent=Z)\n        else:\n            Lamb = self.reconstruct_nmf(T, V)\n\n        R_n = Lamb[:, :, :, na, na] * H[:, :, na, :, :]\n        R = np.sum(R_n, axis=0)\n        R = to_psd(R, flooring_fn=flooring_fn)\n        R_inverse = np.linalg.inv(R)\n        RXXR = R_inverse @ XX @ R_inverse\n\n        P = np.sum(Lamb[:, :, :, na, na] * R_inverse, axis=2)\n        Q = np.sum(Lamb[:, :, :, na, na] * RXXR, axis=2)\n        HQH = H @ Q @ H\n\n        P = to_psd(P, flooring_fn=flooring_fn)\n        HQH = to_psd(HQH, flooring_fn=flooring_fn)\n\n        # geometric mean of P^(-1) and HQH\n        H = gmeanmh(P, HQH, type=2)\n        H = to_psd(H, flooring_fn=flooring_fn)\n\n        self.spatial = H\n\n    def update_latent(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update latent variables in NMF by MM algorithm.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        n_sources = self.n_sources\n        n_frames = self.n_frames\n        na = np.newaxis\n\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        def _compute_traces(\n            target: np.ndarray, reconstructed: np.ndarray, spatial: np.ndarray\n        ) -> np.ndarray:\n            RXX = solve(reconstructed, target)\n            R = np.tile(reconstructed, reps=(n_sources, 1, 1, 1, 1))\n            H = np.tile(spatial[:, :, na, :, :], reps=(1, 1, n_frames, 1, 1))\n            RH = solve(R, H)\n\n            trace_RXXRH = np.trace(RXX @ RH, axis1=-2, axis2=-1)\n            trace_RXXRH = np.real(trace_RXXRH)\n            trace_RH = np.trace(RH, axis1=-2, axis2=-1)\n            trace_RH = np.real(trace_RH)\n\n            return trace_RXXRH, trace_RH\n\n        XX = self.instant_covariance\n        T, V = self.basis, self.activation\n        H, Z = self.spatial, self.latent\n\n        R = self.reconstruct_mnmf(T, V, H, latent=Z)\n        R = to_psd(R, flooring_fn=flooring_fn)\n\n        trace_RXXRH, trace_RH = _compute_traces(XX, R, spatial=H)\n\n        VRXXRH = np.sum(V[na, na, :] * trace_RXXRH[:, :, na], axis=-1)\n        VRH = np.sum(V[na, na, :] * trace_RH[:, :, na], axis=-1)\n\n        num = np.sum(T * VRXXRH, axis=1)\n        denom = np.sum(T * VRH, axis=1)\n\n        Z = Z * np.sqrt(num / denom)\n        Z = Z / Z.sum(axis=0)\n\n        self.latent = Z\n\n\nclass FastGaussMNMF(FastMNMFBase):\n    r\"\"\"Fast multichannel nonnegative matrix factorization on Gaussian distribution \\\n    (Fast Gauss-MNMF).\n\n    Args:\n        n_basis (int):\n            Number of NMF bases.\n        n_sources (int, optional):\n            Number of sources to be separated.\n            If ``None`` is given, ``n_sources`` is determined by number of channels\n            in input spectrogram. Default: ``None``.\n        diagonalizer_algorithm (str):\n            Algorithm for diagonalizers. Choose ``IP``, ``IP1``, or ``IP2``.\n            Default: ``IP``.\n        partitioning (bool):\n            Whether to use partioning function. Default: ``False``.\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        record_loss (bool):\n            Record the loss at each iteration of the update algorithm if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel in multichannel Wiener filter. Default: ``0``.\n        rng (numpy.random.Generator, optioinal):\n            Random number generator. This is mainly used to randomly initialize PSDTF.\n            If ``None`` is given, ``np.random.default_rng()`` is used.\n            Default: ``None``.\n    \"\"\"\n\n    def __init__(\n        self,\n        n_basis: int,\n        n_sources: Optional[int] = None,\n        diagonalizer_algorithm: str = \"IP\",\n        partitioning: bool = False,\n        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n            max_flooring, eps=EPS\n        ),\n        pair_selector: Optional[Callable[[int], Iterable[Tuple[int, int]]]] = None,\n        callbacks: Optional[\n            Union[Callable[[\"FastGaussMNMF\"], None], List[Callable[[\"FastGaussMNMF\"], None]]]\n        ] = None,\n        normalization: bool = True,\n        record_loss: bool = True,\n        reference_id: int = 0,\n        rng: Optional[np.random.Generator] = None,\n    ) -> None:\n        super().__init__(\n            n_basis,\n            n_sources=n_sources,\n            partitioning=partitioning,\n            flooring_fn=flooring_fn,\n            callbacks=callbacks,\n            normalization=normalization,\n            record_loss=record_loss,\n            reference_id=reference_id,\n            rng=rng,\n        )\n\n        assert diagonalizer_algorithm in diagonalizer_algorithms, \"Not support {}.\".format(\n            diagonalizer_algorithm\n        )\n        assert not partitioning, \"partitioning function is not supported.\"\n\n        self.diagonalizer_algorithm = diagonalizer_algorithm\n\n        if pair_selector is None:\n            if diagonalizer_algorithm == \"IP2\":\n                self.pair_selector = sequential_pair_selector\n        else:\n            self.pair_selector = pair_selector\n\n    def __repr__(self) -> str:\n        s = \"FastGaussMNMF(\"\n        s += \"n_basis={n_basis}\"\n\n        if self.n_sources is not None:\n            s += \", n_sources={n_sources}\"\n\n        if hasattr(self, \"n_channels\"):\n            s += \", n_channels={n_channels}\"\n\n        s += \", diagonalizer_algorithm={diagonalizer_algorithm}\"\n        s += \", partitioning={partitioning}\"\n        s += \", record_loss={record_loss}\"\n        s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def separate(self, input: np.ndarray) -> np.ndarray:\n        \"\"\"Separate ``input`` using multichannel Wiener filter.\n\n        Args:\n            input (numpy.ndarray):\n                The mixture signal in frequency-domain.\n                The shape is (n_channels, n_bins, n_frames).\n\n        Returns:\n            numpy.ndarray of the separated signal in frequency-domain.\n            The shape is (n_sources, n_bins, n_frames).\n        \"\"\"\n        na = np.newaxis\n        n_sources = self.n_sources\n        reference_id = self.reference_id\n\n        X = input\n        T, V = self.basis, self.activation\n        Q, D = self.diagonalizer, self.spatial\n\n        if self.partitioning:\n            Lamb = self.reconstruct_nmf(T, V, latent=self.latent)\n        else:\n            Lamb = self.reconstruct_nmf(T, V)\n\n        D = D.transpose(1, 0, 2)\n\n        Q_inverse = np.linalg.inv(Q)\n        Q_inverse_Hermite = Q_inverse.transpose(0, 2, 1).conj()\n        QQ_Hermite = Q_inverse[:, :, :, na] * Q_inverse_Hermite[:, na, :, :]\n\n        LambD = Lamb[:, :, :, na] * D[:, :, na, :]\n\n        R_n = np.sum(LambD[:, :, :, na, :, na] * QQ_Hermite[:, na, :, :, :], axis=4)\n        R = np.sum(R_n, axis=0)\n        R = to_psd(R, flooring_fn=self.flooring_fn)\n        R = np.tile(R, reps=(n_sources, 1, 1, 1, 1))\n        W_Hermite = solve(R, R_n)\n        W = W_Hermite.transpose(0, 1, 2, 4, 3).conj()\n        W_ref = W[:, :, :, reference_id, :]\n        W_ref = W_ref.transpose(0, 3, 1, 2)\n        Y = np.sum(W_ref * X, axis=1)\n\n        return Y\n\n    def compute_loss(self) -> float:\n        r\"\"\"Compute loss :math:`\\mathcal{L}`.\n\n        :math:`\\mathcal{L}` is defined as follows:\n\n        .. math::\n            \\mathcal{L}\n            &:=-\\frac{1}{J}\\sum_{i,j}\\left\\{\n            \\mathrm{tr}\\left(\n            \\boldsymbol{x}_{ij}\\boldsymbol{x}_{ij}^{\\mathsf{H}}\\boldsymbol{R}_{ij}^{-1}\n            \\right)\n            - \\log\\det\\boldsymbol{R}_{ij}\n            \\right\\} \\\\\n            &:=\\frac{1}{J}\\sum_{i,j,m}\\left\\{\n            \\frac{|\\boldsymbol{q}_{im}^{\\mathsf{H}}\\boldsymbol{x}_{ij}|^{2}}\n            {\\sum_{n}\\lambda_{ijn}d_{inm}}\n            + \\log\\sum_{n}\\lambda_{ijn}d_{inm}\\right\\}\n            - 2\\sum_{i}\\log|\\det\\boldsymbol{Q}_{i}|.\n\n        Returns:\n            Computed loss.\n        \"\"\"\n        X = self.input\n        T, V = self.basis, self.activation\n        Q, D = self.diagonalizer, self.spatial\n        na = np.newaxis\n\n        if self.partitioning:\n            Lamb = self.reconstruct_nmf(T, V, latent=self.latent)\n        else:\n            Lamb = self.reconstruct_nmf(T, V)\n\n        D = D.transpose(1, 0, 2)\n        LambD = np.sum(Lamb[:, :, na, :] * D[:, :, :, na], axis=0)\n        QX = Q @ X.transpose(1, 0, 2)\n        QX2 = np.abs(QX) ** 2\n        logdetQ = self.compute_logdet(Q)\n        loss = np.sum(QX2 / LambD + np.log(LambD), axis=1)\n        loss = np.mean(loss, axis=-1) - 2 * logdetQ\n        loss = loss.sum(axis=0)\n        loss = loss.item()\n\n        return loss\n\n    def compute_logdet(self, diagonalizer: np.ndarray) -> np.ndarray:\n        r\"\"\"Compute log-determinant.\n\n        Args:\n            reconstructed:\n                Diagonalizer with shape of (\\*, n_channels, n_channels).\n\n        Returns:\n            numpy.ndarray of computed log-determinant values.\n            The shape is (\\*).\n        \"\"\"\n        _, logdet = np.linalg.slogdet(diagonalizer)\n\n        return logdet\n\n    def update_once(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update MNMF parameters, diagonalizers, and diagonal elements of \\\n        spatial covariance matrices once.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        self.update_basis(flooring_fn=flooring_fn)\n        self.update_activation(flooring_fn=flooring_fn)\n        self.update_diagonalizer(flooring_fn=flooring_fn)\n        self.update_spatial()\n\n        if self.normalization:\n            self.normalize(flooring_fn=flooring_fn)\n\n    def update_basis(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update NMF bases by MM algorithm.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        Update :math:`t_{ikn}` as follows:\n\n        .. math::\n            t_{ikn}\n            \\leftarrow\\left[\n            \\frac{\\displaystyle\\sum_{j,m}\\frac{|\\boldsymbol{q}_{im}^{\\mathsf{H}}\\boldsymbol{x}_{ij}|^{2}d_{inm}v_{kjn}}\n            {\\left(\\sum_{k',n'}t_{ik'n'}v_{k'jn'}d_{in'm}\\right)^{2}}}\n            {\\displaystyle\\sum_{j,m}\\dfrac{d_{inm}v_{kjn}}{\\sum_{k',n'}t_{ik'n'}v_{k'jn'}d_{in'm}}}\n            \\right]^{\\frac{1}{2}}t_{ikn}.\n        \"\"\"\n        assert not self.partitioning, \"partitioning function is not supported.\"\n\n        na = np.newaxis\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        X = self.input\n        T, V = self.basis, self.activation\n        Q, D = self.diagonalizer, self.spatial\n\n        if self.partitioning:\n            Lamb = self.reconstruct_nmf(T, V, latent=self.latent)\n        else:\n            Lamb = self.reconstruct_nmf(T, V)\n\n        D = D.transpose(1, 0, 2)\n        LambD = Lamb[:, :, :, na] * D[:, :, na, :]\n        LambD = np.sum(LambD, axis=0)\n        QX = Q @ X.transpose(1, 0, 2)\n        QX = np.abs(QX)\n        QX = QX.transpose(0, 2, 1)\n        QXLambD = (QX / LambD) ** 2\n        DQXLambD = np.sum(D[:, :, na, :] * QXLambD, axis=-1)\n        DLambD = np.sum(D[:, :, na, :] / LambD, axis=-1)\n\n        num = np.sum(V[:, na, :] * DQXLambD[:, :, na], axis=-1)\n        denom = np.sum(V[:, na, :] * DLambD[:, :, na], axis=-1)\n\n        T = T * np.sqrt(num / denom)\n        T = flooring_fn(T)\n\n        self.basis = T\n\n    def update_activation(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update NMF activations by MM algorithm.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        Update :math:`v_{kjn}` as follows:\n\n        .. math::\n            v_{kjn}\n            \\leftarrow\\left[\n            \\frac{\\displaystyle\\sum_{i,m}\\frac{|\\boldsymbol{q}_{im}^{\\mathsf{H}}\\boldsymbol{x}_{ij}|^{2}d_{inm}t_{ikn}}\n            {\\left(\\sum_{k',n'}t_{ik'n'}v_{k'jn'}d_{in'm}\\right)^{2}}}\n            {\\displaystyle\\sum_{i,m}\\dfrac{d_{inm}t_{ikn}}{\\sum_{k',n'}t_{ik'n'}v_{k'jn'}d_{in'm}}}\n            \\right]^{\\frac{1}{2}}v_{kjn}.\n        \"\"\"\n        assert not self.partitioning, \"partitioning function is not supported.\"\n\n        na = np.newaxis\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        X = self.input\n        T, V = self.basis, self.activation\n        Q, D = self.diagonalizer, self.spatial\n\n        if self.partitioning:\n            Lamb = self.reconstruct_nmf(T, V, latent=self.latent)\n        else:\n            Lamb = self.reconstruct_nmf(T, V)\n\n        D = D.transpose(1, 0, 2)\n        LambD = Lamb[:, :, :, na] * D[:, :, na, :]\n        LambD = np.sum(LambD, axis=0)\n        QX = Q @ X.transpose(1, 0, 2)\n        QX = np.abs(QX)\n        QX = QX.transpose(0, 2, 1)\n        QXLambD = (QX / LambD) ** 2\n        DQXLambD = np.sum(D[:, :, na, :] * QXLambD, axis=-1)\n        DLambD = np.sum(D[:, :, na, :] / LambD, axis=-1)\n\n        num = np.sum(T[:, :, :, na] * DQXLambD[:, :, na, :], axis=1)\n        denom = np.sum(T[:, :, :, na] * DLambD[:, :, na, :], axis=1)\n\n        V = V * np.sqrt(num / denom)\n        V = flooring_fn(V)\n\n        self.activation = V\n\n    def update_diagonalizer(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        \"\"\"Update diagonalizer.\n\n        - If ``diagonalizer_algorithm`` is ``IP`` or ``IP1``, \\\n            ``update_diagonalizer_model_ip1`` is called.\n        - If ``diagonalizer_algorithm`` is ``IP2``, \\\n            ``update_diagonalizer_model_ip2`` is called.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        \"\"\"\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        if self.diagonalizer_algorithm in [\"IP\", \"IP1\"]:\n            self.update_diagonalizer_ip1(flooring_fn=flooring_fn)\n        elif self.diagonalizer_algorithm in [\"IP2\"]:\n            self.update_diagonalizer_ip2(flooring_fn=flooring_fn)\n        else:\n            raise NotImplementedError(\"Not support {}.\".format(self.diagonalizer_algorithm))\n\n    def update_diagonalizer_ip1(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update diagonalizer once using iterative projection.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        Diagonalizers are updated sequentially for :math:`m=1,\\ldots,M` as follows:\n\n        .. math::\n            \\boldsymbol{q}_{im}\n            &\\leftarrow\\left(\\boldsymbol{Q}_{im}^{\\mathsf{H}}\\boldsymbol{U}_{im}\\right)^{-1} \\\n            \\boldsymbol{e}_{m}, \\\\\n            \\boldsymbol{q}_{im}\n            &\\leftarrow\\frac{\\boldsymbol{q}_{im}}\n            {\\sqrt{\\boldsymbol{q}_{im}^{\\mathsf{H}}\\boldsymbol{U}_{im}\\boldsymbol{q}_{im}}},\n\n        where\n\n        .. math::\n            \\boldsymbol{U}_{im}\n            = \\frac{1}{J}\\sum_{j}\n            \\frac{\\boldsymbol{x}_{ij}\\boldsymbol{x}_{ij}^{\\mathsf{H}}}\n            {\\sum_{n}\\left(\\sum_{k}z_{nk}t_{ik}v_{kj}\\right)d_{inm}}\n\n        if ``partitioning=True``, otherwise\n\n        .. math::\n            \\boldsymbol{U}_{im}\n            = \\frac{1}{J}\\sum_{j}\n            \\frac{\\boldsymbol{x}_{ij}\\boldsymbol{x}_{ij}^{\\mathsf{H}}}\n            {\\sum_{n}\\left(\\sum_{k}t_{ikn}v_{kjn}\\right)d_{inm}}.\n        \"\"\"\n        assert not self.partitioning, \"partitioning function is not supported.\"\n\n        na = np.newaxis\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        X = self.input\n        T, V = self.basis, self.activation\n        Q, D = self.diagonalizer, self.spatial\n\n        if self.partitioning:\n            Lamb = self.reconstruct_nmf(T, V, latent=self.latent)\n        else:\n            Lamb = self.reconstruct_nmf(T, V)\n\n        XX_Hermite = X[:, na, :, :] * X[na, :, :, :].conj()\n        XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3)\n\n        Lamb = Lamb.transpose(1, 0, 2)\n        LambD = np.sum(Lamb[:, :, na, :] * D[:, :, :, na], axis=1)\n        varphi = 1 / LambD\n\n        varphi_XX = varphi[:, :, na, na, :] * XX_Hermite[:, na, :, :, :]\n        U = np.mean(varphi_XX, axis=-1)\n\n        self.diagonalizer = update_by_ip1(Q, U, flooring_fn=flooring_fn)\n\n    def update_diagonalizer_ip2(\n        self,\n        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    ) -> None:\n        r\"\"\"Update diagonalizer once using pairwise iterative projection.\n\n        Args:\n            flooring_fn (callable or str, optional):\n                A flooring function for numerical stability.\n                This function is expected to return the same shape tensor as the input.\n                If you explicitly set ``flooring_fn=None``,\n                the identity function (``lambda x: x``) is used.\n                If ``self`` is given as str, ``self.flooring_fn`` is used.\n                Default: ``self``.\n\n        For :math:`m_{1}` and :math:`m_{2}` (:math:`m_{1}\\neq m_{2}`),\n        compute weighted covariance matrix as follows:\n\n        .. math::\n            \\boldsymbol{U}_{im}\n            = \\frac{1}{J}\\sum_{j}\n            \\frac{\\boldsymbol{x}_{ij}\\boldsymbol{x}_{ij}^{\\mathsf{H}}}{\\sum_{n}\\lambda_{ijn}d_{inm}},\n\n        :math:`\\lambda_{ijn}` is computed by\n\n        .. math::\n            \\lambda_{ijn}=\\sum_{k}z_{nk}t_{ik}v_{kj}\n\n        if ``partitioning=True``.\n        Otherwise,\n\n        .. math::\n            \\lambda_{ijn}=\\sum_{k}t_{ikn}v_{kjn}.\n\n        Using :math:`\\boldsymbol{U}_{im_{1}}` and\n        :math:`\\boldsymbol{U}_{im_{2}}`, we compute generalized eigenvectors.\n\n        .. math::\n            \\left({\\boldsymbol{P}_{im_{1}}^{(m_{1},m_{2})}}^{\\mathsf{H}}\\boldsymbol{U}_{im_{1}}\n            \\boldsymbol{P}_{im_{1}}^{(m_{1},m_{2})}\\right)\\boldsymbol{h}_{i}\n            = \\mu_{i}\n            \\left({\\boldsymbol{P}_{im_{2}}^{(m_{1},m_{2})}}^{\\mathsf{H}}\\boldsymbol{U}_{im_{2}}\n            \\boldsymbol{P}_{im_{2}}^{(m_{1},m_{2})}\\right)\\boldsymbol{h}_{i},\n\n        where\n\n        .. math::\n            \\boldsymbol{P}_{im_{1}}^{(m_{1},m_{2})}\n            &= (\\boldsymbol{Q}_{i}\\boldsymbol{U}_{im_{1}})^{-1}\n            (\n            \\begin{array}{cc}\n                \\boldsymbol{e}_{m_{1}} & \\boldsymbol{e}_{m_{2}}\n            \\end{array}\n            ), \\\\\n            \\boldsymbol{P}_{im_{2}}^{(m_{1},m_{2})}\n            &= (\\boldsymbol{Q}_{i}\\boldsymbol{U}_{im_{2}})^{-1}\n            (\n            \\begin{array}{cc}\n                \\boldsymbol{e}_{m_{1}} & \\boldsymbol{e}_{m_{2}}\n            \\end{array}\n            ).\n\n        After that, we standardize two eigenvectors :math:`\\boldsymbol{h}_{im_{1}}`\n        and :math:`\\boldsymbol{h}_{im_{2}}`.\n\n        .. math::\n            \\boldsymbol{h}_{im_{1}}\n            &\\leftarrow\\frac{\\boldsymbol{h}_{im_{1}}}\n            {\\sqrt{\\boldsymbol{h}_{im_{1}}^{\\mathsf{H}}\n            \\left({\\boldsymbol{P}_{im_{1}}^{(m_{1},m_{2})}}^{\\mathsf{H}}\\boldsymbol{U}_{im_{1}}\n            \\boldsymbol{P}_{im_{1}}^{(m_{1},m_{2})}\\right)\n            \\boldsymbol{h}_{im_{1}}}}, \\\\\n            \\boldsymbol{h}_{im_{2}}\n            &\\leftarrow\\frac{\\boldsymbol{h}_{im_{2}}}\n            {\\sqrt{\\boldsymbol{h}_{im_{2}}^{\\mathsf{H}}\n            \\left({\\boldsymbol{P}_{im_{2}}^{(m_{1},m_{2})}}^{\\mathsf{H}}\\boldsymbol{U}_{im_{2}}\n            \\boldsymbol{P}_{im_{2}}^{(m_{1},m_{2})}\\right)\n            \\boldsymbol{h}_{im_{2}}}}.\n\n        Then, update :math:`\\boldsymbol{q}_{im_{1}}` and :math:`\\boldsymbol{q}_{im_{2}}`\n        simultaneously.\n\n        .. math::\n            \\boldsymbol{q}_{im_{1}}\n            &\\leftarrow \\boldsymbol{P}_{im_{1}}^{(m_{1},m_{2})}\\boldsymbol{h}_{im_{1}} \\\\\n            \\boldsymbol{q}_{im_{2}}\n            &\\leftarrow \\boldsymbol{P}_{im_{2}}^{(m_{1},m_{2})}\\boldsymbol{h}_{im_{2}}\n\n        At each iteration, we update pairs of :math:`m_{1}` and :math:`m_{2}`\n        for :math:`m_{1}\\neq m_{2}`.\n        \"\"\"\n        assert not self.partitioning, \"partitioning function is not supported.\"\n\n        na = np.newaxis\n        flooring_fn = choose_flooring_fn(flooring_fn, method=self)\n\n        X = self.input\n        T, V = self.basis, self.activation\n        Q, D = self.diagonalizer, self.spatial\n\n        if self.partitioning:\n            Lamb = self.reconstruct_nmf(T, V, latent=self.latent)\n        else:\n            Lamb = self.reconstruct_nmf(T, V)\n\n        XX_Hermite = X[:, na, :, :] * X[na, :, :, :].conj()\n        XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3)\n\n        Lamb = Lamb.transpose(1, 0, 2)\n        LambD = np.sum(Lamb[:, :, na, :] * D[:, :, :, na], axis=1)\n        varphi = 1 / LambD\n\n        varphi_XX = varphi[:, :, na, na, :] * XX_Hermite[:, na, :, :, :]\n        U = np.mean(varphi_XX, axis=-1)\n\n        self.diagonalizer = update_by_ip2(\n            Q, U, flooring_fn=flooring_fn, pair_selector=self.pair_selector\n        )\n\n    def update_spatial(self) -> None:\n        r\"\"\"Update diagonal elements of spatial covariance matrix by MM algorithm.\n\n        Update :math:`d_{inm}` as follows:\n\n        .. math::\n            d_{inm}\\leftarrow\\left[\n            \\dfrac{\\displaystyle\\sum_{j}\\frac{\\lambda_{ijn}|\\boldsymbol{q}_{im}^{\\mathsf{H}}\\boldsymbol{x}_{ij}|^{2}}\n            {\\left(\\sum_{n'}\\lambda_{ijn'}d_{in'm}\\right)^{2}}}\n            {\\displaystyle\\sum_{j}\\frac{\\lambda_{ijn}}{\\sum_{n'}\\lambda_{ijn'}d_{in'm}}}\n            \\right]^{\\frac{1}{2}}d_{inm}.\n        \"\"\"\n        assert not self.partitioning, \"partitioning function is not supported.\"\n\n        na = np.newaxis\n\n        X = self.input\n        T, V = self.basis, self.activation\n        Q, D = self.diagonalizer, self.spatial\n\n        if self.partitioning:\n            Lamb = self.reconstruct_nmf(T, V, latent=self.latent)\n        else:\n            Lamb = self.reconstruct_nmf(T, V)\n\n        QX = Q @ X.transpose(1, 0, 2)\n        QX = np.abs(QX)\n        QX2 = QX**2\n\n        Lamb = Lamb.transpose(1, 0, 2)\n        LambD = np.sum(Lamb[:, :, na, :] * D[:, :, :, na], axis=1)\n        LambD2 = LambD**2\n        Lamb_LambD2 = Lamb[:, :, na] / LambD2[:, na, :]\n        num = np.sum(Lamb_LambD2 * QX2[:, na, :, :], axis=-1)\n\n        Lamb_LambD = Lamb[:, :, na] / LambD[:, na, :]\n        denom = np.sum(Lamb_LambD, axis=-1)\n\n        D = np.sqrt(num / denom) * D\n\n        self.spatial = D\n"
  },
  {
    "path": "ssspy/bss/pdsbss.py",
    "content": "import warnings\nfrom typing import Callable, List, Optional, Union\n\nimport numpy as np\n\nfrom ..linalg import prox\nfrom .proxbss import ProxBSSBase\n\nEPS = 1e-10\n\n__all__ = [\"PDSBSS\", \"MaskingPDSBSS\"]\n\n\nclass PDSBSSBase(ProxBSSBase):\n    r\"\"\"Base class of blind source separation \\\n    via proximal splitting algorithm [#yatabe2018determined]_.\n\n    Args:\n        penalty_fn (callable):\n            Penalty function that determines source model.\n        prox_penalty (callable):\n            Proximal operator of penalty function.\n            Default: ``None``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back`` explicitly.\n            Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the update algorithm if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back.\n            Default: ``0``.\n\n    .. [#yatabe2018determined] K. Yatabe and D. Kitamura,\n        \"Determined blind source separation via proximal splitting algorithm,\"\n        in *Proc. ICASSP*, 2018, pp. 776-780.\n    \"\"\"\n\n    def __repr__(self) -> str:\n        s = \"PDSBSS(\"\n        s += \"n_penalties={n_penalties}\".format(n_penalties=self.n_penalties)\n        s += \", scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n\nclass PDSBSS(PDSBSSBase):\n    r\"\"\"Blind source separation via proximal splitting algorithm [#yatabe2018determined]_.\n\n    Args:\n        mu1 (float):\n            Step size. Default: ``1``.\n        mu2 (float):\n            Step size. Default: ``1``.\n        alpha (float):\n            Relaxation parameter (deprecated). Set ``relaxation`` instead.\n        relaxation (float):\n            Relaxation parameter. Default: ``1``.\n        penalty_fn (callable, optional):\n            Penalty function that determines source model.\n        prox_penalty (callable):\n            Proximal operator of penalty function.\n            Default: ``None``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back`` explicitly.\n            Default: ``True``.\n        record_loss (bool, optional):\n            Record the loss at each iteration of the update algorithm if ``record_loss=True``.\n            Default: ``None``.\n        reference_id (int):\n            Reference channel for projection back.\n            Default: ``0``.\n    \"\"\"\n\n    def __init__(\n        self,\n        mu1: float = 1,\n        mu2: float = 1,\n        alpha: float = None,\n        relaxation: float = 1,\n        penalty_fn: Optional[Callable[[np.ndarray, np.ndarray], float]] = None,\n        prox_penalty: Callable[[np.ndarray, float], np.ndarray] = None,\n        callbacks: Optional[\n            Union[Callable[[\"PDSBSS\"], None], List[Callable[[\"PDSBSS\"], None]]]\n        ] = None,\n        scale_restoration: bool = True,\n        record_loss: Optional[bool] = None,\n        reference_id: int = 0,\n    ) -> None:\n        super().__init__(\n            penalty_fn=penalty_fn,\n            prox_penalty=prox_penalty,\n            callbacks=callbacks,\n            scale_restoration=scale_restoration,\n            record_loss=record_loss,\n            reference_id=reference_id,\n        )\n\n        self.mu1, self.mu2 = mu1, mu2\n\n        if alpha is None:\n            self.relaxation = relaxation\n        else:\n            assert relaxation == 1, \"You cannot specify relaxation and alpha simultaneously.\"\n\n            warnings.warn(\"alpha is deprecated. Set relaxation instead.\", DeprecationWarning)\n\n            self.relaxation = alpha\n\n    def __call__(self, input, n_iter=100, initial_call: bool = True, **kwargs) -> np.ndarray:\n        r\"\"\"Separate a frequency-domain multichannel signal.\n\n        Args:\n            input (numpy.ndarray):\n                Mixture signal in frequency-domain.\n                The shape is (n_channels, n_bins, n_frames).\n            n_iter (int):\n                Number of iterations of demixing filter updates.\n                Default: ``100``.\n            initial_call (bool):\n                If ``True``, perform callbacks (and computation of loss if necessary)\n                before iterations.\n\n        Returns:\n            numpy.ndarray of the separated signal in frequency-domain.\n            The shape is (n_channels, n_bins, n_frames).\n        \"\"\"\n        self.input = input.copy()\n\n        self._reset(**kwargs)\n\n        # Call __call__ of PDSBSSBase's parent, i.e. __call__ of IterativeMethodBase\n        super(PDSBSSBase, self).__call__(n_iter=n_iter, initial_call=initial_call)\n\n        if self.scale_restoration:\n            self.restore_scale()\n\n        self.output = self.separate(self.input, demix_filter=self.demix_filter)\n\n        return self.output\n\n    def __repr__(self) -> str:\n        s = \"PDSBSS(\"\n        s += \"mu1={mu1}, mu2={mu2}\"\n        s += \", relaxation={relaxation}\"\n        s += \", n_penalties={n_penalties}\".format(n_penalties=self.n_penalties)\n        s += \", scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def _reset(self, **kwargs) -> None:\n        r\"\"\"Reset attributes by given keyword arguments.\n\n        Args:\n            kwargs:\n                Keyword arguments to set as attributes of PDSBSS.\n        \"\"\"\n        super()._reset(**kwargs)\n\n        n_penalties = self.n_penalties\n        n_sources = self.n_sources\n        n_bins, n_frames = self.n_bins, self.n_frames\n\n        if not hasattr(self, \"dual\"):\n            dual = np.zeros((n_penalties, n_sources, n_bins, n_frames), dtype=np.complex128)\n        else:\n            if self.dual is None:\n                dual = None\n            else:\n                # To avoid overwriting ``dual`` given by keyword arguments.\n                dual = self.dual.copy()\n\n        self.dual = dual\n\n    def update_once(self) -> None:\n        r\"\"\"Update demixing filters and dual parameters once.\"\"\"\n        mu1, mu2 = self.mu1, self.mu2\n        alpha = self.relaxation\n\n        Y = self.dual\n        X, W = self.input, self.demix_filter\n\n        Y_sum = Y.sum(axis=0)\n        XY = Y_sum.transpose(1, 0, 2) @ X.transpose(1, 2, 0).conj()\n        W_tilde = prox.neg_logdet(W - mu1 * mu2 * XY, step_size=mu1)\n        XW = self.separate(X, demix_filter=2 * W_tilde - W)\n        Y_tilde = []\n\n        for Y_q, prox_penalty in zip(Y, self.prox_penalty):\n            Z_q = Y_q + XW\n            Y_tilde_q = Z_q - prox_penalty(Z_q, step_size=1 / mu2)\n            Y_tilde.append(Y_tilde_q)\n\n        Y_tilde = np.stack(Y_tilde, axis=0)\n\n        self.demix_filter = alpha * W_tilde + (1 - alpha) * W\n        self.dual = alpha * Y_tilde + (1 - alpha) * Y\n\n\nclass MaskingPDSBSS(PDSBSSBase):\n    r\"\"\"Blind source separation via proximal splitting algorithm with masking [#yatabe2019time]_.\n\n    Args:\n        mu1 (float):\n            Step size. Default: ``1``.\n        mu2 (float):\n            Step size. Default: ``1``.\n        alpha (float):\n            Relaxation parameter (deprecated). Set ``relaxation`` instead.\n        relaxation (float):\n            Relaxation parameter. Default: ``1``.\n        penalty_fn (callable, optional):\n            Penalty function that determines source model.\n        mask_fn (callable):\n            Proximal operator of penalty function.\n            Default: ``None``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back`` explicitly.\n            Default: ``True``.\n        record_loss (bool):\n            Record the loss at each iteration of the update algorithm if ``record_loss=True``.\n            Default: ``True``.\n        reference_id (int):\n            Reference channel for projection back.\n            Default: ``0``.\n\n    .. [#yatabe2019time] K. Yatabe and D. Kitamura,\n        \"Time-frequency-masking-based determined BSS with application to sparse IVA,\"\n        in *Proc. ICASSP*, pp. 715-719, 2019.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        mu1: float = 1,\n        mu2: float = 1,\n        alpha: float = None,\n        relaxation: float = 1,\n        penalty_fn: Optional[Callable[[np.ndarray, np.ndarray], float]] = None,\n        mask_fn: Callable[[np.ndarray], float] = None,\n        callbacks: Optional[\n            Union[Callable[[\"MaskingPDSBSS\"], None], List[Callable[[\"MaskingPDSBSS\"], None]]]\n        ] = None,\n        scale_restoration: bool = True,\n        record_loss: Optional[bool] = None,\n        reference_id: int = 0,\n    ) -> None:\n        super(ProxBSSBase, self).__init__(\n            callbacks=callbacks,\n            record_loss=record_loss,\n        )\n\n        if penalty_fn is None:\n            # Since penalty_fn is not necessarily written in closed form,\n            # None is acceptable.\n            if record_loss is None:\n                record_loss = False\n\n            assert not record_loss, \"To record loss, set penalty_fn.\"\n        else:\n            assert callable(penalty_fn), \"penalty_fn should be callable.\"\n\n            if record_loss is None:\n                record_loss = True\n\n        if mask_fn is None:\n            raise ValueError(\"Specify masking function.\")\n        else:\n            assert callable(mask_fn), \"mask_fn should be callable.\"\n\n        self.penalty_fn = penalty_fn\n        self.mask_fn = mask_fn\n\n        self.input = None\n        self.scale_restoration = scale_restoration\n\n        if reference_id is None and scale_restoration:\n            raise ValueError(\"Specify 'reference_id' if scale_restoration=True.\")\n        else:\n            self.reference_id = reference_id\n\n        self.mu1, self.mu2 = mu1, mu2\n\n        if alpha is None:\n            self.relaxation = relaxation\n        else:\n            assert relaxation == 1, \"You cannot specify relaxation and alpha simultaneously.\"\n\n            warnings.warn(\"alpha is deprecated. Set relaxation instead.\", DeprecationWarning)\n\n            self.relaxation = alpha\n\n    def __call__(self, input, n_iter=100, initial_call: bool = True, **kwargs) -> np.ndarray:\n        r\"\"\"Separate a frequency-domain multichannel signal.\n\n        Args:\n            input (numpy.ndarray):\n                Mixture signal in frequency-domain.\n                The shape is (n_channels, n_bins, n_frames).\n            n_iter (int):\n                Number of iterations of demixing filter updates.\n                Default: ``100``.\n            initial_call (bool):\n                If ``True``, perform callbacks (and computation of loss if necessary)\n                before iterations.\n\n        Returns:\n            numpy.ndarray of the separated signal in frequency-domain.\n            The shape is (n_channels, n_bins, n_frames).\n        \"\"\"\n        self.input = input.copy()\n\n        self._reset(**kwargs)\n\n        # Call __call__ of PDSBSSBase's parent, i.e. __call__ of IterativeMethodBase\n        super(PDSBSSBase, self).__call__(n_iter=n_iter, initial_call=initial_call)\n\n        if self.scale_restoration:\n            self.restore_scale()\n\n        self.output = self.separate(self.input, demix_filter=self.demix_filter)\n\n        return self.output\n\n    def __repr__(self) -> str:\n        s = \"MaskingPDSBSS(\"\n        s += \"mu1={mu1}, mu2={mu2}\"\n        s += \", relaxation={relaxation}\"\n        s += \", scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def _reset(self, **kwargs) -> None:\n        r\"\"\"Reset attributes by given keyword arguments.\n\n        Args:\n            kwargs:\n                Keyword arguments to set as attributes of MaskingPDSBSS.\n        \"\"\"\n        super()._reset(**kwargs)\n\n        assert self.n_penalties == 1, \"Number of penalty function should be one.\"\n\n        n_sources = self.n_sources\n        n_bins, n_frames = self.n_bins, self.n_frames\n\n        if not hasattr(self, \"dual\"):\n            dual = np.zeros((n_sources, n_bins, n_frames), dtype=np.complex128)\n        else:\n            if self.dual is None:\n                dual = None\n            else:\n                # To avoid overwriting ``dual`` given by keyword arguments.\n                dual = self.dual.copy()\n\n        self.dual = dual\n\n    @property\n    def n_penalties(self):\n        r\"\"\"Return number of penalty terms.\"\"\"\n        return 1\n\n    def update_once(self) -> None:\n        r\"\"\"Update demixing filters and dual parameters once.\"\"\"\n        mu1, mu2 = self.mu1, self.mu2\n        alpha = self.relaxation\n\n        Y = self.dual\n        X, W = self.input, self.demix_filter\n\n        XY = Y.transpose(1, 0, 2) @ X.transpose(1, 2, 0).conj()\n        W_tilde = prox.neg_logdet(W - mu1 * mu2 * XY, step_size=mu1)\n        XW = self.separate(X, demix_filter=2 * W_tilde - W)\n\n        Z = Y + XW\n        Y_tilde = Z - self.mask_fn(Z) * Z\n\n        self.demix_filter = alpha * W_tilde + (1 - alpha) * W\n        self.dual = alpha * Y_tilde + (1 - alpha) * Y\n"
  },
  {
    "path": "ssspy/bss/proxbss.py",
    "content": "from typing import Callable, List, Optional, Union\n\nimport numpy as np\n\nfrom ..algorithm import (\n    MINIMAL_DISTORTION_PRINCIPLE_KEYWORDS,\n    PROJECTION_BACK_KEYWORDS,\n    minimal_distortion_principle,\n    projection_back,\n)\nfrom .base import IterativeMethodBase\n\nEPS = 1e-10\n\n\nclass ProxBSSBase(IterativeMethodBase):\n    \"\"\"Base class of blind source separation via proximal gradient method.\n\n    Args:\n        penalty_fn (callable, optional):\n            Penalty function that determines source model.\n        prox_penalty (callable):\n            Proximal operator of penalty function.\n            Default: ``None``.\n        callbacks (callable or list[callable], optional):\n            Callback functions. Each function is called before separation and at each iteration.\n            Default: ``None``.\n        scale_restoration (bool or str):\n            Technique to restore scale ambiguity.\n            If ``scale_restoration=True``, the projection back technique is applied to\n            estimated spectrograms. You can also specify ``projection_back`` explicitly.\n            Default: ``True``.\n        record_loss (bool, optional):\n            Record the loss at each iteration of the update algorithm if ``record_loss=True``.\n            Default: ``None``.\n        reference_id (int):\n            Reference channel for projection back.\n            Default: ``0``.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        penalty_fn: Optional[Callable[[np.ndarray, np.ndarray], float]] = None,\n        prox_penalty: Callable[[np.ndarray, float], np.ndarray] = None,\n        callbacks: Optional[\n            Union[Callable[[\"ProxBSSBase\"], None], List[Callable[[\"ProxBSSBase\"], None]]]\n        ] = None,\n        scale_restoration: bool = True,\n        record_loss: Optional[bool] = None,\n        reference_id: int = 0,\n    ) -> None:\n        super().__init__(\n            callbacks=callbacks,\n            record_loss=record_loss,\n        )\n\n        if penalty_fn is None:\n            # Since penalty_fn is not necessarily written in closed form,\n            # None is acceptable.\n            if record_loss is None:\n                record_loss = False\n\n            assert not record_loss, \"To record loss, set penalty_fn.\"\n        else:\n            if callable(penalty_fn):\n                penalty_fn = [penalty_fn]\n\n            if record_loss is None:\n                record_loss = True\n\n        if prox_penalty is None:\n            raise ValueError(\"Specify proximal operator of penalty function.\")\n        else:\n            if callable(prox_penalty):\n                prox_penalty = [prox_penalty]\n\n        self.penalty_fn = penalty_fn\n        self.prox_penalty = prox_penalty\n\n        if self.penalty_fn is not None:\n            assert len(self.penalty_fn) == len(\n                self.prox_penalty\n            ), \"Length of penalty_fn and prox_penalty are different.\"\n\n        self.input = None\n        self.scale_restoration = scale_restoration\n\n        if reference_id is None and scale_restoration:\n            raise ValueError(\"Specify 'reference_id' if scale_restoration=True.\")\n        else:\n            self.reference_id = reference_id\n\n    def __repr__(self) -> str:\n        s = \"ProxBSSBase(\"\n        s += \"n_penalties={n_penalties}\".format(n_penalties=self.n_penalties)\n        s += \", scale_restoration={scale_restoration}\"\n        s += \", record_loss={record_loss}\"\n\n        if self.scale_restoration:\n            s += \", reference_id={reference_id}\"\n\n        s += \")\"\n\n        return s.format(**self.__dict__)\n\n    def _reset(self, **kwargs) -> None:\n        r\"\"\"Reset attributes by given keyword arguments.\n\n        Args:\n            kwargs:\n                Keyword arguments to set as attributes of ProxBSSBase.\n\n        \"\"\"\n        assert self.input is not None, \"Specify data!\"\n\n        for key in kwargs.keys():\n            setattr(self, key, kwargs[key])\n\n        X = self.input\n\n        n_channels, n_bins, n_frames = X.shape\n        n_sources = n_channels  # n_channels == n_sources\n\n        self.n_sources, self.n_channels = n_sources, n_channels\n        self.n_bins, self.n_frames = n_bins, n_frames\n\n        if not hasattr(self, \"demix_filter\"):\n            W = np.eye(n_sources, n_channels, dtype=np.complex128)\n            W = np.tile(W, reps=(n_bins, 1, 1))\n        else:\n            if self.demix_filter is None:\n                W = None\n            else:\n                # To avoid overwriting ``demix_filter`` given by keyword arguments.\n                W = self.demix_filter.copy()\n\n        self.demix_filter = W\n        self.output = self.separate(X, demix_filter=W)\n\n    @property\n    def n_penalties(self):\n        r\"\"\"Return number of penalty terms.\"\"\"\n        # asumption of len(self.penalty_fn) == len(self.prox_penalty)\n        return len(self.prox_penalty)\n\n    def separate(self, input: np.ndarray, demix_filter: np.ndarray) -> np.ndarray:\n        r\"\"\"Separate ``input`` using ``demixing_filter``.\n\n        .. math::\n            \\boldsymbol{y}_{ij}\n            = \\boldsymbol{W}_{i}\\boldsymbol{x}_{ij}\n\n        Args:\n            input (numpy.ndarray):\n                The mixture signal in frequency-domain.\n                The shape is (n_channels, n_bins, n_frames).\n            demix_filter (numpy.ndarray):\n                The demixing filters to separate ``input``.\n                The shape is (n_bins, n_sources, n_channels).\n\n        Returns:\n            numpy.ndarray of the separated signal in frequency-domain.\n            The shape is (n_sources, n_bins, n_frames).\n        \"\"\"\n        X, W = input, demix_filter\n        Y = W @ X.transpose(1, 0, 2)\n        output = Y.transpose(1, 0, 2)\n\n        return output\n\n    def compute_loss(self) -> float:\n        r\"\"\"Compute loss :math:`\\mathcal{L}`.\n\n        Returns:\n            Computed loss.\n        \"\"\"\n        X, W = self.input, self.demix_filter\n        Y = self.separate(X, demix_filter=W)  # (n_sources, n_bins, n_frames)\n        logdet = self.compute_logdet(W)  # (n_bins,)\n        penalty = 0\n\n        for penalty_fn in self.penalty_fn:\n            penalty = penalty + penalty_fn(Y)\n\n        loss = penalty - np.sum(logdet, axis=0)\n        loss = loss.item()\n\n        return loss\n\n    def compute_logdet(self, demix_filter: np.ndarray) -> np.ndarray:\n        r\"\"\"Compute log-determinant of demixing filter\n\n        Args:\n            demix_filter (numpy.ndarray):\n                Demixing filters with shape of (n_bins, n_sources, n_channels).\n\n        Returns:\n            numpy.ndarray of computed log-determinant values.\n        \"\"\"\n        _, logdet = np.linalg.slogdet(demix_filter)  # (n_bins,)\n\n        return logdet\n\n    def normalize_by_spectral_norm(self, input: np.ndarray, n_penalties: int = None) -> np.ndarray:\n        r\"\"\"Spectral normalization.\n\n        Args:\n            input (numpy.ndarray):\n                Input spectrogram with shape of (n_channels, n_bins, n_frames).\n            n_penalties (int):\n                Number of penalty functions, which determines coefficient of normalization.\n\n        Returns:\n            numpy.ndarray of normalized spectrogram with shape of (n_channels, n_bins, n_frames).\n        \"\"\"\n        if n_penalties is None:\n            n_penalties = self.n_penalties\n\n        norm = np.linalg.norm(input.transpose(1, 0, 2), ord=2, axis=(-2, -1))\n        norm = np.max(norm)\n\n        return input / (np.sqrt(n_penalties) * norm)\n\n    def restore_scale(self) -> None:\n        r\"\"\"Restore scale ambiguity.\n\n        If ``self.scale_restoration=projection_back``, we use projection back technique.\n        \"\"\"\n        scale_restoration = self.scale_restoration\n\n        assert scale_restoration, \"Set self.scale_restoration=True.\"\n\n        if type(scale_restoration) is bool:\n            scale_restoration = \"projection_back\"\n\n        if scale_restoration in PROJECTION_BACK_KEYWORDS:\n            self.apply_projection_back()\n        elif scale_restoration in MINIMAL_DISTORTION_PRINCIPLE_KEYWORDS:\n            self.apply_minimal_distortion_principle()\n        else:\n            raise ValueError(\"{} is not supported for scale restoration.\".format(scale_restoration))\n\n    def apply_projection_back(self) -> None:\n        r\"\"\"Apply projection back technique to estimated spectrograms.\"\"\"\n        assert self.scale_restoration, \"Set self.scale_restoration=True.\"\n\n        X, W = self.input, self.demix_filter\n        W_scaled = projection_back(W, reference_id=self.reference_id)\n        Y_scaled = self.separate(X, demix_filter=W_scaled)\n\n        self.output, self.demix_filter = Y_scaled, W_scaled\n\n    def apply_minimal_distortion_principle(self) -> None:\n        r\"\"\"Apply minimal distortion principle to estimated spectrograms.\"\"\"\n        assert self.scale_restoration, \"Set self.scale_restoration=True.\"\n\n        X, W = self.input, self.demix_filter\n        Y = self.separate(X, demix_filter=W)\n        Y_scaled = minimal_distortion_principle(Y, reference=X, reference_id=self.reference_id)\n        X = X.transpose(1, 0, 2)\n        Y = Y_scaled.transpose(1, 0, 2)\n        X_Hermite = X.transpose(0, 2, 1).conj()\n        W_scaled = Y @ X_Hermite @ np.linalg.inv(X @ X_Hermite)\n\n        self.output, self.demix_filter = Y_scaled, W_scaled\n"
  },
  {
    "path": "ssspy/io/__init__.py",
    "content": "import struct\nfrom io import BufferedReader, BufferedWriter\nfrom typing import Optional, Tuple\n\nimport numpy as np\n\n\ndef wavread(\n    path: str,\n    frame_offset: int = 0,\n    num_frames: Optional[int] = None,\n    return_2d: Optional[bool] = None,\n    channels_first: Optional[bool] = None,\n) -> Tuple[np.ndarray, int]:\n    with open(path, mode=\"rb\") as f:\n        riff = f.read(4)\n\n        # ensure byte order is little endian\n        if riff != b\"RIFF\":\n            raise NotImplementedError(f\"Not support {repr(riff)}.\")\n\n        # total file size\n        _ = struct.unpack(\"<I\", f.read(4))[0] + 4 + 4\n\n        ftype = f.read(4)\n\n        # ensure file type is WAV\n        if ftype != b\"WAVE\":\n            raise NotImplementedError(f\"Not support {repr(ftype)}.\")\n\n        chunk_marker = f.read(4)\n\n        if chunk_marker != b\"fmt \":\n            raise NotImplementedError(f\"Not support {repr(chunk_marker)}.\")\n\n        n_channels, sample_rate, block_align = _read_fmt_chunk(f)\n\n        chunk_marker = f.read(4)\n\n        if chunk_marker != b\"data\":\n            raise NotImplementedError(f\"Not support {repr(chunk_marker)}.\")\n\n        data = _read_data_chunk(\n            f,\n            n_channels,\n            block_align,\n            frame_offset=frame_offset,\n            num_frames=num_frames,\n            return_2d=return_2d,\n            channels_first=channels_first,\n        )\n\n    return data, sample_rate\n\n\ndef wavwrite(\n    path: str,\n    waveform: np.ndarray,\n    sample_rate: int,\n    channels_first: Optional[bool] = None,\n) -> None:\n    assert path[-4:] == \".wav\", \"Only wav file is supported.\"\n\n    if waveform.ndim == 1:\n        _waveform = waveform\n        n_channels = 1\n    elif waveform.ndim == 2:\n        if channels_first:\n            _waveform = waveform.transpose(1, 0)\n        else:\n            _waveform = waveform\n\n        n_channels = _waveform.shape[1]\n\n        if n_channels < 1 or 2 < n_channels:\n            raise ValueError(f\"{n_channels}channel-input is not supported.\")\n    else:\n        raise ValueError(f\"waveform.ndim should be less or equal to 2, but given {waveform.ndim}.\")\n\n    if _waveform.dtype in [\"f2\", \"f4\", \"f8\", \"f16\"]:\n        bits_per_sample = 16\n\n        # float to int\n        _waveform = _waveform * 2 ** (bits_per_sample - 1)\n        _waveform = _waveform.astype(\"<i2\")\n    elif _waveform.dtype == \"i1\":\n        bits_per_sample = 8\n    elif _waveform.dtype == \"i2\":\n        bits_per_sample = 16\n    else:\n        raise ValueError(f\"Invalid dtype={_waveform.dtype} is detected.\")\n\n    assert (\n        bits_per_sample % 8 == 0\n    ), f\"bits_per_sample should be divisible by 8, but given {bits_per_sample}.\"\n\n    byte_rate = (bits_per_sample * sample_rate * n_channels) // 8\n    block_align = byte_rate // sample_rate\n\n    with open(path, mode=\"wb\") as f:\n        valid_file_size = 0\n\n        data = b\"RIFF\"\n        f.write(data)\n        valid_file_size += 4\n\n        filesize_position = f.tell()\n        data = struct.pack(\"<I\", 0)  # calculate file size at last\n        f.write(data)\n\n        data = b\"WAVE\"\n        f.write(data)\n\n        _write_fmt_chunk(f, n_channels, sample_rate, byte_rate, block_align, bits_per_sample)\n\n        _write_data_chunk(f, _waveform)\n\n        total_file_size = f.tell()\n        data = struct.pack(\"<I\", total_file_size - 8)\n        f.seek(filesize_position)\n        f.write(data)\n\n\ndef _read_fmt_chunk(\n    f: BufferedReader,\n) -> Tuple[int, int, int]:\n    fmt_chunk_size = struct.unpack(\"<I\", f.read(4))[0]\n\n    if fmt_chunk_size != 16:\n        raise NotImplementedError(\"Invalid header is detected.\")\n\n    fmt = struct.unpack(\"<H\", f.read(2))[0]\n\n    # ensure format is PCM\n    if fmt != 1:\n        raise NotImplementedError(f\"Invalid header {fmt} is detected.\")\n\n    n_channels, sample_rate, byte_rate, block_align, bits_per_sample = struct.unpack(\n        \"<HIIHH\", f.read(2 + 4 + 4 + 2 + 2)\n    )\n\n    if bits_per_sample * sample_rate * n_channels != 8 * byte_rate:\n        raise ValueError(\"Invalid header is detected.\")\n\n    return n_channels, sample_rate, block_align\n\n\ndef _read_data_chunk(\n    f: BufferedReader,\n    n_channels: int,\n    block_align: int,\n    frame_offset: int = 0,\n    num_frames: Optional[int] = None,\n    return_2d: Optional[bool] = None,\n    channels_first: Optional[bool] = None,\n) -> np.ndarray:\n    data_chunk_size = struct.unpack(\"<I\", f.read(4))[0]\n    bytes_per_sample = block_align // n_channels\n    n_full_samples = data_chunk_size // bytes_per_sample\n\n    start = f.tell() + block_align * frame_offset\n    max_frame = data_chunk_size // block_align\n\n    if num_frames is None:\n        shape = (n_full_samples - n_channels * frame_offset,)\n        end_frame = data_chunk_size // block_align\n    elif num_frames >= 0:\n        shape = (n_channels * num_frames,)\n        end_frame = frame_offset + num_frames\n    else:\n        raise ValueError(f\"Invalid num_frames={num_frames} is given. Set nonnegative integer.\")\n\n    if end_frame > max_frame:\n        raise ValueError(f\"num_frames={num_frames} exceeds maximum frame {max_frame}.\")\n\n    data = np.memmap(f, dtype=f\"<i{bytes_per_sample}\", mode=\"c\", offset=start, shape=shape)\n\n    if n_channels > 1:\n        data = data.reshape(-1, n_channels)\n\n        if channels_first:\n            data = data.transpose(1, 0)\n    else:\n        if return_2d:\n            data = data.reshape(-1, n_channels)\n\n            if channels_first:\n                data = data.transpose(1, 0)\n\n    vmax = 2 ** (8 * bytes_per_sample - 1)\n\n    return data / vmax\n\n\ndef _write_fmt_chunk(\n    f: BufferedWriter,\n    n_channels: int,\n    sample_rate: int,\n    byte_rate: int,\n    block_align: int,\n    bits_per_sample: int,\n) -> None:\n    data = b\"fmt \"\n    f.write(data)\n\n    data = struct.pack(\"<I\", 16)\n    f.write(data)\n\n    data = struct.pack(\"<H\", 1)\n    f.write(data)\n\n    data = struct.pack(\"<HIIHH\", n_channels, sample_rate, byte_rate, block_align, bits_per_sample)\n    f.write(data)\n\n\ndef _write_data_chunk(f: BufferedWriter, waveform: np.ndarray) -> None:\n    data = b\"data\"\n    f.write(data)\n\n    data_chunk_size = waveform.nbytes\n    data = struct.pack(\"<I\", data_chunk_size)\n    f.write(data)\n\n    _waveform = waveform.flatten()\n    data = _waveform.view(\"b\").data\n    f.write(data)\n"
  },
  {
    "path": "ssspy/linalg/__init__.py",
    "content": "from ._solve import solve\nfrom .cubic import cbrt\nfrom .eigh import eigh, eigh2\nfrom .inv import inv2\nfrom .lqpqm import lqpqm2\nfrom .mean import gmeanmh\nfrom .polynomial import solve_cubic\nfrom .quadratic import quadratic\nfrom .sqrtm import invsqrtmh, sqrtmh\n\n__all__ = [\n    \"cbrt\",\n    \"quadratic\",\n    \"inv2\",\n    \"eigh\",\n    \"eigh2\",\n    \"sqrtmh\",\n    \"invsqrtmh\",\n    \"gmeanmh\",\n    \"solve_cubic\",\n    \"lqpqm2\",\n    \"solve\",\n]\n"
  },
  {
    "path": "ssspy/linalg/_solve.py",
    "content": "import numpy as np\nfrom packaging import version\n\nnp_version = np.__version__\n\nIS_NUMPY_GE_2 = version.parse(np.__version__) >= version.parse(\"2\")\n\n\ndef solve(a: np.ndarray, b: np.ndarray) -> np.ndarray:\n    requires_new_axis = IS_NUMPY_GE_2 and a.ndim == b.ndim + 1\n\n    if requires_new_axis:\n        b = b[..., np.newaxis]\n\n    x = np.linalg.solve(a, b)\n\n    if requires_new_axis:\n        x = x[..., 0]\n        b = b[..., 0]\n\n    return x\n"
  },
  {
    "path": "ssspy/linalg/cubic.py",
    "content": "import numpy as np\n\n\ndef cbrt(x: np.ndarray) -> np.ndarray:\n    \"\"\"Return cube-root of an array.\n\n    Args:\n        x (np.ndarray):\n            Values to compute cube-root. Complex value is available.\n\n    Returns:\n        np.ndarray of cube-root.\n\n    \"\"\"\n    if np.iscomplexobj(x):\n        amplitude = np.abs(x)\n        phase = np.angle(x)\n        x_cbrt = np.cbrt(amplitude) * np.exp(1j * phase / 3)\n    else:\n        x_cbrt = np.cbrt(x)\n\n    return x_cbrt\n"
  },
  {
    "path": "ssspy/linalg/eigh.py",
    "content": "from typing import Callable, Optional, Tuple, Union\n\nimport numpy as np\n\nfrom .inv import inv2\n\n\ndef eigh(\n    A: np.ndarray, B: Optional[np.ndarray] = None, type: Optional[int] = 1\n) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:\n    r\"\"\"Compute the (generalized) eigenvalues and eigenvectors of a complex Hermitian \\\n    (conjugate symmetric) or a real symmetric matrix.\n\n    If ``B`` is ``None``, solve :math:`\\boldsymbol{A}\\boldsymbol{z} = \\lambda\\boldsymbol{z}`.\n\n    If ``B`` is given,\n    solve :math:`\\boldsymbol{A}\\boldsymbol{z} = \\lambda\\boldsymbol{B}\\boldsymbol{z}`.\n\n    Args:\n        A (numpy.ndarray):\n            A complex Hermitian matrix with shape of (\\*, n_channels, n_channels).\n        B (numpy.ndarray, optional):\n            A complex Hermitian matrix with shape of (\\*, n_channels, n_channels).\n        type (int):\n            For the generalized eigenproblem, this value specifies the type of problem.\n            Only ``1``, ``2``, and ``3`` are supported.\n\n            - When ``type=1``, solve :math:`\\boldsymbol{Az}=\\lambda\\boldsymbol{Bz}`.\n            - When ``type=2``, solve :math:`\\boldsymbol{ABz}=\\lambda\\boldsymbol{z}`.\n            - When ``type=3``, solve :math:`\\boldsymbol{BAz}=\\lambda\\boldsymbol{z}`.\n\n    Returns:\n        A tuple of (eigenvalues, eigenvectors)\n            - Eigenvalues have shape of (\\*, n_channels).\n            - Eigenvectors have shape of (\\*, n_channels, n_channels).\n\n    .. note::\n        If ``B`` is given, we use cholesky decomposition to\n        satisfy :math:`\\boldsymbol{L}\\boldsymbol{L}^{\\mathsf{H}}=\\boldsymbol{B}`.\n\n        Then, solve :math:`\\boldsymbol{C}\\boldsymbol{y} = \\lambda\\boldsymbol{y}`,\n        where :math:`\\boldsymbol{C}=\\boldsymbol{L}^{-1}\\boldsymbol{A}\\boldsymbol{L}^{-\\mathsf{H}}`.\n\n        The generalized eigenvalues of :math:`\\boldsymbol{A}` and :math:`\\boldsymbol{B}`\n        are computed by :math:`\\boldsymbol{L}^{-\\mathsf{H}}\\boldsymbol{y}`.\n\n    Examples:\n\n        .. code-block:: python\n\n            >>> import numpy as np\n            >>> from ssspy.linalg import eigh\n            >>> A = np.array([[1, -2j], [2j, 3]])\n            >>> lamb, z = eigh(A)\n            >>> lamb; z\n            array([-0.23606798,  4.23606798])\n            array([[-0.85065081+0.j        , -0.52573111+0.j        ],\n                [ 0.        +0.52573111j,  0.        -0.85065081j]])\n            >>> np.allclose(A @ z, lamb * z)\n            True\n\n        .. code-block:: python\n\n            >>> import numpy as np\n            >>> from ssspy.linalg import eigh\n            >>> A = np.array([[1, -2j], [2j, 3]])\n            >>> B = np.array([[2, -3j], [3j, 5]])\n            >>> lamb, z = eigh(A, B)\n            >>> lamb; z\n            array([-1.61803399,  0.61803399])\n            array([[ 2.22703273+0.j        , -0.20081142+0.j        ],\n                [ 0.        -1.37638192j,  0.        -0.3249197j ]])\n            >>> np.allclose(A @ z, lamb * (B @ z))\n            True\n    \"\"\"\n    if B is None:\n        return np.linalg.eigh(A)\n\n    lamb, z = _eigh(A, B, type=type, inv=np.linalg.inv)\n\n    return lamb, z\n\n\ndef eigh2(\n    A: np.ndarray, B: Optional[np.ndarray] = None, type: Optional[int] = 1\n) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:\n    r\"\"\"Compute the (generalized) eigenvalues and eigenvectors of a 2x2 complex Hermitian \\\n    (conjugate symmetric) or a real symmetric matrix.\n\n    If ``B`` is ``None``, solve :math:`\\boldsymbol{A}\\boldsymbol{z} = \\lambda\\boldsymbol{z}`.\n\n    If ``B`` is given,\n    solve :math:`\\boldsymbol{A}\\boldsymbol{z} = \\lambda\\boldsymbol{B}\\boldsymbol{z}`.\n\n    Args:\n        A (numpy.ndarray):\n            A complex Hermitian matrix with shape of (\\*, 2, 2).\n        B (numpy.ndarray, optional):\n            A complex Hermitian matrix with shape of (\\*, 2, 2).\n        type (int):\n            For the generalized eigenproblem, this value specifies the type of problem.\n            Only ``1``, ``2``, and ``3`` are supported.\n\n            - When ``type=1``, solve :math:`\\boldsymbol{Az}=\\lambda\\boldsymbol{Bz}`.\n            - When ``type=2``, solve :math:`\\boldsymbol{ABz}=\\lambda\\boldsymbol{z}`.\n            - When ``type=3``, solve :math:`\\boldsymbol{BAz}=\\lambda\\boldsymbol{z}`.\n\n    Returns:\n        A tuple of (eigenvalues, eigenvectors)\n            - Eigenvalues have shape of (\\*, 2).\n            - Eigenvectors have shape of (\\*, 2, 2).\n\n    .. note::\n        If ``B`` is given, we use cholesky decomposition to\n        satisfy :math:`\\boldsymbol{L}\\boldsymbol{L}^{\\mathsf{H}}=\\boldsymbol{B}`.\n\n        Then, solve :math:`\\boldsymbol{C}\\boldsymbol{y} = \\lambda\\boldsymbol{y}`,\n        where :math:`\\boldsymbol{C}=\\boldsymbol{L}^{-1}\\boldsymbol{A}\\boldsymbol{L}^{-\\mathsf{H}}`.\n\n        The generalized eigenvalues of :math:`\\boldsymbol{A}` and :math:`\\boldsymbol{B}`\n        are computed by :math:`\\boldsymbol{L}^{-\\mathsf{H}}\\boldsymbol{y}`.\n\n        See also https://github.com/tky823/ssspy/issues/115 for this implementation.\n\n    Examples:\n\n        .. code-block:: python\n\n            >>> import numpy as np\n            >>> from ssspy.linalg import eigh2\n            >>> A = np.array([[1, -2j], [2j, 3]])\n            >>> lamb, z = eigh2(A)\n            >>> lamb; z\n            array([-0.23606798,  4.23606798])\n            array([[-0.85065081+0.j        , -0.52573111+0.j        ],\n                [ 0.        +0.52573111j,  0.        -0.85065081j]])\n            >>> np.allclose(A @ z, lamb * z)\n            True\n\n        .. code-block:: python\n\n            >>> import numpy as np\n            >>> from ssspy.linalg import eigh2\n            >>> A = np.array([[1, -2j], [2j, 3]])\n            >>> B = np.array([[2, -3j], [3j, 5]])\n            >>> lamb, z = eigh2(A, B)\n            >>> lamb; z\n            array([-1.61803399,  0.61803399])\n            array([[ 2.22703273+0.j        , -0.20081142+0.j        ],\n                [ 0.        -1.37638192j,  0.        -0.3249197j ]])\n            >>> np.allclose(A @ z, lamb * (B @ z))\n            True\n    \"\"\"\n    assert A.shape[-2:] == (2, 2), \"2x2 matrix is expected, but given shape of {}.\".format(A.shape)\n\n    if B is None:\n        return np.linalg.eigh(A)\n\n    lamb, z = _eigh(A, B, type=type, inv=inv2)\n\n    return lamb, z\n\n\ndef _eigh(\n    A: np.ndarray,\n    B: np.ndarray,\n    type: int = 1,\n    inv: Callable[[np.ndarray], np.ndarray] = None,\n) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:\n    if inv is None:\n        inv = np.linalg.inv\n\n    L = np.linalg.cholesky(B)\n\n    if type == 1:\n        L_inv = inv(L)\n        L_inv_Hermite = np.swapaxes(L_inv, -2, -1)\n\n        if np.iscomplexobj(L_inv_Hermite):\n            L_inv_Hermite = L_inv_Hermite.conj()\n\n        C = L_inv @ A @ L_inv_Hermite\n    elif type in [2, 3]:\n        L_Hermite = np.swapaxes(L, -2, -1)\n\n        if np.iscomplexobj(L_Hermite):\n            L_Hermite = L_Hermite.conj()\n\n        C = L_Hermite @ A @ L\n\n        if type == 2:\n            L_inv_Hermite = inv(L_Hermite)\n        else:\n            L_inv_Hermite = None\n    else:\n        raise ValueError(\"Invalid type={} is given.\".format(type))\n\n    lamb, y = np.linalg.eigh(C)\n\n    if type in [1, 2]:\n        z = L_inv_Hermite @ y\n    elif type == 3:\n        z = L @ y\n    else:\n        raise ValueError(\"Invalid type={} is given.\".format(type))\n\n    return lamb, z\n"
  },
  {
    "path": "ssspy/linalg/inv.py",
    "content": "import numpy as np\n\n\ndef inv2(X: np.ndarray) -> np.ndarray:\n    r\"\"\"Compute the (multiplicative) inverse of a 2x2 matrix.\n\n    Args:\n        X (numpy.ndarray):\n            2x2 matrix to be inverted. The shape is (\\*, 2, 2).\n\n    Returns:\n        numpy.ndarray:\n            (Multiplicative) inverse of the matrix X.\n\n    Examples:\n        .. code-block:: python\n\n            >>> import numpy as np\n            >>> from ssspy.linalg import inv2\n            >>> X = np.array([[0, 1], [2, 3]])\n            >>> X_inv = inv2(X)\n            >>> np.allclose(X @ X_inv, np.eye(2))\n            True\n            >>> np.allclose(X_inv @ X, np.eye(2))\n            True\n\n        .. code-block:: python\n\n            >>> import numpy as np\n            >>> from ssspy.linalg import inv2\n            >>> X = np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])\n            >>> inv2(X)\n            array([[[-1.5,  0.5],\n                    [ 1. , -0. ]],\n\n                [[-3.5,  2.5],\n                    [ 3. , -2. ]]])\n    \"\"\"\n    shape = X.shape\n\n    assert shape[-2:] == (2, 2), \"2x2 matrix is expected, but given shape of {}.\".format(shape)\n\n    a = X[..., 0, 0]\n    b = X[..., 0, 1]\n    c = X[..., 1, 0]\n    d = X[..., 1, 1]\n\n    det = a * d - b * c\n\n    X_adj = np.stack([d, -b, -c, a], axis=-1)\n    X_adj = X_adj.reshape(shape[:-2] + (2, 2))\n    X_inv = X_adj / det[..., np.newaxis, np.newaxis]\n\n    return X_inv\n"
  },
  {
    "path": "ssspy/linalg/lqpqm.py",
    "content": "import functools\nimport warnings\nfrom typing import Callable, Optional, Union\n\nimport numpy as np\n\nfrom ..special.flooring import identity, max_flooring\nfrom .cubic import cbrt\n\nEPS = 1e-10\n\n\ndef lqpqm2(\n    H: np.ndarray,\n    v: np.ndarray,\n    z: np.ndarray,\n    flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n        max_flooring, eps=EPS\n    ),\n    singular_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"flooring\",\n    max_iter: int = 10,\n) -> None:\n    r\"\"\"Solve of log-quadratically penelized quadratic minimization (type 2).\n\n    .. math::\n\n        \\check{\\boldsymbol{q}}_{in}\n        = \\min_{\\check{\\boldsymbol{q}}_{in}}\n        ~~\\check{\\boldsymbol{q}}_{in}^{\\mathsf{H}}\\check{\\boldsymbol{q}}_{in}\n        - \\log\\left((\\check{\\boldsymbol{q}}_{in}+\\boldsymbol{v}_{in})^{\\mathsf{H}}\n        \\boldsymbol{H}_{in}(\\check{\\boldsymbol{q}}_{in}+\\boldsymbol{v}_{in})\n        + z_{in}\n        \\right)\n\n    Args:\n        H (numpy.ndarray): Positive semidefinite matrices of shape\n            (n_bins, n_sources - 1, n_sources - 1).\n        v (numpy.ndarray): Linear terms in LQPQM of shape (n_bins, n_sources - 1).\n        z (numpy.ndarray): Constant terms in LQPQM of shape (n_bins,).\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n        singular_fn (callable, optional):\n            A flooring function to return singular condition.\n            This function is expected to return the same shape bool tensor as the input.\n            If ``singular_fn=None``, ``lambda x: x == 0`` is used.\n            Default: ``flooring``.\n        max_iter (int):\n            Maximum number of Newton-Raphson method. Default: ``10``.\n\n    Returns:\n        np.ndarray: Solutions of LQPQM type-2 of shape (n_bins, n_sources - 1).\n\n    \"\"\"\n    if flooring_fn is None:\n        flooring_fn = identity\n\n    if singular_fn is None:\n\n        def _is_zero(x: np.ndarray) -> np.ndarray:\n            return x == 0\n\n        singular_fn = _is_zero\n    elif singular_fn == \"flooring\":\n\n        def _is_lower_than_floor(x: np.ndarray) -> np.ndarray:\n            return x < flooring_fn(0)\n\n        singular_fn = _is_lower_than_floor\n    else:\n        assert callable(singular_fn), \"singular_fn should be callable.\"\n\n    phi, sigma = np.linalg.eigh(H)\n    norm = np.linalg.norm(v, axis=-1)\n    is_singular = singular_fn(norm)\n\n    # when v = 0\n    phi_singular = phi[is_singular]\n    sigma_singular = sigma[is_singular]\n    z_singular = z[is_singular]\n\n    phi_max_singular = phi_singular[:, -1]\n    sigma_max_singular = sigma_singular[:, -1]\n    lamb_singular = np.maximum(z_singular, phi_max_singular)\n    scale = (lamb_singular - z_singular) / phi_max_singular\n    scale = np.maximum(scale, 0)\n    scale = np.sqrt(scale)\n    y_singular = scale[..., np.newaxis] * sigma_max_singular\n\n    # when v != 0\n    phi_non_singular = phi[~is_singular]\n    sigma_non_singular = sigma[~is_singular]\n    v_non_singular = v[~is_singular]\n    z_non_singular = z[~is_singular]\n\n    v_tilde_non_singular = np.sum(\n        sigma_non_singular.conj() * v_non_singular[:, :, np.newaxis], axis=-2\n    )\n    lamb_non_singular = solve_equation(\n        phi_non_singular,\n        v_tilde_non_singular,\n        z_non_singular,\n        flooring_fn=flooring_fn,\n        max_iter=max_iter,\n        normalization=True,\n    )\n\n    num = phi_non_singular * v_tilde_non_singular\n    denom = lamb_non_singular[..., np.newaxis] - phi_non_singular\n    v_nonsingular = num / denom\n    y_non_singular = np.sum(sigma_non_singular * v_nonsingular[:, np.newaxis, :], axis=-1)\n\n    y = np.zeros_like(v)\n    y[is_singular] = y_singular\n    y[~is_singular] = y_non_singular\n\n    return y\n\n\ndef solve_equation(\n    phi: np.ndarray,\n    v: np.ndarray,\n    z: np.ndarray,\n    flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n        max_flooring, eps=EPS\n    ),\n    max_iter: int = 10,\n    normalization: bool = True,\n):\n    r\"\"\"Find largest root of :math:`f(\\lambda_{in})`, where\n\n    .. math::\n\n        f(\\lambda_{in})\n        = \\lambda_{in}^{2}\\sum_{n'}\n        \\frac{\\phi_{inn'}|\\tilde{v}_{inn'}|^{2}}{(\\lambda_{in}-\\phi_{inn'})^{2}}\n        - \\lambda_{in} + z_{in}\n\n    Args:\n        phi (numpy.ndarray): Eigen values defined in LQPQM of shape (n_bins, n_sources).\n        v (numpy.ndarray): Linear term defined in LQPQM of shape (n_bins, n_sources).\n        z (numpy.ndarray): Constant term defined in LQPQM of shape (n_bins,).\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n        max_iter (int): Maximum iteration of Newton-Raphson method. Default: ``10``.\n        normalization (bool): If ``True``, coefficients are normalized by ``phi_max``.\n\n    Returns:\n        numpy.ndarray of largest root of :math:`f(\\lambda_{in})`.\n        The shape is (n_bins,).\n\n    \"\"\"\n    if flooring_fn is None:\n        flooring_fn = identity\n\n    n_bins, n_sources = phi.shape\n\n    non_zero_mask = phi * np.abs(v) ** 2 >= flooring_fn(0)\n    phi = non_zero_mask * phi\n    v = non_zero_mask * v\n\n    max_index = np.argmax(phi, axis=-1) + np.arange(0, n_bins * n_sources, n_sources)\n    phi_flatten = phi.flatten()\n    v_flatten = v.flatten()\n    phi_max = phi_flatten[max_index]\n    v_max = v_flatten[max_index]\n    phi_max = flooring_fn(phi_max)\n\n    if normalization:\n        phi_max_original = phi_max\n        phi = phi / phi_max[:, np.newaxis]\n        v = v / phi_max[:, np.newaxis]\n        v_max = v_max / phi_max\n        z = z / phi_max\n        phi_max = phi_max / phi_max  # i.e. phi_max = 1\n    else:\n        phi_max_original = None\n\n    # Find largest root of cubic polynomial for initialization\n    A = -(phi_max * np.abs(v_max) ** 2 + 2 * phi_max + z)\n    B = (phi_max + 2 * z) * phi_max\n    C = -(phi_max**2) * z\n    lamb = _find_largest_root(A, B, C)\n\n    is_valid = lamb > phi_max\n    lamb[~is_valid] = phi_max[~is_valid] + flooring_fn(0)\n    lamb = np.maximum(lamb, z)\n\n    for iter_idx in range(max_iter):\n        f = _fn(lamb, phi, v, z)\n        is_convergence = np.abs(f) <= flooring_fn(0)\n\n        if np.all(is_convergence):\n            break\n\n        df = _d_fn(lamb, phi, v, z)\n        mu = lamb - f / df\n        lamb = np.where(mu > phi_max, mu, (phi_max + lamb) / 2)\n\n    if iter_idx == max_iter - 1:\n        f = _fn(lamb, phi, v, z)\n        is_convergence = np.abs(f) <= flooring_fn(0)\n\n        if not np.all(is_convergence):\n            warnings.warn(\n                f\"Newton-Raphson method did not converge in {max_iter} iterations.\", UserWarning\n            )\n\n    if normalization:\n        lamb = lamb * phi_max_original\n\n    return lamb\n\n\ndef _find_largest_root(A: np.ndarray, B: np.ndarray, C: np.ndarray) -> np.ndarray:\n    r\"\"\"Find largest (real) roots of the following cubic equations:\n\n    .. math::\n\n        x^{3} + Ax^{2} + Bx + C = 0.\n\n    Args:\n        A (numpy.ndarray): Coefficients of quadratic terms with shape of (\\*).\n        B (numpy.ndarray): Coefficients of linear terms with shape of (\\*).\n        C (numpy.ndarray): Coefficients of constant terms with shape of (\\*).\n\n    Returns:\n        numpy.ndarray of largest (real) roots.\n\n    .. note::\n\n        :math:`x^{3} + Ax^{2} + Bx + C = 0` can be transformed into\n        :math:`t^{3} + pt + q = 0` by :math:`t=x+\\frac{A}{3}`.\n        When :math:`p<0` and :math:`\\frac{q^{2}}{4}+\\frac{p^{3}}{27}\\leq 0`,\n        there exists three real solutions: :math:`t=u-\\frac{p}{3u}`,\n        :math:`x=u\\omega-\\frac{p\\omega^{*}}{3u}`, and\n        :math:`x=u\\omega^{*}-\\frac{p\\omega}{3u}`, where\n\n        .. math::\n\n            u\n            &=\\sqrt[3]{-\\frac{q}{2}+\\sqrt{\\frac{q^{2}}{4} + \\frac{p^{3}}{27}}}, \\\\\n            \\omega\n            &=\\frac{-1+j\\sqrt{3}}{2}.\n\n        When :math:`p<0` and :math:`\\frac{q^{2}}{4}+\\frac{p^{3}}{27}>0`,\n        :math:`t=u-p/(3u)` is a unique real solution.\n        When :math:`p>0`, :math:`t=u-p/(3u)` is a unique real solution.\n        Otherwise (when :math:`p=0`), :math:`t=\\sqrt[3]{-q}` is a unique real solution.\n\n    \"\"\"\n    P = -(A**2) / 3 + B\n    Q = (2 * A**3) / 27 - (A * B) / 3 + C\n\n    omega = (-1 + 1j * np.sqrt(3)) / 2\n    omega_conj = (-1 - 1j * np.sqrt(3)) / 2\n\n    discriminant = (Q / 2) ** 2 + (P / 3) ** 3\n    discriminant = discriminant.astype(np.complex128)\n    U = cbrt(-Q / 2 + np.sqrt(discriminant))\n    # When U = 0, P is always 0 in real coefficients cases.\n    is_singular = U == 0\n    U = np.where(is_singular, 1, U)\n    V = -P / (3 * U)\n\n    X1 = U + V\n    X1 = np.where(is_singular, cbrt(-Q), X1)\n    X2 = np.real(U * omega + V * omega_conj)\n    X3 = np.real(U * omega_conj + V * omega)\n\n    roots = np.stack([X1, X2, X3], axis=-1)\n    roots = np.real(roots)\n\n    is_monotonic = P >= 0\n    is_unique = np.array([True, False, False])\n\n    imaginary_mask = is_monotonic[..., np.newaxis] & ~is_unique\n    roots = np.where(imaginary_mask, -float(\"inf\"), roots)\n    imaginary_mask = ~is_monotonic[..., np.newaxis] & ~is_unique\n    is_positive = discriminant > 0\n    roots = np.where(imaginary_mask & is_positive[..., np.newaxis], -float(\"inf\"), roots)\n    root = np.max(roots, axis=-1)\n    root = root - A / 3\n\n    return root\n\n\ndef _fn(lamb: np.ndarray, phi: np.ndarray, v: np.ndarray, z: np.ndarray) -> np.ndarray:\n    r\"\"\"Compute values of :math:`f(\\lambda_{in})`, where\n\n    .. math::\n\n        f(\\lambda_{in})\n        = \\lambda_{in}^{2}\\sum_{n'}\n        \\frac{\\phi_{inn'}|\\tilde{v}_{inn'}|^{2}}{(\\lambda_{in}-\\phi_{inn'})^{2}}\n        - \\lambda_{in} + z_{in}\n\n    Args:\n        lamb (numpy.ndarray): Argument of :math:`f(\\lambda_{in})` with shape of (n_bins,).\n        phi (numpy.ndarray): Eigen values defined in LQPQM of shape (n_bins, n_sources).\n        v (numpy.ndarray): Linear term defined in LQPQM of shape (n_bins, n_sources).\n        z (numpy.ndarray): Constant term defined in LQPQM of shape (n_bins,).\n\n    Returns:\n        numpy.ndarray of values :math:`f(\\lambda_{in})` of shape (n_bins,).\n\n    \"\"\"\n    num = phi * np.abs(v) ** 2\n    denom = (lamb[..., np.newaxis] - phi) ** 2\n    f = lamb**2 * np.sum(num / denom, axis=-1) - lamb + z\n\n    return f\n\n\ndef _d_fn(\n    lamb: np.ndarray,\n    phi: np.ndarray,\n    v: np.ndarray,\n    z: Optional[np.ndarray] = None,\n):\n    r\"\"\"Compute values of :math:`f'(\\lambda_{in})`, where\n\n    .. math::\n\n        f'(\\lambda_{in})\n        = -2\\lambda_{in}\\sum_{n'}\n        \\frac{\\phi_{inn'}^{2}|\\tilde{v}_{inn'}|^{2}}{(\\lambda_{in}-\\phi_{inn'})^{3}}\n        - 1\n\n    Args:\n        lamb (numpy.ndarray): Argument of :math:`f'(\\lambda_{in})` with shape of (n_bins,).\n        phi (numpy.ndarray): Eigen values defined in LQPQM of shape (n_bins, n_sources).\n        v (numpy.ndarray): Linear term defined in LQPQM of shape (n_bins, n_sources).\n        z (numpy.ndarray, optional): Constant term defined in LQPQM of shape (n_bins,).\n            This argument is not used in this funtion.\n\n    Returns:\n        numpy.ndarray of values :math:`f'(\\lambda_{in})` of shape (n_bins,).\n\n    \"\"\"\n    num = (phi * np.abs(v)) ** 2\n    denom = (lamb[..., np.newaxis] - phi) ** 3\n    df = -2 * lamb * np.sum(num / denom, axis=-1) - 1\n\n    return df\n"
  },
  {
    "path": "ssspy/linalg/mean.py",
    "content": "import numpy as np\n\nfrom .eigh import eigh\n\n\ndef gmeanmh(A: np.ndarray, B: np.ndarray, type: int = 1) -> np.ndarray:\n    r\"\"\"Compute the geometric mean of complex Hermitian \\\n    (conjugate symmetric) or real symmetric matrices.\n\n    The geometric mean of :math:`\\boldsymbol{A}` and :math:`\\boldsymbol{B}`\n    is defined as follows [#bhatia2009positive]_:\n\n    .. math::\n        \\boldsymbol{A}\\#\\boldsymbol{B}\n        &= \\boldsymbol{A}^{1/2}\n        (\\boldsymbol{A}^{-1/2}\\boldsymbol{B}\\boldsymbol{A}^{-1/2})^{1/2}\n        \\boldsymbol{A}^{1/2} \\\\\n        &= \\boldsymbol{A}(\\boldsymbol{A}^{-1}\\boldsymbol{B})^{1/2} \\\\\n        &= (\\boldsymbol{A}\\boldsymbol{B}^{-1})^{1/2}\\boldsymbol{B}.\n\n    This is a solution of the following equation for\n    complex Hermitian or real symmetric matrices,\n    :math:`\\boldsymbol{A}`, :math:`\\boldsymbol{B}`, and :math:`\\boldsymbol{X}`:\n\n    .. math::\n        \\boldsymbol{X}\\boldsymbol{A}^{-1}\\boldsymbol{X} = \\boldsymbol{B}.\n\n    .. note::\n        In this toolkit, :math:`\\boldsymbol{A}\\#\\boldsymbol{B}` is computed by\n        :math:`\\boldsymbol{B}(\\boldsymbol{B}^{-1}\\boldsymbol{A})^{1/2}`\n        in terms of computational speed.\n        Note that :math:`\\boldsymbol{A}\\#\\boldsymbol{B}` is equal to\n        :math:`\\boldsymbol{B}\\#\\boldsymbol{A}`.\n        For comparison of computational time, see https://github.com/tky823/ssspy/issues/210.\n\n    .. note::\n        :math:`(\\boldsymbol{B}^{-1}\\boldsymbol{A})^{1/2}` is computed by\n        generalized eigendecomposition.\n        Let :math:`\\lambda` and :math:`z` be the eigenvalue and eigenvector of\n        the generalized eigenproblem :math:`\\boldsymbol{Az}=\\lambda\\boldsymbol{Bz}`.\n        Then, :math:`(\\boldsymbol{B}^{-1}\\boldsymbol{A})^{1/2}` is computed by\n        :math:`\\boldsymbol{Z}\\boldsymbol{\\Lambda}^{1/2}\\boldsymbol{Z}^{-1}`,\n        where the main diagonals of :math:`\\boldsymbol{\\Lambda}` are :math:`\\lambda` s\n        and the columns of :math:`\\boldsymbol{Z}` are :math:`\\boldsymbol{z}` s.\n\n    Args:\n        A (numpy.ndarray):\n            A complex Hermitian matrix with shape of (\\*, n_channels, n_channels).\n        B (numpy.ndarray):\n            A complex Hermitian matrix with shape of (\\*, n_channels, n_channels).\n        type (int):\n            This value specifies the type of geometric mean.\n            Only ``1``, ``2``, and ``3`` are supported.\n\n            - When ``type=1``, return :math:`\\boldsymbol{A}\\#\\boldsymbol{B}`.\n            - When ``type=2``, return :math:`\\boldsymbol{A}^{-1}\\#\\boldsymbol{B}`.\n            - When ``type=3``, return :math:`\\boldsymbol{A}\\#\\boldsymbol{B}^{-1}`.\n\n    Returns:\n        Geometric mean of matrices with shape of (\\*, n_channels, n_channels).\n\n    .. [#bhatia2009positive] R. Bhatia,\n        \"Positive definite matrices,\"\n        Princeton university press, 2009.\n    \"\"\"  # noqa: W605\n    lamb, Z = eigh(A, B, type=type)\n    lamb = np.sqrt(lamb)\n    Lamb = lamb[..., np.newaxis] * np.eye(Z.shape[-1])\n    ZLZ = Z @ Lamb @ np.linalg.inv(Z)\n\n    if type == 1:\n        BA = ZLZ\n        G = B @ BA\n    elif type == 2:\n        AB = ZLZ\n        G = np.linalg.inv(A) @ AB\n    elif type == 3:\n        BA = ZLZ\n        G = np.linalg.inv(B) @ BA\n    else:\n        raise ValueError(\"Invalid type={} is given.\".format(type))\n\n    return G\n"
  },
  {
    "path": "ssspy/linalg/polynomial.py",
    "content": "from typing import Optional\n\nimport numpy as np\nfrom numpy.linalg import LinAlgError\n\nfrom .cubic import cbrt\n\n\ndef solve_cubic(\n    A: np.ndarray,\n    B: np.ndarray,\n    C: np.ndarray,\n    D: Optional[np.ndarray] = None,\n    all: bool = True,\n) -> np.ndarray:\n    r\"\"\"Find roots of cubic equations.\n\n    Args:\n        A (numpy.ndarray): Coefficients of cubic or quadratic terms.\n        B (numpy.ndarray): Coefficients of quadratic or linear terms.\n        C (numpy.ndarray): Coefficients of linear or constant terms.\n        D (numpy.ndarray, optional): Constant terms.\n        all (bool): If ``all=True``, returns all roots. Otherwise, returns one of them.\n            Default: ``True``.\n\n    Returns:\n        numpy.ndarray: All roots of cuadratic equations of shape (3, \\*) if ``all=True``.\n            Otherwise, (\\*).\n\n    This function solves the following equations if ``D`` is given:\n\n    .. math::\n\n        Ax^{3} + Bx^{2} + Cx + D = 0.\n\n    If ``D`` is not given, solves\n\n    .. math::\n\n        x^{3} + Ax^{2} + Bx + C = 0.\n\n    \"\"\"\n    if D is None:\n        P = -(A**2) / 3 + B\n        Q = (2 * A**3) / 27 - (A * B) / 3 + C\n\n        X = _find_cubic_roots(P, Q)\n        x = X - A / 3\n\n        return x if all else x[0]\n    else:\n        if np.any(A == 0):\n            raise LinAlgError(\"Coefficients include zero.\")\n\n        return solve_cubic(B / A, C / A, D / A, all=all)\n\n\ndef _find_cubic_roots(P: np.ndarray, Q: np.ndarray) -> np.ndarray:\n    r\"\"\"Find roots of the following cubic equations:\n\n    .. math::\n\n        x^{3} + px + q = 0\n\n\n    Args:\n        P (np.ndarray): Coefficients of cubic equation.\n        Q (np.ndarray): Coefficients of cubic equation.\n\n    Returns:\n        numpy.ndarray of the three roots.\n        The shape is (3, \\*).\n\n    \"\"\"\n    P = P.astype(np.complex128)\n    Q = Q.astype(np.complex128)\n    omega = (-1 + 1j * np.sqrt(3)) / 2\n    omega_conj = (-1 - 1j * np.sqrt(3)) / 2\n\n    discriminant = (Q / 2) ** 2 + (P / 3) ** 3\n\n    U = cbrt(-Q / 2 + np.sqrt(discriminant))\n    # U = 0, when P = 0.\n    is_singular = P == 0\n    U = np.where(is_singular, 1, U)\n    V = -P / (3 * U)\n\n    X1 = U + V\n    X1 = np.where(is_singular, cbrt(-Q), X1)\n    X2 = U * omega + V * omega_conj\n    X2 = np.where(is_singular, X1 * omega, X2)\n    X3 = U * omega_conj + V * omega\n    X3 = np.where(is_singular, X1 * omega_conj, X3)\n\n    return np.stack([X1, X2, X3], axis=0)\n"
  },
  {
    "path": "ssspy/linalg/prox.py",
    "content": "import numpy as np\n\n__all__ = [\"l21\", \"neg_log\", \"neg_logdet\"]\n\n\ndef l1(x, step_size: float = 1) -> np.ndarray:\n    norm = np.abs(x)\n\n    # to suppress warning RuntimeWarning\n    norm = np.where(norm < step_size, step_size, norm)\n\n    return np.maximum(1 - step_size / norm, 0) * x\n\n\ndef l21(x: np.ndarray, step_size: float = 1, axis1: int = -2, axis2: int = -1):\n    r\"\"\"Proximal operator of L21 norm.\n\n    Args:\n        x (numpy.ndarray):\n            Input tensor.\n        step_size (float):\n            Step size parameter.\n\n    Returns:\n        numpy.ndarray:\n            Output tensor. The shape is same as input.\n    \"\"\"\n    norm = np.linalg.norm(x, axis=axis2, keepdims=True)\n\n    # to suppress warning RuntimeWarning\n    norm = np.where(norm < step_size, step_size, norm)\n\n    return np.maximum(1 - step_size / norm, 0) * x\n\n\ndef neg_log(x: np.ndarray, step_size: float = 1):\n    r\"\"\"Proximal operator of negative logarithm function.\n\n    Proximal operator of :math:`-\\log(x)` is defined as follows:\n\n    .. math::\n        \\mathrm{prox}_{-\\mu\\log}(x)\n        = \\frac{x + \\sqrt{x^{2} + 4\\mu}}{2}\n\n    Args:\n        x (np.ndarray):\n            Shape is (n_bins, n_sources, n_channels).\n        step_size (float):\n            Step size parameter. Default: 1.\n\n    Returns:\n        np.ndarray:\n            Proximal operator of negative logarithm function.\n    \"\"\"\n    assert np.all(x >= 0)\n\n    output = (x + np.sqrt(x**2 + 4 * step_size)) / 2\n\n    return output\n\n\ndef neg_logdet(X: np.ndarray, step_size=1):\n    r\"\"\"Proximal operator of negative log-determinant.\n\n    :math:`X\\in\\mathbb{C}^{N\\times M}`\n\n    .. math::\n        \\mathrm{prox}_{-\\mu\\log}(\\boldsymbol{X})\n        &= \\boldsymbol{U}\\tilde{\\boldsymbol{\\Sigma}}\\boldsymbol{V}^{\\mathsf{H}} \\\\\n        \\tilde{\\boldsymbol{\\Sigma}}\n        &= \\mathrm{diag}(\\mathrm{prox}_{-\\mu\\log}(\\sigma_{1}),\n        \\ldots,\\mathrm{prox}_{-\\mu\\log}(\\sigma_{M}))\n\n    Args:\n        X (np.ndarray):\n            Shape is (n_bins, n_sources, n_channels).\n        step_size (float):\n            Step size parameter. Default: 1.\n\n    Returns:\n        np.ndarray:\n            Proximal operator of log-determinant.\n    \"\"\"\n    n_channels = X.shape[-1]\n\n    U, Sigma, V = np.linalg.svd(X)\n    Sigma = neg_log(Sigma, step_size=step_size)\n    Sigma = Sigma[..., np.newaxis] * np.eye(n_channels)\n    USV = U @ Sigma @ V\n\n    return USV\n"
  },
  {
    "path": "ssspy/linalg/quadratic.py",
    "content": "import numpy as np\n\n\ndef quadratic(X: np.ndarray, A: np.ndarray) -> np.ndarray:\n    r\"\"\"Compute values of quadratic forms.\n\n    Args:\n        X (np.ndarray):\n            Input vectors with shape of (\\*, n_channels).\n        A (np.ndarray):\n            Input matrices with shape of (\\*, n_channels, n_channels).\n\n    Returns:\n        Computed values of quadratic forms.\n        The shape is (\\*,).\n    \"\"\"\n    if np.iscomplexobj(X):\n        X_Hermite = X.conj()\n    else:\n        X_Hermite = X\n\n    Y = X_Hermite[..., np.newaxis, :] @ A @ X[..., np.newaxis]\n    Y = Y[..., 0, 0]\n\n    return Y\n"
  },
  {
    "path": "ssspy/linalg/sqrtm.py",
    "content": "from typing import Callable, Optional\n\nimport numpy as np\n\nfrom .eigh import eigh\n\n\ndef sqrtmh(X: np.ndarray) -> np.ndarray:\n    r\"\"\"Compute square root of a positive semidefinite Hermitian or symmetric matrix.\n\n    Args:\n        X (numpy.ndarray):\n            A complex Hermitian or symmetric matrix with shape of (\\*, n_channels, n_channels).\n\n    Returns:\n        numpy.ndarray of square root. The shape is same as that of input.\n    \"\"\"\n    Lamb, P = eigh(X)\n\n    P_Hermite = P.swapaxes(-2, -1)\n\n    if np.iscomplexobj(X):\n        P_Hermite = P_Hermite.conj()\n\n    Lamb = np.sqrt(Lamb)[..., np.newaxis] * np.eye(Lamb.shape[-1])\n\n    return P @ Lamb @ P_Hermite\n\n\ndef invsqrtmh(\n    X: np.ndarray,\n    flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = None,\n) -> np.ndarray:\n    r\"\"\"Compute inversion of square root for a positive definite Hermitian or symmetric matrix.\n\n    Args:\n        X (numpy.ndarray):\n            A complex Hermitian matrix with shape of (\\*, n_channels, n_channels).\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to receive and return the same shape as that of X.\n            By default, the identity function (``lambda x: x``) is used.\n\n    Returns:\n        numpy.ndarray of inversion of square root. The shape is same as that of input.\n    \"\"\"\n\n    def _identity(x):\n        return x\n\n    if flooring_fn is None:\n        flooring_fn = _identity\n\n    Lamb, P = eigh(X)\n\n    P_Hermite = P.swapaxes(-2, -1)\n\n    if np.iscomplexobj(X):\n        P_Hermite = P_Hermite.conj()\n\n    Lamb = 1 / flooring_fn(np.sqrt(Lamb))\n    Lamb = Lamb[..., np.newaxis] * np.eye(Lamb.shape[-1])\n\n    return P @ Lamb @ P_Hermite\n"
  },
  {
    "path": "ssspy/special/__init__.py",
    "content": "from .flooring import add_flooring, identity, max_flooring\nfrom .logsumexp import logsumexp\nfrom .psd import to_psd\nfrom .softmax import softmax\n\n__all__ = [\"add_flooring\", \"max_flooring\", \"identity\", \"to_psd\", \"logsumexp\", \"softmax\"]\n"
  },
  {
    "path": "ssspy/special/flooring.py",
    "content": "import numpy as np\n\nEPS = 1e-10\n\n\ndef identity(input: np.ndarray) -> np.ndarray:\n    r\"\"\"Identity function.\"\"\"\n    return input\n\n\ndef max_flooring(input: np.ndarray, eps: float = EPS) -> np.ndarray:\n    r\"\"\"Max flooring operation.\"\"\"\n    return np.maximum(input, eps)\n\n\ndef add_flooring(input: np.ndarray, eps: float = EPS) -> np.ndarray:\n    r\"\"\"Add flooring operation.\"\"\"\n    return input + eps\n"
  },
  {
    "path": "ssspy/special/logsumexp.py",
    "content": "import numpy as np\n\n\ndef logsumexp(X: np.ndarray, axis: int = None, keepdims: bool = False) -> np.ndarray:\n    r\"\"\"Compute log-sum-exp values.\n\n    Args:\n        X (np.ndarray):\n            Elements to compute log-sum-exp.\n        axis (int or tuple[int], optional):\n            Axis or axes over which the sum is performed.\n            Default: ``None``.\n        keepdims (bool):\n            If ``True`` is given, ``axis`` dimension(s) is reduced.\n            Default: ``False``.\n\n    Returns:\n        np.ndarray of log-sum-exp values.\n\n    Examples:\n\n        .. code-block:: python\n\n            >>> import numpy as np\n\n            >>> X = np.array([[1, 2, 3], [4, 5, 6]])\n            >>> logsumexp(X, axis=0)\n            array([4.04858735, 5.04858735, 6.04858735])\n            >>> logsumexp(X, axis=1)\n            array([3.40760596, 6.40760596])\n    \"\"\"\n    vmax = np.max(X, axis=axis, keepdims=True)\n    exp = np.exp(X - vmax)\n    sum_exp = exp.sum(axis=axis, keepdims=True)\n    v = np.log(sum_exp) + vmax\n\n    if not keepdims:\n        v = np.squeeze(v, axis=axis)\n\n    return v\n"
  },
  {
    "path": "ssspy/special/psd.py",
    "content": "import functools\nfrom typing import Callable, Optional\n\nimport numpy as np\n\nfrom ..special.flooring import identity, max_flooring\n\nEPS = 1e-10\n\n\ndef to_psd(\n    X: np.ndarray,\n    axis1: int = -2,\n    axis2: int = -1,\n    flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(\n        max_flooring, eps=EPS\n    ),\n) -> np.ndarray:\n    r\"\"\"Ensure matrix to be positive semidefinite.\n\n    Args:\n        X (np.ndarray):\n            A complex Hermitian matrix.\n        axis1 (int):\n            Axis to be used as first axis of 2D sub-arrays.\n        axis2 (int):\n            Axis to be used as second axis of 2D sub-arrays.\n        flooring_fn (callable, optional):\n            A flooring function for numerical stability.\n            This function is expected to return the same shape tensor as the input.\n            If you explicitly set ``flooring_fn=None``,\n            the identity function (``lambda x: x``) is used.\n            Default: ``functools.partial(max_flooring, eps=1e-10)``.\n\n    Returns:\n        Positive semidefinite matrix.\n    \"\"\"\n    if flooring_fn is None:\n        flooring_fn = identity\n\n    shape = X.shape\n    n_dims = len(shape)\n\n    axis1 = n_dims + axis1 if axis1 < 0 else axis1\n    axis2 = n_dims + axis2 if axis2 < 0 else axis2\n\n    assert axis1 == n_dims - 2 and axis2 == n_dims - 1, \"axis1 == -2 and axis2 == -1\"\n\n    if np.iscomplexobj(X):\n        X = (X + X.swapaxes(axis1, axis2).conj()) / 2\n    else:\n        X = (X + X.swapaxes(axis1, axis2)) / 2\n\n    Lamb, P = np.linalg.eigh(X)\n\n    P_Hermite = P.swapaxes(-2, -1)\n\n    if np.iscomplexobj(X):\n        P_Hermite = P_Hermite.conj()\n\n    Lamb = flooring_fn(Lamb)\n    Lamb = Lamb[..., np.newaxis] * np.eye(Lamb.shape[-1])\n\n    X = P @ Lamb @ P_Hermite\n\n    if np.iscomplexobj(X):\n        X = (X + X.swapaxes(axis1, axis2).conj()) / 2\n    else:\n        X = (X + X.swapaxes(axis1, axis2)) / 2\n\n    return X\n"
  },
  {
    "path": "ssspy/special/softmax.py",
    "content": "import numpy as np\n\n\ndef softmax(X: np.ndarray, axis: int = None) -> np.ndarray:\n    r\"\"\"Compute softmax values.\n\n    Args:\n        X (np.ndarray):\n            Elements to compute softmax.\n        axis (int or tuple[int], optional):\n            Axis or axes over which the sum is performed.\n            Default: ``None``.\n\n    Returns:\n        np.ndarray of softmax values.\n\n    Examples:\n\n        .. code-block:: python\n\n            >>> import numpy as np\n\n            >>> X = np.array([[1, 2, 3], [4, 5, 6]])\n            >>> softmax(X, axis=0)\n            array([[0.04742587, 0.04742587, 0.04742587],\n                [0.95257413, 0.95257413, 0.95257413]])\n            >>> softmax(X, axis=1)\n            array([[0.09003057, 0.24472847, 0.66524096],\n                [0.09003057, 0.24472847, 0.66524096]])\n    \"\"\"\n    vmax = np.max(X, axis=axis, keepdims=True)\n    Y = X - vmax\n    exp = np.exp(Y)\n    v = exp / np.sum(exp, axis=axis, keepdims=True)\n\n    return v\n"
  },
  {
    "path": "ssspy/transform/__init__.py",
    "content": "from .pca import pca\nfrom .whiten import whiten\n\n__all__ = [\"pca\", \"whiten\"]\n"
  },
  {
    "path": "ssspy/transform/pca.py",
    "content": "import numpy as np\n\n\ndef pca(input: np.ndarray, ascend: bool = True) -> np.ndarray:\n    r\"\"\"Apply principal component analysis (PCA).\n\n    Args:\n        input (numpy.ndarray):\n            Input tensor.\n        ascend (bool):\n            If ``ascend=True``, first channel corresponds to first principle component. \\\n            Otherwise, last channel corresponds to first principle component.\n\n    Returns:\n        numpy.ndarray:\n            Output tensor. The type (real or complex) and shape are same as input.\n\n    .. note::\n        - If ``input`` is 2D real tensor, it is regarded as (n_channels, n_samples).\n        - If ``input`` is 3D complex tensor, it is regarded as (n_channels, n_bins, n_frames).\n        - If ``input`` is 3D real tensor, it is regarded as (batch_size, n_channels, n_samples).\n        - If ``input`` is 4D complex tensor, it is regarded as \\\n          (batch_size, n_channels, n_bins, n_frames).\n\n    Examples:\n        .. code-block:: python\n\n            >>> import numpy as np\n            >>> from ssspy.transform import pca\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> n_sources = n_channels\n            >>> rng = np.random.default_rng(42)\n\n            >>> spectrogram_mix = \\\n            ...     rng.standard_normal((n_channels, n_bins, n_frames)) \\\n            ...     + 1j * rng.standard_normal((n_channels, n_bins, n_frames))\n            >>> spectrogram_mix_ortho = pca(spectrogram_mix)\n            >>> spectrogram_mix_ortho.shape\n            (2, 2049, 128)\n    \"\"\"\n    if input.ndim == 2:\n        if np.iscomplexobj(input):\n            raise ValueError(\"Real tensor is expected, but given complex tensor.\")\n        else:\n            X = input.transpose(1, 0)\n            covariance = np.mean(X[:, :, np.newaxis] * X[:, np.newaxis, :], axis=0)\n            _, V = np.linalg.eigh(covariance)\n\n            if ascend:\n                V = V[..., ::-1]\n\n            Y = X @ V\n            output = Y.transpose(1, 0)\n    elif input.ndim == 3:\n        if np.iscomplexobj(input):\n            X = input.transpose(1, 2, 0)\n            covariance = np.mean(X[:, :, :, np.newaxis] * X[:, :, np.newaxis, :].conj(), axis=1)\n            _, V = np.linalg.eigh(covariance)\n\n            if ascend:\n                V = V[..., ::-1]\n\n            Y = X @ V.conj()\n            output = Y.transpose(2, 0, 1)\n        else:\n            X = input.transpose(0, 2, 1)\n            covariance = np.mean(X[:, :, :, np.newaxis] * X[:, :, np.newaxis, :], axis=1)\n            _, V = np.linalg.eigh(covariance)\n\n            if ascend:\n                V = V[..., ::-1]\n\n            Y = X @ V\n            output = Y.transpose(0, 2, 1)\n    elif input.ndim == 4:\n        if np.iscomplexobj(input):\n            X = input.transpose(0, 2, 3, 1)\n            covariance = np.mean(\n                X[:, :, :, :, np.newaxis] * X[:, :, :, np.newaxis, :].conj(), axis=2\n            )\n            _, V = np.linalg.eigh(covariance)\n\n            if ascend:\n                V = V[..., ::-1]\n\n            Y = X @ V.conj()\n            output = Y.transpose(0, 3, 1, 2)\n        else:\n            raise ValueError(\"Complex tensor is expected, but given real tensor.\")\n    else:\n        raise ValueError(\n            \"The dimension of input is expected 3 or 4, but given {}.\".format(input.ndim)\n        )\n\n    return output\n"
  },
  {
    "path": "ssspy/transform/whiten.py",
    "content": "import numpy as np\n\n\ndef whiten(input: np.ndarray) -> np.ndarray:\n    r\"\"\"Apply whitening (a.k.a sphering).\n\n    Args:\n        input (numpy.ndarray):\n            Input tensor to be whitened.\n\n    Returns:\n        numpy.ndarray of Whitened tensor.\n        The type (real or complex) and shape are same as input.\n\n    .. note::\n        - If ``input`` is 2D real tensor, it is regarded as (n_channels, n_samples).\n        - If ``input`` is 3D complex tensor, it is regarded as (n_channels, n_bins, n_frames).\n        - If ``input`` is 3D real tensor, it is regarded as (batch_size, n_channels, n_samples).\n        - If ``input`` is 4D complex tensor, it is regarded as\n          (batch_size, n_channels, n_bins, n_frames).\n\n    Examples:\n        .. code-block:: python\n\n            >>> import numpy as np\n            >>> from ssspy.transform import whiten\n\n            >>> n_channels, n_bins, n_frames = 2, 2049, 128\n            >>> n_sources = n_channels\n            >>> rng = np.random.default_rng(42)\n\n            >>> spectrogram_mix = \\\n            ...     rng.standard_normal((n_channels, n_bins, n_frames)) \\\n            ...     + 1j * rng.standard_normal((n_channels, n_bins, n_frames))\n            >>> spectrogram_mix_whitened = whiten(spectrogram_mix)\n            >>> spectrogram_mix_whitened.shape\n            (2, 2049, 128)\n    \"\"\"\n    if input.ndim == 2:\n        if np.iscomplexobj(input):\n            raise ValueError(\"Real tensor is expected, but given complex tensor.\")\n        else:\n            n_channels = input.shape[0]\n            X = input.transpose(1, 0)\n            covariance = np.mean(X[:, :, np.newaxis] * X[:, np.newaxis, :], axis=0)\n            W, V = np.linalg.eigh(covariance)\n            D_diag = 1 / np.sqrt(W)\n            D_diag = np.diag(D_diag)\n            V_transpose = V.transpose(1, 0)\n            output = D_diag @ V_transpose @ X.transpose(1, 0)\n    elif input.ndim == 3:\n        if np.iscomplexobj(input):\n            n_channels = input.shape[0]\n            X = input.transpose(1, 2, 0)\n            covariance = np.mean(X[:, :, :, np.newaxis] * X[:, :, np.newaxis, :].conj(), axis=1)\n            W, V = np.linalg.eigh(covariance)\n            D_diag = 1 / np.sqrt(W)\n            D_diag = D_diag[:, :, np.newaxis]\n            D_diag = D_diag * np.eye(n_channels)\n            V_Hermite = V.transpose(0, 2, 1).conj()\n            Y = D_diag @ V_Hermite @ X.transpose(0, 2, 1)\n            output = Y.transpose(1, 0, 2)\n        else:\n            n_channels = input.shape[1]\n            X = input.transpose(0, 2, 1)\n            covariance = np.mean(X[:, :, :, np.newaxis] * X[:, :, np.newaxis, :], axis=1)\n            W, V = np.linalg.eigh(covariance)\n            D_diag = 1 / np.sqrt(W)\n            D_diag = D_diag[:, :, np.newaxis]\n            D_diag = D_diag * np.eye(n_channels)\n            V_transpose = V.transpose(0, 2, 1)\n            output = D_diag @ V_transpose @ X.transpose(0, 2, 1)\n    elif input.ndim == 4:\n        if np.iscomplexobj(input):\n            n_channels = input.shape[1]\n            X = input.transpose(0, 2, 3, 1)\n            covariance = np.mean(\n                X[:, :, :, :, np.newaxis] * X[:, :, :, np.newaxis, :].conj(), axis=2\n            )\n            W, V = np.linalg.eigh(covariance)\n            D_diag = 1 / np.sqrt(W)\n            D_diag = D_diag[:, :, :, np.newaxis]\n            D_diag = D_diag * np.eye(n_channels)\n            V_Hermite = V.transpose(0, 1, 3, 2).conj()\n            Y = D_diag @ V_Hermite @ X.transpose(0, 1, 3, 2)\n            output = Y.transpose(0, 2, 1, 3)\n        else:\n            raise ValueError(\"Complex tensor is expected, but given real tensor.\")\n    else:\n        raise ValueError(\n            \"The dimension of input is expected 2, 3, or 4, but given {}.\".format(input.ndim)\n        )\n\n    return output\n"
  },
  {
    "path": "ssspy/utils/__init__.py",
    "content": ""
  },
  {
    "path": "ssspy/utils/dataset/__init__.py",
    "content": "from typing import Tuple\n\nimport numpy as np\n\nfrom .mird import download as download_mird\nfrom .sisec2010 import download as download_sisec2010\n\n__all__ = [\"download_sample_speech_data\"]\n\nsisec2010_tags = [\"dev1_female3\", \"dev1_female4\"]\n\n\ndef download_sample_speech_data(\n    sisec2010_root: str = \".data/SiSEC2010\",\n    mird_root: str = \".data/MIRD\",\n    n_sources: int = 3,\n    sisec2010_tag: str = \"dev1_female3\",\n    max_duration: float = 10,\n    reverb_duration: float = 0.16,\n    conv: bool = True,\n) -> Tuple[np.ndarray, int]:\n    r\"\"\"Download sample speech data to test sepration methods.\n\n    This function returns source images of sample speech data.\n\n    Args:\n        sisec2010_root (str):\n            Path to save SiSEC2010 dataset. Default: \".data/SiSEC2010\".\n        mird_root (str):\n            Path to save MIRD dataset. Default: \".data/MIRD\".\n        n_sources (int):\n            Number of sources included in sample data.\n        sisec2010_tag (str):\n            Tag of SiSEC 2010 data.\n            Choose ``dev1_female3`` or ``dev1_female4``.\n            Default: ``dev1_female3``.\n        max_duration (float):\n            Maximum duration. Default: ``160000``.\n        reverb_duration (float):\n            Duration of reverberation in MIRD.\n            Choose ``0.16``, ``0.36``, ``0.61``. Default: ``0.16``.\n        conv (bool):\n            Convolutive mixture or not. Defalt: ``True``.\n\n    Returns:\n        Tuple of source images and sampling rate.\n        The source images is numpy.ndarry with shape of (n_channels, n_sources, n_samples).\n    \"\"\"\n    assert sisec2010_tag in sisec2010_tags, \"Choose sisec2010_tag from {}\".format(sisec2010_tags)\n    sample_rate = 16000  # Only 16khz is supported.\n    max_samples = int(sample_rate * max_duration)\n\n    sisec2010_npz_path = download_sisec2010(\n        root=sisec2010_root, n_sources=n_sources, tag=sisec2010_tag\n    )\n    sisec2010_npz = np.load(sisec2010_npz_path)\n\n    assert sample_rate == sisec2010_npz[\"sample_rate\"].item(), \"Invalid sampling rate is detected.\"\n\n    if conv:\n        mird_npz_path = download_mird(\n            root=mird_root, n_sources=n_sources, reverb_duration=reverb_duration\n        )\n        mird_npz = np.load(mird_npz_path)\n\n        assert sample_rate == mird_npz[\"sample_rate\"].item(), \"Invalid sampling rate is detected.\"\n\n        waveform_src_img = []\n\n        for src_idx in range(n_sources):\n            key = \"src_{}\".format(src_idx + 1)\n            waveform_src = sisec2010_npz[key][:max_samples]\n            n_samples = len(waveform_src)\n            _waveform_src_img = []\n\n            for waveform_rir in mird_npz[key]:\n                waveform_conv = np.convolve(waveform_src, waveform_rir)[:n_samples]\n                _waveform_src_img.append(waveform_conv)\n\n            _waveform_src_img = np.stack(_waveform_src_img, axis=0)  # (n_channels, n_samples)\n            waveform_src_img.append(_waveform_src_img)\n\n        waveform_src_img = np.stack(waveform_src_img, axis=1)  # (n_channels, n_sources, n_samples)\n    else:\n        waveform_src_img = []\n\n        rng = np.random.default_rng(seed=42)\n        mixing = rng.standard_normal((n_sources, n_sources))\n\n        for src_idx in range(n_sources):\n            key = \"src_{}\".format(src_idx + 1)\n            _mixing = mixing[:, src_idx]\n            waveform_src = sisec2010_npz[key][:max_samples]\n            _waveform_src_img = _mixing[:, np.newaxis] * waveform_src\n            waveform_src_img.append(_waveform_src_img)\n\n        waveform_src_img = np.stack(waveform_src_img, axis=1)  # (n_channels, n_sources, n_samples)\n\n    return waveform_src_img, sample_rate\n"
  },
  {
    "path": "ssspy/utils/dataset/mird.py",
    "content": "import os\nimport shutil\nimport urllib.request\n\nimport numpy as np\n\nreverb_durations = [0.16, 0.36, 0.61]\n\n\ndef download(root: str = \".data/MIRD\", n_sources: int = 3, reverb_duration: float = 0.16) -> str:\n    assert reverb_duration in reverb_durations, \"reverb_duration should be chosen from {}.\".format(\n        reverb_durations\n    )\n\n    filename = (\n        \"Impulse_response_Acoustic_Lab_Bar-Ilan_University__\"\n        \"Reverberation_{reverb_duration:.3f}s__3-3-3-8-3-3-3.zip\"\n    )\n    filename = filename.format(reverb_duration=reverb_duration)\n    url = (\n        \"https://www.iks.rwth-aachen.de/fileadmin/user_upload/downloads/\"\n        \"forschung/tools-downloads/{filename}\"\n    )\n    url = url.format(filename=filename)\n    zip_path = os.path.join(root, filename)\n\n    degrees = [30, 345, 0, 60, 315]\n    channels = [3, 4, 2, 5, 1, 6, 0, 7]\n    sample_rate = 16000\n    duration = reverb_duration\n\n    degrees = degrees[:n_sources]\n    channels = channels[:n_sources]\n\n    n_channels = len(channels)\n    n_samples = int(sample_rate * duration)\n\n    template_rir_name = (\n        \"Impulse_response_Acoustic_Lab_Bar-Ilan_University_\"\n        \"(Reverberation_{:.3f}s)_3-3-3-8-3-3-3_1m_{:03d}.mat\"\n    )\n\n    os.makedirs(root, exist_ok=True)\n\n    if not os.path.exists(zip_path):\n        urllib.request.urlretrieve(url, zip_path)\n\n    rir_path = os.path.join(root, template_rir_name.format(reverb_duration, 0))\n\n    if not os.path.exists(rir_path):\n        shutil.unpack_archive(zip_path, root)\n\n    npz_path = os.path.join(root, \"MIRD-{}ch.npz\".format(n_channels))\n\n    assert n_channels == n_sources, \"Mixing system should be determined.\"\n\n    if not os.path.exists(npz_path):\n        rirs = {}\n\n        for src_idx, degree in enumerate(degrees):\n            rir_path = os.path.join(root, template_rir_name.format(duration, degree))\n            rir = resample_mird_rir(rir_path, sample_rate_out=sample_rate)\n            rirs[\"src_{}\".format(src_idx + 1)] = rir[channels, :n_samples]\n\n        np.savez(\n            npz_path, sample_rate=sample_rate, n_sources=n_sources, n_channels=n_channels, **rirs\n        )\n\n    return npz_path\n\n\ndef resample_mird_rir(rir_path: str, sample_rate_out: int) -> np.ndarray:\n    import scipy.signal as ss\n    from scipy.io import loadmat\n\n    sample_rate_in = 48000\n    rir_mat = loadmat(rir_path)\n    rir = rir_mat[\"impulse_response\"]\n\n    rir_resampled = ss.resample_poly(rir, sample_rate_out, sample_rate_in, axis=0)\n\n    return rir_resampled.T\n"
  },
  {
    "path": "ssspy/utils/dataset/sisec2010.py",
    "content": "import os\nimport shutil\nimport urllib.request\n\nimport numpy as np\n\nfrom ...io import wavread\n\n\ndef download(root: str = \".data/SiSEC2010\", n_sources: int = 3, tag: str = \"dev1_female3\") -> str:\n    filename = \"dev1.zip\"\n    url = \"http://www.irisa.fr/metiss/SiSEC10/underdetermined/{}\".format(filename)\n    zip_path = os.path.join(root, filename)\n\n    os.makedirs(root, exist_ok=True)\n\n    if not os.path.exists(zip_path):\n        urllib.request.urlretrieve(url, zip_path)\n\n    if not os.path.exists(os.path.join(root, \"{}_inst_matrix.mat\".format(tag))):\n        shutil.unpack_archive(zip_path, root)\n\n    source_paths = []\n\n    for src_idx in range(n_sources):\n        source_path = os.path.join(root, \"{}_src_{}.wav\".format(tag, src_idx + 1))\n        source_paths.append(source_path)\n\n    channels = [3, 4, 2, 5]\n    sample_rate = 16000\n\n    source_paths = source_paths[:n_sources]\n    channels = channels[:n_sources]\n\n    n_channels = len(channels)\n    npz_path = os.path.join(root, \"SiSEC2010-{}ch.npz\".format(n_channels))\n\n    assert n_channels == n_sources, \"Mixing system should be determined.\"\n\n    if not os.path.exists(npz_path):\n        dry_sources = {}\n\n        for src_idx, source_path in enumerate(source_paths):\n            data, _ = wavread(source_path, return_2d=False)\n            dry_sources[\"src_{}\".format(src_idx + 1)] = data\n\n        np.savez(\n            npz_path,\n            sample_rate=sample_rate,\n            n_sources=n_sources,\n            n_channels=n_channels,\n            **dry_sources,\n        )\n\n    return npz_path\n"
  },
  {
    "path": "ssspy/utils/flooring.py",
    "content": "from typing import Any, Callable, Optional, Union\n\nimport numpy as np\n\nfrom ..special.flooring import identity\n\n\ndef choose_flooring_fn(\n    flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = \"self\",\n    method: Optional[Any] = None,\n) -> Callable[[np.ndarray], np.ndarray]:\n    if flooring_fn is None:\n        assert method is None, \"method is given, but flooring function is not specified.\"\n\n        flooring_fn = identity\n    elif type(flooring_fn) is str and flooring_fn == \"self\":\n        if method is None or not hasattr(method, \"flooring_fn\"):\n            flooring_fn = identity\n        else:\n            flooring_fn = method.flooring_fn\n\n    assert callable(flooring_fn), \"flooring_fn should be callable.\"\n\n    return flooring_fn\n"
  },
  {
    "path": "ssspy/utils/select_pair.py",
    "content": "import itertools\nfrom typing import Iterable, Optional, Tuple\n\n\ndef sequential_pair_selector(\n    n_sources: int, stop: Optional[int] = None, step: int = 1, sort: bool = False\n) -> Iterable[Tuple[int, int]]:\n    r\"\"\"Select pair in pairwise update.\n\n    Args:\n        n_sources (int):\n            Number of sources.\n        step (int):\n            This parameter determines step size.\n            For instance, if ``sequential_pair_selector(n_sources=6, step=2, sort=False)``,\n            this function yields ``0, 1``, ``2, 3``, ``4, 5``, ``0, 1``, ``2, 3``, ``4, 5``.\n            Default: ``1``.\n        sort (bool):\n            Sort pair to ensure :math:`m<n` if ``sort=True``.\n            Default: ``False``.\n\n    Yields:\n        Pair (tuple) of indices.\n\n    Examples:\n        .. code-block:: python\n\n            >>> for m, n in combination_pair_selector(4):\n            ...     print(m, n)\n            0 1\n            1 2\n            2 3\n            3 0\n    \"\"\"\n    if stop is None:\n        stop = n_sources\n\n    for m in range(0, stop, step):\n        m, n = m % n_sources, (m + 1) % n_sources\n\n        if sort:\n            m, n = (n, m) if m > n else (m, n)\n\n        yield m, n\n\n\ndef combination_pair_selector(n_sources: int, sort: bool = False) -> Iterable[Tuple[int, int]]:\n    r\"\"\"Select pair in pairwise update.\n\n    Args:\n        n_sources (int):\n            Number of sources.\n        sort (bool):\n            Sort pair to ensure :math:`m<n` if ``sort=True``.\n            Default: ``False``.\n\n    Yields:\n        Pair (tuple) of indices.\n\n    Examples:\n        .. code-block:: python\n\n            >>> for m, n in combination_pair_selector(4):\n            ...     print(m, n)\n            0 1\n            0 2\n            0 3\n            1 2\n            1 3\n            2 3\n    \"\"\"\n    for m, n in itertools.combinations(range(n_sources), 2):\n        if sort:\n            m, n = (n, m) if m > n else (m, n)\n\n        yield m, n\n"
  },
  {
    "path": "tests/conftest.py",
    "content": "# conftest.py is based on\n# https://docs.pytest.org/en/latest/example/simple.html#control-skipping-of-tests-according-to-command-line-option\n# and\n# https://docs.pytest.org/en/latest/deprecations.html#pytest-namespace\n\nimport pytest\n\n\ndef pytest_addoption(parser):\n    parser.addoption(\n        \"--run-redundant\", action=\"store_true\", default=False, help=\"Run redandant tests.\"\n    )\n\n\ndef pytest_configure():\n    pytest.run_redundant = False\n\n\ndef pytest_collection_modifyitems(config, items):\n    if config.getoption(\"--run-redundant\"):\n        pytest.run_redundant = True\n"
  },
  {
    "path": "tests/dummy/callback.py",
    "content": "def dummy_function(_) -> None:\n    pass\n\n\nclass DummyCallback:\n    def __init__(self) -> None:\n        pass\n\n    def __call__(self, _) -> None:\n        pass\n"
  },
  {
    "path": "tests/dummy/io.py",
    "content": "import os\nimport struct\n\nimport numpy as np\n\n\ndef save_invalid_wavfile(\n    path: str,\n    invalid_riff: bool = False,\n    invalid_ftype: bool = False,\n    invalid_fmt_chunk_marker: bool = False,\n    invalid_fmt_chunk_size: bool = False,\n    invalid_fmt: bool = False,\n    invalid_byte_rate: bool = False,\n    invalid_data_chunk_marker: bool = False,\n) -> None:\n    os.makedirs(os.path.dirname(path), exist_ok=True)\n\n    n_channels = 1\n    sample_rate = 16000\n    bits_per_sample = 16\n    duration = 5\n    byte_rate = (bits_per_sample * sample_rate * n_channels) // 8\n    block_align = byte_rate // sample_rate\n    total_file_size = byte_rate * duration + 44\n\n    rng = np.random.default_rng(42)\n    num_frames = sample_rate * duration\n    bytes_per_sample = block_align // n_channels\n    vmax = 2 ** (bits_per_sample - 1)\n\n    valid_file_size = 0\n\n    with open(path, mode=\"wb\") as f:\n        if invalid_riff:\n            data = b\"RIFX\"\n        else:\n            data = b\"RIFF\"\n\n        f.write(data)\n        valid_file_size += 4\n\n        data = struct.pack(\"<I\", total_file_size - 4 - 4)\n        f.write(data)\n        valid_file_size += 4\n\n        if invalid_ftype:\n            data = b\"wave\"\n        else:\n            data = b\"WAVE\"\n\n        f.write(data)\n        valid_file_size += 4\n\n        if invalid_fmt_chunk_marker:\n            data = b\"FMT \"\n        else:\n            data = b\"fmt \"\n\n        f.write(data)\n        valid_file_size += 4\n\n        if invalid_fmt_chunk_size:\n            data = struct.pack(\"<I\", 15)\n        else:\n            data = struct.pack(\"<I\", 16)\n\n        f.write(data)\n        valid_file_size += 4\n\n        if invalid_fmt:\n            data = struct.pack(\"<H\", 0)\n        else:\n            data = struct.pack(\"<H\", 1)\n\n        f.write(data)\n        valid_file_size += 2\n\n        if invalid_byte_rate:\n            data = struct.pack(\n                \"<HIIHH\", n_channels, sample_rate, byte_rate + 1, block_align, bits_per_sample\n            )\n        else:\n            data = struct.pack(\n                \"<HIIHH\", n_channels, sample_rate, byte_rate, block_align, bits_per_sample\n            )\n\n        f.write(data)\n        valid_file_size += 2 + 4 + 4 + 2 + 2\n\n        if invalid_data_chunk_marker:\n            data = b\"DATA\"\n        else:\n            data = b\"data\"\n\n        f.write(data)\n        valid_file_size += 4\n\n        data_chunk_size = num_frames * n_channels * block_align\n        data = struct.pack(\"<I\", data_chunk_size)\n        f.write(data)\n        valid_file_size += 4\n\n        waveform = rng.integers(-vmax, vmax, size=(num_frames,), dtype=f\"<i{bytes_per_sample}\")\n        data = waveform.view(\"b\").data\n        f.write(data)\n        valid_file_size += byte_rate * duration\n\n        assert valid_file_size == total_file_size\n"
  },
  {
    "path": "tests/dummy/utils/dataset/__init__.py",
    "content": "import hashlib\nimport json\nimport os\nimport urllib.request\nimport warnings\nfrom typing import List, Optional, Tuple\n\nimport numpy as np\n\nfrom ssspy.utils.dataset import download_sample_speech_data as _download\n\n\ndef download_sample_speech_data(\n    sisec2010_root: str = \"./tests/.data/SiSEC2010\",\n    mird_root: str = \"./tests/.data/MIRD\",\n    n_sources: int = 3,\n    sisec2010_tag: str = \"dev1_female3\",\n    max_duration: float = 10,\n    conv: bool = True,\n    cache_dir: str = \".tests/.data/.cache\",\n) -> Tuple[np.ndarray, int]:\n    hash = hashlib.sha256(sisec2010_root.encode(\"utf-8\")).hexdigest()\n    hash += hashlib.sha256(mird_root.encode(\"utf-8\")).hexdigest()\n    hash += hashlib.sha256(str(n_sources).encode(\"utf-8\")).hexdigest()\n    hash += hashlib.sha256(sisec2010_tag.encode(\"utf-8\")).hexdigest()\n    hash += hashlib.sha256(str(max_duration).encode(\"utf-8\")).hexdigest()\n    hash += hashlib.sha256(str(conv).encode(\"utf-8\")).hexdigest()\n\n    # because concatenated hash is too long\n    hash = hashlib.sha256(hash.encode(\"utf-8\")).hexdigest()\n\n    npz_path = os.path.join(cache_dir, \"{}.npz\".format(hash))\n\n    if os.path.exists(npz_path):\n        npz = np.load(npz_path)\n        waveform_src_img, sample_rate = npz[\"waveform_src_img\"], npz[\"sample_rate\"]\n        sample_rate = sample_rate.item()\n    else:\n        waveform_src_img, sample_rate = _download(\n            sisec2010_root=sisec2010_root,\n            mird_root=mird_root,\n            n_sources=n_sources,\n            sisec2010_tag=sisec2010_tag,\n            max_duration=max_duration,\n            conv=conv,\n        )\n        os.makedirs(cache_dir, exist_ok=True)\n        np.savez(npz_path, waveform_src_img=waveform_src_img, sample_rate=sample_rate)\n\n    return waveform_src_img, sample_rate\n\n\ndef download_ssspy_data(path: str, filename: Optional[str] = None, branch: str = \"main\") -> None:\n    \"\"\"Download file from https://github.com/tky823/ssspy-data.\n\n    Args:\n        path (str): Path to file in https://github.com/tky823/ssspy-data.\n        filename (str, optional): File name to save data. If ``None``,\n            base name of ``path`` is used.\n        branch (str, optional): Branch name of https://github.com/tky823/ssspy-data.\n\n    \"\"\"\n    url = f\"https://github.com/tky823/ssspy-data/raw/{branch}/{path}\"\n\n    if filename is None:\n        filename = os.path.basename(url)\n\n    root = os.path.dirname(filename)\n\n    if root:\n        os.makedirs(root, exist_ok=True)\n\n    if not os.path.exists(filename):\n        urllib.request.urlretrieve(url, filename)\n\n\ndef load_regression_data(root: str, filenames: Optional[List[str]] = None) -> Tuple:\n    \"\"\"Load regression data.\n\n    Args:\n        root (str): Root to save regression data, where url.json is placed.\n        filenames (str, optional): Filenames to download.\n\n    Returns:\n        tuple: Tuple containing data of specified filenames.\n\n    \"\"\"\n    url_json_path = os.path.join(root, \"url.json\")\n\n    with open(url_json_path) as f:\n        urls = json.load(f)\n\n    if filenames is None:\n        warnings.warn(\"It is recommended to specify filenames to ensure order.\", UserWarning)\n\n        filenames = []\n\n        for file in urls[\"files\"]:\n            filename = file[\"filename\"]\n            filenames.append(filename)\n\n    npz = {}\n\n    for file in urls[\"files\"]:\n        filename = file[\"filename\"]\n        location = file[\"location\"]\n\n        if filename not in filenames:\n            continue\n\n        path = os.path.join(root, filename)\n\n        download_ssspy_data(location, path)\n\n        npz[filename] = np.load(path)\n\n    sorted_npz = []\n\n    for filename in filenames:\n        sorted_npz.append(npz[filename])\n\n    return tuple(sorted_npz)\n"
  },
  {
    "path": "tests/mock/regression/bss/cacgmm/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/cacgmm/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/fdica/aux_laplace_fdica/IP1/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/fdica/aux_laplace_fdica/IP1/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/fdica/aux_laplace_fdica/IP2/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/fdica/aux_laplace_fdica/IP2/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/fdica/grad_laplace_fdica/holonomic/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/fdica/grad_laplace_fdica/holonomic/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/fdica/grad_laplace_fdica/nonholonomic/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/fdica/grad_laplace_fdica/nonholonomic/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/fdica/natural_grad_laplace_fdica/holonomic/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/fdica/natural_grad_laplace_fdica/holonomic/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/fdica/natural_grad_laplace_fdica/nonholonomic/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/fdica/natural_grad_laplace_fdica/nonholonomic/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/ilrma/gauss_ilrma/IP1/ME/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/ilrma/gauss_ilrma/IP1/ME/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/ilrma/gauss_ilrma/IP1/MM/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/ilrma/gauss_ilrma/IP1/MM/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/ilrma/gauss_ilrma/IP2/ME/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/ilrma/gauss_ilrma/IP2/ME/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/ilrma/gauss_ilrma/IP2/MM/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/ilrma/gauss_ilrma/IP2/MM/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/ilrma/gauss_ilrma/IPA/ME/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/ilrma/gauss_ilrma/IPA/ME/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/ilrma/gauss_ilrma/IPA/MM/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/ilrma/gauss_ilrma/IPA/MM/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/ilrma/gauss_ilrma/ISS1/ME/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/ilrma/gauss_ilrma/ISS1/ME/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/ilrma/gauss_ilrma/ISS1/MM/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/ilrma/gauss_ilrma/ISS1/MM/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/ilrma/gauss_ilrma/ISS2/ME/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/ilrma/gauss_ilrma/ISS2/ME/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/ilrma/gauss_ilrma/ISS2/MM/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/ilrma/gauss_ilrma/ISS2/MM/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/ilrma/ggd_ilrma/IP1/MM/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/ilrma/ggd_ilrma/IP1/MM/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/ilrma/ggd_ilrma/IP2/MM/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/ilrma/ggd_ilrma/IP2/MM/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/ilrma/ggd_ilrma/ISS1/MM/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/ilrma/ggd_ilrma/ISS1/MM/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/ilrma/ggd_ilrma/ISS2/MM/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/ilrma/ggd_ilrma/ISS2/MM/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/ilrma/t_ilrma/IP1/ME/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/ilrma/t_ilrma/IP1/ME/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/ilrma/t_ilrma/IP1/MM/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/ilrma/t_ilrma/IP1/MM/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/ilrma/t_ilrma/IP2/ME/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/ilrma/t_ilrma/IP2/ME/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/ilrma/t_ilrma/IP2/MM/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/ilrma/t_ilrma/IP2/MM/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/ilrma/t_ilrma/ISS1/ME/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/ilrma/t_ilrma/ISS1/ME/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/ilrma/t_ilrma/ISS1/MM/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/ilrma/t_ilrma/ISS1/MM/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/ilrma/t_ilrma/ISS2/ME/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/ilrma/t_ilrma/ISS2/ME/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/ilrma/t_ilrma/ISS2/MM/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/ilrma/t_ilrma/ISS2/MM/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/ipsdta/gauss_ipsdta/VCD/MM/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/ipsdta/gauss_ipsdta/VCD/MM/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/ipsdta/t_ipsdta/VCD/MM/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/ipsdta/t_ipsdta/VCD/MM/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/iva/aux_iva/IP1/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/iva/aux_iva/IP1/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/iva/aux_iva/IP2/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/iva/aux_iva/IP2/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/iva/aux_iva/IPA/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/iva/aux_iva/IPA/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/iva/aux_iva/ISS1/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/iva/aux_iva/ISS1/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/iva/aux_iva/ISS2/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/iva/aux_iva/ISS2/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/iva/fast_iva/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/iva/fast_iva/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/iva/grad_iva/holonomic/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/iva/grad_iva/holonomic/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/iva/grad_iva/nonholonomic/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/iva/grad_iva/nonholonomic/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/iva/natural_grad_iva/holonomic/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/iva/natural_grad_iva/holonomic/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/iva/natural_grad_iva/nonholonomic/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/iva/natural_grad_iva/nonholonomic/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/mnmf/fast_gauss_mnmf/IP1/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/mnmf/fast_gauss_mnmf/IP1/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/mnmf/fast_gauss_mnmf/IP2/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/mnmf/fast_gauss_mnmf/IP2/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/mock/regression/bss/mnmf/gauss_mnmf/url.json",
    "content": "{\n    \"files\": [\n        {\n            \"filename\": \"input.npz\",\n            \"location\": \"npz/canon_8k_reverbed.npz\"\n        },\n        {\n            \"filename\": \"target.npz\",\n            \"location\": \"npz/bss/mnmf/gauss_mnmf/target.npz\"\n        }\n    ]\n}\n"
  },
  {
    "path": "tests/package/algorithm/test_minimal_distortion_principle.py",
    "content": "from typing import Optional\n\nimport numpy as np\nimport pytest\n\nfrom ssspy.algorithm import minimal_distortion_principle\n\nparameters = [(2, 0), (3, 2), (2, None)]\n\n\n@pytest.mark.parametrize(\"n_sources, reference_id\", parameters)\ndef test_minimal_distortion_principle(n_sources: int, reference_id: Optional[int]):\n    rng = np.random.default_rng(0)\n\n    n_channels = n_sources\n    n_bins, n_frames = 5, 8\n\n    spectrogram_mix = rng.standard_normal(\n        (n_channels, n_bins, n_frames)\n    ) + 1j * rng.standard_normal((n_channels, n_bins, n_frames))\n    demix_filter = rng.standard_normal((n_bins, n_sources, n_channels)) + 1j * rng.standard_normal(\n        (n_bins, n_sources, n_channels)\n    )\n    spectrogram_est = demix_filter @ spectrogram_mix.transpose(1, 0, 2)\n    spectrogram_est = spectrogram_est.transpose(1, 0, 2)\n\n    spectrogram_est_scaled = minimal_distortion_principle(\n        spectrogram_est, spectrogram_mix, reference_id=reference_id\n    )\n\n    if reference_id is None:\n        for _spectrogram_est_scaled in spectrogram_est_scaled:\n            assert spectrogram_mix.shape == _spectrogram_est_scaled.shape\n    else:\n        assert spectrogram_mix.shape == spectrogram_est.shape\n"
  },
  {
    "path": "tests/package/algorithm/test_permutation_alignment.py",
    "content": "import numpy as np\nimport pytest\n\nfrom ssspy.algorithm.permutation_alignment import (\n    correlation_based_permutation_solver,\n    score_based_permutation_solver,\n)\n\nrng = np.random.default_rng(0)\n\nparameters_give_demixing_filter = [True, False]\n\n\n@pytest.mark.parametrize(\"give_demixing_filter\", parameters_give_demixing_filter)\ndef test_correlation_based_permutation_solver(give_demixing_filter: bool):\n    n_sources = 3\n    n_channels = n_sources\n    n_bins, n_frames = 4, 16\n\n    shape = (n_channels, n_bins, n_frames)\n    mixture = rng.standard_normal(shape) + 1j * rng.standard_normal(shape)\n    shape = (n_bins, n_sources, n_channels)\n    demix_filter = rng.standard_normal(shape) + 1j * rng.standard_normal(shape)\n    separated = demix_filter @ mixture.transpose(1, 0, 2)\n\n    if give_demixing_filter:\n        separated, demix_filter = correlation_based_permutation_solver(separated, demix_filter)\n\n        assert demix_filter.shape == (n_bins, n_sources, n_channels)\n    else:\n        separated = correlation_based_permutation_solver(separated)\n\n    assert separated.shape == (n_bins, n_sources, n_frames)\n\n\n@pytest.mark.parametrize(\"give_demixing_filter\", parameters_give_demixing_filter)\ndef test_score_based_permutation_solver(give_demixing_filter: bool):\n    n_sources = 3\n    n_channels = n_sources\n    n_bins, n_frames = 4, 16\n\n    shape = (n_channels, n_bins, n_frames)\n    mixture = rng.standard_normal(shape) + 1j * rng.standard_normal(shape)\n    shape = (n_bins, n_sources, n_channels)\n    demix_filter = rng.standard_normal(shape) + 1j * rng.standard_normal(shape)\n    separated = demix_filter @ mixture.transpose(1, 0, 2)\n\n    if give_demixing_filter:\n        separated, demix_filter = score_based_permutation_solver(separated, demix_filter)\n\n        assert demix_filter.shape == (n_bins, n_sources, n_channels)\n    else:\n        separated = correlation_based_permutation_solver(separated)\n\n    assert separated.shape == (n_bins, n_sources, n_frames)\n"
  },
  {
    "path": "tests/package/algorithm/test_projection_back.py",
    "content": "from typing import Optional\n\nimport numpy as np\nimport pytest\n\nfrom ssspy.algorithm import projection_back\n\nparameters = [(2, 0), (3, 2), (2, None)]\n\n\n@pytest.mark.parametrize(\"n_sources, reference_id\", parameters)\ndef test_projection_back_demix_filter(n_sources: int, reference_id: Optional[int]):\n    np.random.seed(111)\n\n    n_channels = n_sources\n    n_bins, n_frames = 17, 10\n\n    spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) + 1j * np.random.randn(\n        n_channels, n_bins, n_frames\n    )\n    demix_filter = np.random.randn(n_bins, n_sources, n_channels) + 1j * np.random.randn(\n        n_bins, n_sources, n_channels\n    )\n\n    demix_filter_scaled = projection_back(demix_filter, reference_id=reference_id)\n\n    spectrogram_est = demix_filter_scaled @ spectrogram_mix.transpose(1, 0, 2)\n\n    if reference_id is None:\n        spectrogram_est = spectrogram_est.transpose(0, 2, 1, 3)\n\n        for _spectrogram_est in spectrogram_est:\n            assert spectrogram_mix.shape == _spectrogram_est.shape\n    else:\n        spectrogram_est = spectrogram_est.transpose(1, 0, 2)\n\n        assert spectrogram_mix.shape == spectrogram_est.shape\n\n\n@pytest.mark.parametrize(\"n_sources, reference_id\", parameters)\ndef test_projection_back_output(n_sources: int, reference_id: Optional[int]):\n    np.random.seed(111)\n\n    n_channels = n_sources\n    n_bins, n_frames = 17, 10\n\n    spectrogram_mix = np.random.randn(n_channels, n_bins, n_frames) + 1j * np.random.randn(\n        n_channels, n_bins, n_frames\n    )\n    demix_filter = np.random.randn(n_bins, n_sources, n_channels) + 1j * np.random.randn(\n        n_bins, n_sources, n_channels\n    )\n    spectrogram_est = demix_filter @ spectrogram_mix.transpose(1, 0, 2)\n    spectrogram_est = spectrogram_est.transpose(1, 0, 2)\n\n    spectrogram_est_scaled = projection_back(\n        spectrogram_est, reference=spectrogram_mix, reference_id=reference_id\n    )\n\n    if reference_id is None:\n        for _spectrogram_est_scaled in spectrogram_est_scaled:\n            assert spectrogram_mix.shape == _spectrogram_est_scaled.shape\n    else:\n        assert spectrogram_mix.shape == spectrogram_est.shape\n"
  },
  {
    "path": "tests/package/bss/test_admmbss.py",
    "content": "import math\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport numpy as np\nimport pytest\nimport scipy.signal as ss\nfrom dummy.callback import DummyCallback, dummy_function\nfrom dummy.utils.dataset import download_sample_speech_data\n\nfrom ssspy.bss.admmbss import ADMMBSS, ADMMBSSBase, MaskingADMMBSS\n\nmax_duration = 0.5\nn_fft = 2048\nhop_length = 1024\nn_bins = n_fft // 2 + 1\nn_iter = 5\n\nparameters_admmbss = [\n    (2, None, {}),\n    (\n        3,\n        dummy_function,\n        {\"demix_filter\": np.tile(-np.eye(3, dtype=np.complex128), reps=(n_bins, 1, 1))},\n    ),\n    (2, [DummyCallback(), dummy_function], {}),\n    (\n        2,\n        None,\n        {\n            # n_frames=9\n            \"auxiliary1\": np.ones((n_bins, 2, 2), dtype=np.complex128),\n            \"auxiliary2\": np.zeros((1, 2, n_bins, 9), dtype=np.complex128),\n            \"dual1\": np.ones((n_bins, 2, 2), dtype=np.complex128),\n            \"dual2\": np.zeros((1, 2, n_bins, 9), dtype=np.complex128),\n        },\n    ),\n]\n\n\ndef contrast_fn(y: np.ndarray) -> np.ndarray:\n    r\"\"\"Contrast function.\n\n    Args:\n        y (np.ndarray):\n            The shape is (n_sources, n_bins, n_frames).\n\n    Returns:\n        np.ndarray of the shape is (n_sources, n_frames).\n    \"\"\"\n    return 2 * np.linalg.norm(y, axis=1)\n\n\ndef penalty_fn(y: np.ndarray) -> float:\n    loss = contrast_fn(y)\n    loss = np.sum(loss.mean(axis=-1))\n    return loss\n\n\ndef prox_penalty(y: np.ndarray, step_size: float = 1) -> np.ndarray:\n    r\"\"\"Proximal operator of penalty function.\n\n    Args:\n        y (np.ndarray):\n            The shape is (n_sources, n_bins, n_frames).\n        step_size (float):\n            Step size. Default: 1.\n\n    Returns:\n        np.ndarray of the shape is (n_sources, n_bins, n_frames).\n    \"\"\"\n    norm = np.linalg.norm(y, axis=1, keepdims=True)\n\n    # to suppress warning RuntimeWarning\n    norm = np.where(norm < step_size, step_size, norm)\n\n    return y * np.maximum(1 - step_size / norm, 0)\n\n\ndef test_admmbss_base():\n    admmbss = ADMMBSSBase(penalty_fn=penalty_fn, prox_penalty=prox_penalty)\n\n    print(admmbss)\n\n\n@pytest.mark.parametrize(\"n_sources, callbacks, reset_kwargs\", parameters_admmbss)\ndef test_admmbss(\n    n_sources: int,\n    callbacks: Optional[Union[Callable[[ADMMBSS], None], List[Callable[[ADMMBSS], None]]]],\n    reset_kwargs: Dict[Any, Any],\n):\n    np.random.seed(111)\n\n    waveform_src_img, _ = download_sample_speech_data(\n        sisec2010_root=\"./tests/.data/SiSEC2010\",\n        mird_root=\"./tests/.data/MIRD\",\n        n_sources=n_sources,\n        sisec2010_tag=\"dev1_female3\",\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    admmbss = ADMMBSS(penalty_fn=penalty_fn, prox_penalty=prox_penalty, callbacks=callbacks)\n    spectrogram_mix_normalized = admmbss.normalize_by_spectral_norm(spectrogram_mix)\n    spectrogram_est = admmbss(spectrogram_mix_normalized, n_iter=n_iter, **reset_kwargs)\n\n    assert spectrogram_mix.shape == spectrogram_est.shape\n\n    print(admmbss)\n\n\n@pytest.mark.parametrize(\"n_sources, callbacks, reset_kwargs\", parameters_admmbss)\ndef test_masking_admmbss(\n    n_sources: int,\n    callbacks: Optional[Union[Callable[[ADMMBSS], None], List[Callable[[ADMMBSS], None]]]],\n    reset_kwargs: Dict[Any, Any],\n) -> None:\n    np.random.seed(111)\n\n    waveform_src_img, _ = download_sample_speech_data(\n        sisec2010_root=\"./tests/.data/SiSEC2010\",\n        mird_root=\"./tests/.data/MIRD\",\n        n_sources=n_sources,\n        sisec2010_tag=\"dev1_female3\",\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    def hva_mask_fn(y: np.ndarray, mask_iter: int = 2) -> np.ndarray:\n        \"\"\"Masking function to emphasize harmonic components.\n\n        Args:\n            y (np.ndarray): The shape is (n_sources, n_bins, n_frames).\n\n        Returns:\n            np.ndarray of mask. The shape is (n_sources, n_bins, n_frames).\n        \"\"\"\n        n_sources, n_bins, _ = y.shape\n\n        gamma = 1 / n_sources\n\n        y = np.maximum(np.abs(y), 1e-10)\n        zeta = np.log(y)\n        zeta_mean = zeta.mean(axis=1, keepdims=True)\n        rho = zeta - zeta_mean\n        nu = np.fft.irfft(rho, axis=1, norm=\"backward\")\n        nu = nu[:, :n_bins]\n        varsigma = np.minimum(1, nu)\n\n        for _ in range(mask_iter):\n            varsigma = (1 - np.cos(math.pi * varsigma)) / 2\n\n        xi = np.fft.irfft(varsigma * nu, axis=1, norm=\"forward\")\n        xi = xi[:, :n_bins]\n        varrho = xi + zeta_mean\n        v = np.exp(2 * varrho)\n        mask = (v / v.sum(axis=0)) ** gamma\n\n        return mask\n\n    admmbss = MaskingADMMBSS(mask_fn=hva_mask_fn, callbacks=callbacks)\n    spectrogram_mix_normalized = admmbss.normalize_by_spectral_norm(spectrogram_mix)\n\n    if \"auxiliary2\" in reset_kwargs:\n        auxiliary2 = reset_kwargs.pop(\"auxiliary2\")\n\n        if auxiliary2.ndim == 4:\n            auxiliary2 = auxiliary2.squeeze(axis=0)\n\n        reset_kwargs[\"auxiliary2\"] = auxiliary2\n\n    if \"dual2\" in reset_kwargs:\n        dual2 = reset_kwargs.pop(\"dual2\")\n\n        if dual2.ndim == 4:\n            dual2 = dual2.squeeze(axis=0)\n\n        reset_kwargs[\"dual2\"] = dual2\n\n    spectrogram_est = admmbss(spectrogram_mix_normalized, n_iter=n_iter, **reset_kwargs)\n\n    assert spectrogram_mix.shape == spectrogram_est.shape\n\n    print(admmbss)\n"
  },
  {
    "path": "tests/package/bss/test_base.py",
    "content": "from typing import Callable, List, Optional, Union\n\nimport pytest\nfrom dummy.callback import DummyCallback, dummy_function\n\nfrom ssspy.bss.base import IterativeMethodBase\n\nn_iter = 3\n\nparameters_callbacks = [None, dummy_function, [DummyCallback(), dummy_function]]\nparameters_record_loss = [True, False]\n\n\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"record_loss\", parameters_record_loss)\ndef test_iterative_method_base(\n    callbacks: Optional[\n        Union[Callable[[IterativeMethodBase], None], List[Callable[[IterativeMethodBase], None]]]\n    ],\n    record_loss: bool,\n):\n    method = IterativeMethodBase(callbacks=callbacks, record_loss=record_loss)\n\n    with pytest.raises(NotImplementedError) as exc_info:\n        method(n_iter=n_iter)\n\n    assert exc_info.type is NotImplementedError\n"
  },
  {
    "path": "tests/package/bss/test_cacgmm.py",
    "content": "from typing import Any, Callable, Dict, List, Optional, Union\n\nimport numpy as np\nimport pytest\nimport scipy.signal as ss\nfrom dummy.callback import DummyCallback, dummy_function\nfrom dummy.utils.dataset import download_sample_speech_data\n\nfrom ssspy.bss.cacgmm import CACGMM\n\nmax_duration = 0.5\nwindow = \"hann\"\nn_fft = 512\nhop_length = 256\nn_bins = n_fft // 2 + 1\nn_iter = 3\nrng = np.random.default_rng(42)\n\nparameters_callbacks = [None, dummy_function, [DummyCallback(), dummy_function]]\nparameters_permutation_alignment = [\n    \"posterior_score\",\n    \"amplitude_score\",\n    \"amplitude_correlation\",\n]\nparameters_cacgmm = [(2, 2, {}), (3, 2, {})]\n\n\n@pytest.mark.parametrize(\"n_sources, n_channels, reset_kwargs\", parameters_cacgmm)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"permutation_alignment\", parameters_permutation_alignment)\ndef test_cacgmm(\n    n_sources: int,\n    n_channels: int,\n    callbacks: Optional[Union[Callable[[CACGMM], None], List[Callable[[CACGMM], None]]]],\n    permutation_alignment: bool,\n    reset_kwargs: Dict[str, Any],\n):\n    if n_sources < 4:\n        sisec2010_tag = \"dev1_female3\"\n    elif n_sources == 4:\n        sisec2010_tag = \"dev1_female4\"\n    else:\n        raise ValueError(\"n_sources should be less than 5.\")\n\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=sisec2010_tag,\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n    waveform_mix = waveform_mix[:n_channels]\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=window, nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    cacgmm = CACGMM(\n        n_sources=n_sources,\n        callbacks=callbacks,\n        permutation_alignment=permutation_alignment,\n        rng=rng,\n    )\n\n    spectrogram_est = cacgmm(spectrogram_mix, n_iter=n_iter, **reset_kwargs)\n\n    assert spectrogram_est.shape == (n_sources,) + spectrogram_mix.shape[-2:]\n    assert type(cacgmm.loss[-1]) is float\n\n    # when posterior is not given\n    _spectrogram_est = cacgmm.separate(spectrogram_mix)\n\n    assert np.allclose(_spectrogram_est, spectrogram_est)\n\n    print(cacgmm)\n\n\ndef test_cacgmm_zero_norm() -> None:\n    \"\"\"Test input with zero norm.\"\"\"\n    n_channels, n_sources, n_samples = 2, 3, 10 * 8000\n    waveform_src_img = rng.standard_normal((n_channels, n_sources, n_samples))\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n    waveform_mix = waveform_mix[:n_channels]\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=window, nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n    # set 0 at most grids in 0th frequency bin\n    spectrogram_mix[:, 0, 1:-1] = 0\n\n    assert np.linalg.norm(spectrogram_mix, axis=0).any()\n\n    cacgmm = CACGMM(n_sources=n_sources, rng=rng)\n    spectrogram_est = cacgmm(spectrogram_mix, n_iter=n_iter)\n\n    assert spectrogram_est.shape == (n_sources,) + spectrogram_mix.shape[-2:]\n    assert type(cacgmm.loss[-1]) is float\n\n    # when posterior is not given\n    _spectrogram_est = cacgmm.separate(spectrogram_mix)\n\n    assert np.allclose(_spectrogram_est, spectrogram_est)\n"
  },
  {
    "path": "tests/package/bss/test_fdica.py",
    "content": "from typing import Any, Callable, Dict, List, Optional, Union\n\nimport numpy as np\nimport pytest\nimport scipy.signal as ss\nfrom dummy.callback import DummyCallback, dummy_function\nfrom dummy.utils.dataset import download_sample_speech_data\n\nfrom ssspy.bss.fdica import (\n    AuxFDICA,\n    AuxLaplaceFDICA,\n    GradFDICA,\n    GradFDICABase,\n    GradLaplaceFDICA,\n    NaturalGradFDICA,\n    NaturalGradLaplaceFDICA,\n)\n\nmax_duration = 0.5\nn_fft = 512\nhop_length = 256\nn_bins = n_fft // 2 + 1\nn_iter = 3\n\nparameters_callbacks = [None, dummy_function, [DummyCallback(), dummy_function]]\nparameters_is_holonomic = [True, False]\nparameters_scale_restoration = [True, False, \"projection_back\", \"minimal_distortion_principle\"]\nparameters_spatial_algorithm = [\"IP\", \"IP1\", \"IP2\"]\nparameters_grad_fdica = [\n    (2, {}),\n    (\n        3,\n        {\"demix_filter\": np.tile(-np.eye(3, dtype=np.complex128), reps=(n_bins, 1, 1))},\n    ),\n]\nparameters_aux_fdica = [\n    (2, {}),\n    (\n        3,\n        {\"demix_filter\": np.tile(-np.eye(3, dtype=np.complex128), reps=(n_bins, 1, 1))},\n    ),\n]\n\n\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\ndef test_grad_fdica_base(\n    callbacks: Optional[Union[Callable[[GradFDICA], None], List[Callable[[GradFDICA], None]]]],\n):\n    np.random.seed(111)\n\n    def contrast_fn(y):\n        return 2 * np.abs(y)\n\n    def score_fn(y):\n        denominator = np.maximum(np.abs(y), 1e-10)\n        return y / denominator\n\n    fdica = GradFDICABase(contrast_fn=contrast_fn, score_fn=score_fn, callbacks=callbacks)\n\n    print(fdica)\n\n\n@pytest.mark.parametrize(\"n_sources, reset_kwargs\", parameters_grad_fdica)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"is_holonomic\", parameters_is_holonomic)\ndef test_grad_fdica(\n    n_sources: int,\n    callbacks: Optional[Union[Callable[[GradFDICA], None], List[Callable[[GradFDICA], None]]]],\n    is_holonomic: bool,\n    reset_kwargs: Dict[Any, Any],\n):\n    np.random.seed(111)\n\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=\"dev1_female3\",\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    def contrast_fn(y):\n        return 2 * np.abs(y)\n\n    def score_fn(y):\n        denominator = np.maximum(np.abs(y), 1e-10)\n        return y / denominator\n\n    fdica = GradFDICA(\n        contrast_fn=contrast_fn, score_fn=score_fn, callbacks=callbacks, is_holonomic=is_holonomic\n    )\n    spectrogram_est = fdica(spectrogram_mix, n_iter=n_iter)\n\n    assert spectrogram_mix.shape == spectrogram_est.shape\n    assert type(fdica.loss[-1]) is float\n\n    print(fdica)\n\n\n@pytest.mark.parametrize(\"n_sources, reset_kwargs\", parameters_grad_fdica)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"is_holonomic\", parameters_is_holonomic)\ndef test_natural_grad_fdica(\n    n_sources: int,\n    callbacks: Optional[\n        Union[Callable[[NaturalGradFDICA], None], List[Callable[[NaturalGradFDICA], None]]]\n    ],\n    is_holonomic: bool,\n    reset_kwargs: Dict[Any, Any],\n):\n    np.random.seed(111)\n\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=\"dev1_female3\",\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    def contrast_fn(y):\n        return 2 * np.abs(y)\n\n    def score_fn(y):\n        denominator = np.maximum(np.abs(y), 1e-10)\n        return y / denominator\n\n    fdica = NaturalGradFDICA(\n        contrast_fn=contrast_fn, score_fn=score_fn, callbacks=callbacks, is_holonomic=is_holonomic\n    )\n    spectrogram_est = fdica(spectrogram_mix, n_iter=n_iter)\n\n    assert spectrogram_mix.shape == spectrogram_est.shape\n    assert type(fdica.loss[-1]) is float\n\n    print(fdica)\n\n\n@pytest.mark.parametrize(\"n_sources, reset_kwargs\", parameters_aux_fdica)\n@pytest.mark.parametrize(\"spatial_algorithm\", parameters_spatial_algorithm)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"scale_restoration\", parameters_scale_restoration)\ndef test_aux_fdica(\n    n_sources: int,\n    spatial_algorithm: str,\n    callbacks: Optional[Union[Callable[[AuxFDICA], None], List[Callable[[AuxFDICA], None]]]],\n    scale_restoration: Union[str, bool],\n    reset_kwargs: Dict[Any, Any],\n):\n    if spatial_algorithm in [\"IP\"] and not pytest.run_redundant:\n        pytest.skip(reason=\"Need --run-redundant option to run.\")\n\n    np.random.seed(111)\n\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=\"dev1_female3\",\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    def contrast_fn(y):\n        return 2 * np.abs(y)\n\n    def d_contrast_fn(y):\n        return 2 * np.ones_like(y)\n\n    fdica = AuxFDICA(\n        spatial_algorithm=spatial_algorithm,\n        contrast_fn=contrast_fn,\n        d_contrast_fn=d_contrast_fn,\n        callbacks=callbacks,\n        scale_restoration=scale_restoration,\n    )\n    spectrogram_est = fdica(spectrogram_mix, n_iter=n_iter)\n\n    assert spectrogram_mix.shape == spectrogram_est.shape\n    assert type(fdica.loss[-1]) is float\n\n    print(fdica)\n\n\n@pytest.mark.parametrize(\"n_sources, reset_kwargs\", parameters_grad_fdica)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"is_holonomic\", parameters_is_holonomic)\ndef test_grad_laplace_fdica(\n    n_sources: int,\n    callbacks: Optional[\n        Union[Callable[[GradLaplaceFDICA], None], List[Callable[[GradLaplaceFDICA], None]]]\n    ],\n    is_holonomic: bool,\n    reset_kwargs: Dict[Any, Any],\n):\n    np.random.seed(111)\n\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=\"dev1_female3\",\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    fdica = GradLaplaceFDICA(callbacks=callbacks, is_holonomic=is_holonomic)\n    spectrogram_est = fdica(spectrogram_mix, n_iter=n_iter)\n\n    assert spectrogram_mix.shape == spectrogram_est.shape\n    assert type(fdica.loss[-1]) is float\n\n    print(fdica)\n\n\n@pytest.mark.parametrize(\"n_sources, reset_kwargs\", parameters_grad_fdica)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"is_holonomic\", parameters_is_holonomic)\ndef test_natural_grad_laplace_fdica(\n    n_sources: int,\n    callbacks: Optional[\n        Union[\n            Callable[[NaturalGradLaplaceFDICA], None],\n            List[Callable[[NaturalGradLaplaceFDICA], None]],\n        ]\n    ],\n    is_holonomic: bool,\n    reset_kwargs: Dict[Any, Any],\n):\n    np.random.seed(111)\n\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=\"dev1_female3\",\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    fdica = NaturalGradLaplaceFDICA(callbacks=callbacks, is_holonomic=is_holonomic)\n    spectrogram_est = fdica(spectrogram_mix, n_iter=n_iter)\n\n    assert spectrogram_mix.shape == spectrogram_est.shape\n    assert type(fdica.loss[-1]) is float\n\n    print(fdica)\n\n\n@pytest.mark.parametrize(\"n_sources, reset_kwargs\", parameters_aux_fdica)\n@pytest.mark.parametrize(\"spatial_algorithm\", parameters_spatial_algorithm)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"scale_restoration\", parameters_scale_restoration)\ndef test_aux_laplace_fdica(\n    n_sources: int,\n    spatial_algorithm: str,\n    callbacks: Optional[\n        Union[Callable[[AuxLaplaceFDICA], None], List[Callable[[AuxLaplaceFDICA], None]]]\n    ],\n    scale_restoration: Union[str, bool],\n    reset_kwargs: Dict[Any, Any],\n):\n    if spatial_algorithm in [\"IP\"] and not pytest.run_redundant:\n        pytest.skip(reason=\"Need --run-redundant option to run.\")\n\n    np.random.seed(111)\n\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=\"dev1_female3\",\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    fdica = AuxLaplaceFDICA(\n        spatial_algorithm=spatial_algorithm,\n        callbacks=callbacks,\n        scale_restoration=scale_restoration,\n    )\n    spectrogram_est = fdica(spectrogram_mix, n_iter=n_iter)\n\n    assert spectrogram_mix.shape == spectrogram_est.shape\n    assert type(fdica.loss[-1]) is float\n\n    print(fdica)\n"
  },
  {
    "path": "tests/package/bss/test_hva.py",
    "content": "from typing import Any, Callable, Dict, List, Optional, Union\n\nimport numpy as np\nimport pytest\nimport scipy.signal as ss\nfrom dummy.callback import DummyCallback, dummy_function\nfrom dummy.utils.dataset import download_sample_speech_data\n\nfrom ssspy.bss.hva import HVA, MaskingADMMHVA, MaskingPDSHVA\n\nmax_duration = 0.5\nn_fft = 2048\nhop_length = 1024\nn_bins = n_fft // 2 + 1\nn_iter = 5\n\nparameters_hva = [\n    (2, None, {}),\n    (\n        3,\n        dummy_function,\n        {\"demix_filter\": np.tile(-np.eye(3, dtype=np.complex128), reps=(n_bins, 1, 1))},\n    ),\n    (2, [DummyCallback(), dummy_function], {}),\n]\n\n\n@pytest.mark.parametrize(\"n_sources, callbacks, reset_kwargs\", parameters_hva)\ndef test_masking_pdshva(\n    n_sources: int,\n    callbacks: Optional[\n        Union[Callable[[MaskingPDSHVA], None], List[Callable[[MaskingPDSHVA], None]]]\n    ],\n    reset_kwargs: Dict[Any, Any],\n):\n    waveform_src_img, _ = download_sample_speech_data(\n        sisec2010_root=\"./tests/.data/SiSEC2010\",\n        mird_root=\"./tests/.data/MIRD\",\n        n_sources=n_sources,\n        sisec2010_tag=\"dev1_female3\",\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    hva = MaskingPDSHVA(callbacks=callbacks)\n\n    spectrogram_mix_normalized = hva.normalize_by_spectral_norm(spectrogram_mix)\n    spectrogram_est = hva(spectrogram_mix_normalized, n_iter=n_iter, **reset_kwargs)\n\n    assert spectrogram_mix.shape == spectrogram_est.shape\n\n    print(hva)\n\n\n@pytest.mark.parametrize(\"n_sources, callbacks, reset_kwargs\", parameters_hva)\ndef test_masking_admmhva(\n    n_sources: int,\n    callbacks: Optional[\n        Union[Callable[[MaskingADMMHVA], None], List[Callable[[MaskingADMMHVA], None]]]\n    ],\n    reset_kwargs: Dict[Any, Any],\n):\n    waveform_src_img, _ = download_sample_speech_data(\n        sisec2010_root=\"./tests/.data/SiSEC2010\",\n        mird_root=\"./tests/.data/MIRD\",\n        n_sources=n_sources,\n        sisec2010_tag=\"dev1_female3\",\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    hva = MaskingADMMHVA(callbacks=callbacks)\n\n    spectrogram_mix_normalized = hva.normalize_by_spectral_norm(spectrogram_mix)\n    spectrogram_est = hva(spectrogram_mix_normalized, n_iter=n_iter, **reset_kwargs)\n\n    assert spectrogram_mix.shape == spectrogram_est.shape\n\n    print(hva)\n\n\n@pytest.mark.parametrize(\"n_sources, callbacks, reset_kwargs\", parameters_hva)\ndef test_hva(\n    n_sources: int,\n    callbacks: Optional[Union[Callable[[HVA], None], List[Callable[[HVA], None]]]],\n    reset_kwargs: Dict[Any, Any],\n):\n    waveform_src_img, _ = download_sample_speech_data(\n        sisec2010_root=\"./tests/.data/SiSEC2010\",\n        mird_root=\"./tests/.data/MIRD\",\n        n_sources=n_sources,\n        sisec2010_tag=\"dev1_female3\",\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    hva = HVA(callbacks=callbacks)\n\n    spectrogram_mix_normalized = hva.normalize_by_spectral_norm(spectrogram_mix)\n    spectrogram_est = hva(spectrogram_mix_normalized, n_iter=n_iter, **reset_kwargs)\n\n    assert spectrogram_mix.shape == spectrogram_est.shape\n\n    print(hva)\n"
  },
  {
    "path": "tests/package/bss/test_ica.py",
    "content": "from typing import Any, Callable, Dict, List, Optional, Union\n\nimport numpy as np\nimport pytest\nfrom dummy.callback import DummyCallback, dummy_function\nfrom dummy.utils.dataset import download_sample_speech_data\n\nfrom ssspy.bss.ica import (\n    FastICA,\n    GradICA,\n    GradICABase,\n    GradLaplaceICA,\n    NaturalGradICA,\n    NaturalGradLaplaceICA,\n)\n\nmax_duration = 0.5\nn_iter = 3\n\nparameters_callbacks = [None, dummy_function, [DummyCallback(), dummy_function]]\nparameters_is_holonomic = [True, False]\nparameters_grad_ica = [\n    (2, {}),\n    (3, {\"demix_filter\": -np.eye(3)}),\n]\nparameters_fast_ica = [\n    (2, {}),\n    (3, {\"demix_filter\": -np.eye(3)}),\n]\n\n\n@pytest.mark.parametrize(\"n_sources, reset_kwargs\", parameters_grad_ica)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"is_holonomic\", parameters_is_holonomic)\ndef test_grad_ica_base(\n    n_sources: int,\n    callbacks: Optional[Union[Callable[[GradICA], None], List[Callable[[GradICA], None]]]],\n    is_holonomic: bool,\n    reset_kwargs: Dict[Any, Any],\n):\n    def contrast_fn(x):\n        return np.log(1 + np.exp(x))\n\n    def score_fn(x):\n        return 1 / (1 + np.exp(-x))\n\n    ica = GradICABase(contrast_fn=contrast_fn, score_fn=score_fn, callbacks=callbacks)\n\n    print(ica)\n\n\n@pytest.mark.parametrize(\"n_sources, reset_kwargs\", parameters_grad_ica)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"is_holonomic\", parameters_is_holonomic)\ndef test_grad_ica(\n    n_sources: int,\n    callbacks: Optional[Union[Callable[[GradICA], None], List[Callable[[GradICA], None]]]],\n    is_holonomic: bool,\n    reset_kwargs: Dict[Any, Any],\n):\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=\"dev1_female3\",\n        max_duration=max_duration,\n        conv=False,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    def contrast_fn(x):\n        return np.log(1 + np.exp(x))\n\n    def score_fn(x):\n        return 1 / (1 + np.exp(-x))\n\n    ica = GradICA(\n        contrast_fn=contrast_fn, score_fn=score_fn, callbacks=callbacks, is_holonomic=is_holonomic\n    )\n    waveform_est = ica(waveform_mix, n_iter=n_iter)\n\n    assert waveform_mix.shape == waveform_est.shape\n    assert type(ica.loss[-1]) is float\n\n    print(ica)\n\n\n@pytest.mark.parametrize(\"n_sources, reset_kwargs\", parameters_grad_ica)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"is_holonomic\", parameters_is_holonomic)\ndef test_natural_grad_ica(\n    n_sources: int,\n    callbacks: Optional[Union[Callable[[GradICA], None], List[Callable[[GradICA], None]]]],\n    is_holonomic: bool,\n    reset_kwargs: Dict[Any, Any],\n):\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=\"dev1_female3\",\n        max_duration=max_duration,\n        conv=False,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    def contrast_fn(x):\n        return np.log(1 + np.exp(x))\n\n    def score_fn(x):\n        return 1 / (1 + np.exp(-x))\n\n    ica = NaturalGradICA(\n        contrast_fn=contrast_fn, score_fn=score_fn, callbacks=callbacks, is_holonomic=is_holonomic\n    )\n    waveform_est = ica(waveform_mix, n_iter=n_iter)\n\n    assert waveform_mix.shape == waveform_est.shape\n    assert type(ica.loss[-1]) is float\n\n    print(ica)\n\n\n@pytest.mark.parametrize(\"n_sources, reset_kwargs\", parameters_grad_ica)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"is_holonomic\", parameters_is_holonomic)\ndef test_grad_laplace_ica(\n    n_sources: int,\n    callbacks: Optional[Union[Callable[[GradICA], None], List[Callable[[GradICA], None]]]],\n    is_holonomic: bool,\n    reset_kwargs: Dict[Any, Any],\n):\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=\"dev1_female3\",\n        max_duration=max_duration,\n        conv=False,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    ica = GradLaplaceICA(callbacks=callbacks, is_holonomic=is_holonomic)\n    waveform_est = ica(waveform_mix, n_iter=n_iter)\n\n    assert waveform_mix.shape == waveform_est.shape\n    assert type(ica.loss[-1]) is float\n\n    print(ica)\n\n\n@pytest.mark.parametrize(\"n_sources, reset_kwargs\", parameters_grad_ica)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"is_holonomic\", parameters_is_holonomic)\ndef test_natural_grad_laplace_ica(\n    n_sources: int,\n    callbacks: Optional[Union[Callable[[GradICA], None], List[Callable[[GradICA], None]]]],\n    is_holonomic: bool,\n    reset_kwargs: Dict[Any, Any],\n):\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=\"dev1_female3\",\n        max_duration=max_duration,\n        conv=False,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    ica = NaturalGradLaplaceICA(callbacks=callbacks, is_holonomic=is_holonomic)\n    waveform_est = ica(waveform_mix, n_iter=n_iter)\n\n    assert waveform_mix.shape == waveform_est.shape\n    assert type(ica.loss[-1]) is float\n\n    print(ica)\n\n\n@pytest.mark.parametrize(\"n_sources, reset_kwargs\", parameters_fast_ica)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\ndef test_fast_ica(\n    n_sources: int,\n    callbacks: Optional[Union[Callable[[FastICA], None], List[Callable[[FastICA], None]]]],\n    reset_kwargs: Dict[Any, Any],\n):\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=\"dev1_female3\",\n        max_duration=max_duration,\n        conv=False,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    def contrast_fn(x):\n        return np.log(1 + np.exp(x))\n\n    def score_fn(x):\n        return 1 / (1 + np.exp(-x))\n\n    def d_score_fn(x):\n        s = 1 / (1 + np.exp(-x))\n        return s * (1 - s)\n\n    ica = FastICA(\n        contrast_fn=contrast_fn, score_fn=score_fn, d_score_fn=d_score_fn, callbacks=callbacks\n    )\n    waveform_est = ica(waveform_mix, n_iter=n_iter)\n\n    assert waveform_mix.shape == waveform_est.shape\n    assert type(ica.loss[-1]) is float\n\n    print(ica)\n"
  },
  {
    "path": "tests/package/bss/test_ilrma.py",
    "content": "from typing import Any, Callable, Dict, List, Optional, Union\n\nimport numpy as np\nimport pytest\nimport scipy.signal as ss\nfrom dummy.callback import DummyCallback, dummy_function\nfrom dummy.utils.dataset import download_sample_speech_data\n\nfrom ssspy.bss.ilrma import GGDILRMA, TILRMA, GaussILRMA, ILRMABase\n\nmax_duration = 0.5\nn_fft = 512\nhop_length = 256\nn_bins = n_fft // 2 + 1\nn_iter = 3\nrng = np.random.default_rng(42)\n\nparameters_dof = [100]\nparameters_beta = [0.5, 1.5]\nparameters_spatial_algorithm = [\"IP\", \"IP1\", \"IP2\", \"ISS\", \"ISS1\", \"ISS2\", \"IPA\"]\nparameters_source_algorithm = [\"MM\", \"ME\"]\nparameters_callbacks = [None, dummy_function, [DummyCallback(), dummy_function]]\nparameters_scale_restoration = [True, False, \"projection_back\", \"minimal_distortion_principle\"]\nparameters_ilrma_base = [2]\nparameters_ilrma_latent = [\n    (\n        2,\n        4,\n        2,\n        {\n            \"demix_filter\": np.tile(np.eye(2, dtype=np.complex128), (n_bins, 1, 1)),\n            \"latent\": rng.random((2, 4)),\n            \"basis\": rng.random((n_bins, 4)),\n        },\n    ),\n    (3, 3, 1, {}),\n]\nparameters_ilrma_wo_latent = [\n    (\n        2,\n        2,\n        2,\n        {\n            \"demix_filter\": np.tile(np.eye(2, dtype=np.complex128), (n_bins, 1, 1)),\n            \"basis\": rng.random((2, n_bins, 2)),\n        },\n    ),\n    (\n        3,\n        1,\n        1,\n        {},\n    ),\n]\nparameters_normalization_latent = [True, False, \"power\"]\nparameters_normalization_wo_latent = [True, False, \"power\", \"projection_back\"]\n\n\n@pytest.mark.parametrize(\n    \"n_basis\",\n    parameters_ilrma_base,\n)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"scale_restoration\", parameters_scale_restoration)\ndef test_ilrma_base(\n    n_basis: int,\n    callbacks: Optional[Union[Callable[[GaussILRMA], None], List[Callable[[GaussILRMA], None]]]],\n    scale_restoration: Union[str, bool],\n):\n    ilrma = ILRMABase(\n        n_basis,\n        partitioning=True,\n        callbacks=callbacks,\n        scale_restoration=scale_restoration,\n        rng=np.random.default_rng(42),\n    )\n\n    print(ilrma)\n\n\n@pytest.mark.parametrize(\n    \"n_sources, n_basis, domain, reset_kwargs\",\n    parameters_ilrma_latent,\n)\n@pytest.mark.parametrize(\"spatial_algorithm\", parameters_spatial_algorithm)\n@pytest.mark.parametrize(\"source_algorithm\", parameters_source_algorithm)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"normalization\", parameters_normalization_latent)\n@pytest.mark.parametrize(\"scale_restoration\", parameters_scale_restoration)\ndef test_gauss_ilrma_latent(\n    n_sources: int,\n    n_basis: int,\n    spatial_algorithm: str,\n    source_algorithm: str,\n    domain: float,\n    callbacks: Optional[Union[Callable[[GaussILRMA], None], List[Callable[[GaussILRMA], None]]]],\n    normalization: Optional[Union[str, bool]],\n    scale_restoration: Union[str, bool],\n    reset_kwargs: Dict[str, Any],\n):\n    if spatial_algorithm in [\"IP\", \"ISS\"] and not pytest.run_redundant:\n        pytest.skip(reason=\"Need --run-redundant option to run.\")\n\n    if n_sources < 4:\n        sisec2010_tag = \"dev1_female3\"\n    elif n_sources == 4:\n        sisec2010_tag = \"dev1_female4\"\n    else:\n        raise ValueError(\"n_sources should be less than 5.\")\n\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=sisec2010_tag,\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    kwargs = {\n        \"spatial_algorithm\": spatial_algorithm,\n        \"source_algorithm\": source_algorithm,\n        \"domain\": domain,\n        \"partitioning\": True,\n        \"callbacks\": callbacks,\n        \"normalization\": normalization,\n        \"scale_restoration\": scale_restoration,\n        \"rng\": np.random.default_rng(42),\n    }\n\n    if source_algorithm == \"ME\" and domain != 2:\n        with pytest.raises(AssertionError) as e:\n            ilrma = GaussILRMA(n_basis, **kwargs)\n\n        assert str(e.value) == \"domain parameter should be 2 when you specify ME algorithm.\"\n    else:\n        ilrma = GaussILRMA(n_basis, **kwargs)\n        spectrogram_est = ilrma(spectrogram_mix, n_iter=n_iter, **reset_kwargs)\n\n        assert spectrogram_mix.shape == spectrogram_est.shape\n        assert type(ilrma.loss[-1]) is float\n\n        if spatial_algorithm in [\"ISS\", \"ISS1\", \"ISS2\"]:\n            assert ilrma.demix_filter is None\n\n        print(ilrma)\n\n\n@pytest.mark.parametrize(\n    \"n_sources, n_basis, domain, reset_kwargs\",\n    parameters_ilrma_wo_latent,\n)\n@pytest.mark.parametrize(\"spatial_algorithm\", parameters_spatial_algorithm)\n@pytest.mark.parametrize(\"source_algorithm\", parameters_source_algorithm)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"normalization\", parameters_normalization_wo_latent)\n@pytest.mark.parametrize(\"scale_restoration\", parameters_scale_restoration)\ndef test_gauss_ilrma_wo_latent(\n    n_sources: int,\n    n_basis: int,\n    spatial_algorithm: str,\n    source_algorithm: str,\n    domain: float,\n    callbacks: Optional[Union[Callable[[GaussILRMA], None], List[Callable[[GaussILRMA], None]]]],\n    normalization: Optional[Union[str, bool]],\n    scale_restoration: Union[str, bool],\n    reset_kwargs: Dict[str, Any],\n):\n    if spatial_algorithm in [\"IP\", \"ISS\"] and not pytest.run_redundant:\n        pytest.skip(reason=\"Need --run-redundant option to run.\")\n\n    if n_sources < 4:\n        sisec2010_tag = \"dev1_female3\"\n    elif n_sources == 4:\n        sisec2010_tag = \"dev1_female4\"\n    else:\n        raise ValueError(\"n_sources should be less than 5.\")\n\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=sisec2010_tag,\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    kwargs = {\n        \"spatial_algorithm\": spatial_algorithm,\n        \"source_algorithm\": source_algorithm,\n        \"domain\": domain,\n        \"partitioning\": False,\n        \"callbacks\": callbacks,\n        \"normalization\": normalization,\n        \"scale_restoration\": scale_restoration,\n        \"rng\": np.random.default_rng(42),\n    }\n\n    if source_algorithm == \"ME\" and domain != 2:\n        with pytest.raises(AssertionError) as e:\n            ilrma = GaussILRMA(n_basis, **kwargs)\n\n        assert str(e.value) == \"domain parameter should be 2 when you specify ME algorithm.\"\n    else:\n        ilrma = GaussILRMA(n_basis, **kwargs)\n        spectrogram_est = ilrma(spectrogram_mix, n_iter=n_iter, **reset_kwargs)\n\n        assert spectrogram_mix.shape == spectrogram_est.shape\n        assert type(ilrma.loss[-1]) is float\n\n        if spatial_algorithm in [\"ISS\", \"ISS1\", \"ISS2\"]:\n            assert ilrma.demix_filter is None\n\n        print(ilrma)\n\n\n@pytest.mark.parametrize(\n    \"n_sources, n_basis, domain, reset_kwargs\",\n    parameters_ilrma_latent,\n)\n@pytest.mark.parametrize(\"dof\", parameters_dof)\n@pytest.mark.parametrize(\"spatial_algorithm\", parameters_spatial_algorithm)\n@pytest.mark.parametrize(\"source_algorithm\", parameters_source_algorithm)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"normalization\", parameters_normalization_latent)\n@pytest.mark.parametrize(\"scale_restoration\", parameters_scale_restoration)\ndef test_t_ilrma_latent(\n    n_sources: int,\n    n_basis: int,\n    dof: float,\n    spatial_algorithm: str,\n    source_algorithm: str,\n    domain: float,\n    callbacks: Optional[Union[Callable[[GaussILRMA], None], List[Callable[[GaussILRMA], None]]]],\n    normalization: Optional[Union[str, bool]],\n    scale_restoration: Union[str, bool],\n    reset_kwargs: Dict[str, Any],\n):\n    if spatial_algorithm in [\"IP\", \"ISS\"] and not pytest.run_redundant:\n        pytest.skip(reason=\"Need --run-redundant option to run.\")\n\n    if n_sources < 4:\n        sisec2010_tag = \"dev1_female3\"\n    elif n_sources == 4:\n        sisec2010_tag = \"dev1_female4\"\n    else:\n        raise ValueError(\"n_sources should be less than 5.\")\n\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=sisec2010_tag,\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    kwargs = {\n        \"dof\": dof,\n        \"spatial_algorithm\": spatial_algorithm,\n        \"source_algorithm\": source_algorithm,\n        \"domain\": domain,\n        \"partitioning\": True,\n        \"callbacks\": callbacks,\n        \"normalization\": normalization,\n        \"scale_restoration\": scale_restoration,\n        \"rng\": np.random.default_rng(42),\n    }\n\n    if spatial_algorithm == \"IPA\":\n        with pytest.raises(ValueError) as e:\n            ilrma = TILRMA(n_basis, **kwargs)\n\n        assert str(e.value) == \"IPA is not supported for t-ILRMA.\"\n    elif source_algorithm == \"ME\" and domain != 2:\n        with pytest.raises(AssertionError) as e:\n            ilrma = TILRMA(n_basis, **kwargs)\n\n        assert str(e.value) == \"domain parameter should be 2 when you specify ME algorithm.\"\n    else:\n        ilrma = TILRMA(n_basis, **kwargs)\n        spectrogram_est = ilrma(spectrogram_mix, n_iter=n_iter, **reset_kwargs)\n\n        assert spectrogram_mix.shape == spectrogram_est.shape\n        assert type(ilrma.loss[-1]) is float\n\n        if spatial_algorithm in [\"ISS\", \"ISS1\", \"ISS2\"]:\n            assert ilrma.demix_filter is None\n\n        print(ilrma)\n\n\n@pytest.mark.parametrize(\n    \"n_sources, n_basis, domain, reset_kwargs\",\n    parameters_ilrma_wo_latent,\n)\n@pytest.mark.parametrize(\"dof\", parameters_dof)\n@pytest.mark.parametrize(\"spatial_algorithm\", parameters_spatial_algorithm)\n@pytest.mark.parametrize(\"source_algorithm\", parameters_source_algorithm)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"normalization\", parameters_normalization_wo_latent)\n@pytest.mark.parametrize(\"scale_restoration\", parameters_scale_restoration)\ndef test_t_ilrma_wo_latent(\n    n_sources: int,\n    n_basis: int,\n    dof: float,\n    spatial_algorithm: str,\n    source_algorithm: str,\n    domain: float,\n    callbacks: Optional[Union[Callable[[GaussILRMA], None], List[Callable[[GaussILRMA], None]]]],\n    normalization: Optional[Union[str, bool]],\n    scale_restoration: Union[str, bool],\n    reset_kwargs: Dict[str, Any],\n):\n    if spatial_algorithm in [\"IP\", \"ISS\"] and not pytest.run_redundant:\n        pytest.skip(reason=\"Need --run-redundant option to run.\")\n\n    if n_sources < 4:\n        sisec2010_tag = \"dev1_female3\"\n    elif n_sources == 4:\n        sisec2010_tag = \"dev1_female4\"\n    else:\n        raise ValueError(\"n_sources should be less than 5.\")\n\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=sisec2010_tag,\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    kwargs = {\n        \"dof\": dof,\n        \"spatial_algorithm\": spatial_algorithm,\n        \"source_algorithm\": source_algorithm,\n        \"domain\": domain,\n        \"partitioning\": False,\n        \"callbacks\": callbacks,\n        \"normalization\": normalization,\n        \"scale_restoration\": scale_restoration,\n        \"rng\": np.random.default_rng(42),\n    }\n\n    if spatial_algorithm == \"IPA\":\n        with pytest.raises(ValueError) as e:\n            ilrma = TILRMA(n_basis, **kwargs)\n\n        assert str(e.value) == \"IPA is not supported for t-ILRMA.\"\n    elif source_algorithm == \"ME\" and domain != 2:\n        with pytest.raises(AssertionError) as e:\n            ilrma = TILRMA(n_basis, **kwargs)\n\n        assert str(e.value) == \"domain parameter should be 2 when you specify ME algorithm.\"\n    else:\n        ilrma = TILRMA(n_basis, **kwargs)\n        spectrogram_est = ilrma(spectrogram_mix, n_iter=n_iter, **reset_kwargs)\n\n        assert spectrogram_mix.shape == spectrogram_est.shape\n        assert type(ilrma.loss[-1]) is float\n\n        if spatial_algorithm in [\"ISS\", \"ISS1\", \"ISS2\"]:\n            assert ilrma.demix_filter is None\n\n        print(ilrma)\n\n\n@pytest.mark.parametrize(\n    \"n_sources, n_basis, domain, reset_kwargs\",\n    parameters_ilrma_latent,\n)\n@pytest.mark.parametrize(\"beta\", parameters_beta)\n@pytest.mark.parametrize(\"spatial_algorithm\", parameters_spatial_algorithm)\n@pytest.mark.parametrize(\"source_algorithm\", parameters_source_algorithm)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"normalization\", parameters_normalization_latent)\n@pytest.mark.parametrize(\"scale_restoration\", parameters_scale_restoration)\ndef test_ggd_ilrma_latent(\n    n_sources: int,\n    n_basis: int,\n    beta: float,\n    spatial_algorithm: str,\n    source_algorithm: str,\n    domain: float,\n    callbacks: Optional[Union[Callable[[GaussILRMA], None], List[Callable[[GaussILRMA], None]]]],\n    normalization: Optional[Union[str, bool]],\n    scale_restoration: Union[str, bool],\n    reset_kwargs: Dict[str, Any],\n):\n    if spatial_algorithm in [\"IP\", \"ISS\"] and not pytest.run_redundant:\n        pytest.skip(reason=\"Need --run-redundant option to run.\")\n\n    if n_sources < 4:\n        sisec2010_tag = \"dev1_female3\"\n    elif n_sources == 4:\n        sisec2010_tag = \"dev1_female4\"\n    else:\n        raise ValueError(\"n_sources should be less than 5.\")\n\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=sisec2010_tag,\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    kwargs = {\n        \"beta\": beta,\n        \"spatial_algorithm\": spatial_algorithm,\n        \"source_algorithm\": source_algorithm,\n        \"domain\": domain,\n        \"partitioning\": True,\n        \"callbacks\": callbacks,\n        \"normalization\": normalization,\n        \"scale_restoration\": scale_restoration,\n        \"rng\": np.random.default_rng(42),\n    }\n\n    if source_algorithm == \"ME\":\n        with pytest.raises(AssertionError) as e:\n            ilrma = GGDILRMA(n_basis, **kwargs)\n\n        assert str(e.value) == \"Not support {}.\".format(source_algorithm)\n    elif spatial_algorithm == \"IPA\":\n        with pytest.raises(ValueError) as e:\n            ilrma = GGDILRMA(n_basis, **kwargs)\n\n        assert str(e.value) == \"IPA is not supported for GGD-ILRMA.\"\n    else:\n        ilrma = GGDILRMA(n_basis, **kwargs)\n        spectrogram_est = ilrma(spectrogram_mix, n_iter=n_iter, **reset_kwargs)\n\n        assert spectrogram_mix.shape == spectrogram_est.shape\n        assert type(ilrma.loss[-1]) is float\n\n        if spatial_algorithm in [\"ISS\", \"ISS1\", \"ISS2\"]:\n            assert ilrma.demix_filter is None\n\n        print(ilrma)\n\n\n@pytest.mark.parametrize(\n    \"n_sources, n_basis, domain, reset_kwargs\",\n    parameters_ilrma_wo_latent,\n)\n@pytest.mark.parametrize(\"beta\", parameters_beta)\n@pytest.mark.parametrize(\"spatial_algorithm\", parameters_spatial_algorithm)\n@pytest.mark.parametrize(\"source_algorithm\", parameters_source_algorithm)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"normalization\", parameters_normalization_wo_latent)\n@pytest.mark.parametrize(\"scale_restoration\", parameters_scale_restoration)\ndef test_ggd_ilrma_wo_latent(\n    n_sources: int,\n    n_basis: int,\n    beta: float,\n    spatial_algorithm: str,\n    source_algorithm: str,\n    domain: float,\n    callbacks: Optional[Union[Callable[[GaussILRMA], None], List[Callable[[GaussILRMA], None]]]],\n    normalization: Optional[Union[str, bool]],\n    scale_restoration: Union[str, bool],\n    reset_kwargs: Dict[str, Any],\n):\n    if spatial_algorithm in [\"IP\", \"ISS\"] and not pytest.run_redundant:\n        pytest.skip(reason=\"Need --run-redundant option to run.\")\n\n    if n_sources < 4:\n        sisec2010_tag = \"dev1_female3\"\n    elif n_sources == 4:\n        sisec2010_tag = \"dev1_female4\"\n    else:\n        raise ValueError(\"n_sources should be less than 5.\")\n\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=sisec2010_tag,\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    kwargs = {\n        \"beta\": beta,\n        \"spatial_algorithm\": spatial_algorithm,\n        \"source_algorithm\": source_algorithm,\n        \"domain\": domain,\n        \"partitioning\": False,\n        \"callbacks\": callbacks,\n        \"normalization\": normalization,\n        \"scale_restoration\": scale_restoration,\n        \"rng\": np.random.default_rng(42),\n    }\n\n    if source_algorithm == \"ME\":\n        with pytest.raises(AssertionError) as e:\n            ilrma = GGDILRMA(n_basis, **kwargs)\n\n        assert str(e.value) == \"Not support {}.\".format(source_algorithm)\n    elif spatial_algorithm == \"IPA\":\n        with pytest.raises(ValueError) as e:\n            ilrma = GGDILRMA(n_basis, **kwargs)\n\n        assert str(e.value) == \"IPA is not supported for GGD-ILRMA.\"\n    else:\n        ilrma = GGDILRMA(n_basis, **kwargs)\n        spectrogram_est = ilrma(spectrogram_mix, n_iter=n_iter, **reset_kwargs)\n\n        assert spectrogram_mix.shape == spectrogram_est.shape\n        assert type(ilrma.loss[-1]) is float\n\n        if spatial_algorithm in [\"ISS\", \"ISS1\", \"ISS2\"]:\n            assert ilrma.demix_filter is None\n\n        print(ilrma)\n"
  },
  {
    "path": "tests/package/bss/test_ipsdta.py",
    "content": "from typing import Any, Callable, Dict, List, Optional, Union\n\nimport numpy as np\nimport pytest\nimport scipy.signal as ss\nfrom dummy.callback import DummyCallback, dummy_function\nfrom dummy.utils.dataset import download_sample_speech_data\n\nfrom ssspy.bss.ipsdta import TIPSDTA, BlockDecompositionIPSDTABase, GaussIPSDTA, IPSDTABase\n\nmax_duration = 0.1\nn_fft = 256\nhop_length = 128\nwindow = \"hann\"\nn_bins = n_fft // 2 + 1\nn_iter = 3\nrng = np.random.default_rng(42)\n\nparameters_dof = [100]\nparameters_spatial_algorithm = [\"FPI\", \"VCD\"]\nparameters_callbacks = [None, dummy_function, [DummyCallback(), dummy_function]]\nparameters_source_normalization = [True, False]\nparameters_scale_restoration = [True, False, \"projection_back\", \"minimal_distortion_principle\"]\nparameters_ipsdta_base = [2]\nparameters_block_decomposition_ipsdta_base = [4]\nparameters_ipsdta = [\n    (\n        2,\n        2,\n        43,\n        {\n            \"demix_filter\": np.tile(np.eye(2, dtype=np.complex128), (n_bins, 1, 1)),\n        },\n    ),\n    (3, 2, 64, {}),\n]\n\n\n@pytest.mark.parametrize(\n    \"n_basis\",\n    parameters_ipsdta_base,\n)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"scale_restoration\", parameters_scale_restoration)\ndef test_ipsdta_base(\n    n_basis: int,\n    callbacks: Optional[Union[Callable[[IPSDTABase], None], List[Callable[[IPSDTABase], None]]]],\n    scale_restoration: Union[str, bool],\n):\n    ipsdta = IPSDTABase(\n        n_basis,\n        callbacks=callbacks,\n        scale_restoration=scale_restoration,\n        record_loss=False,\n        rng=rng,\n    )\n\n    print(ipsdta)\n\n\n@pytest.mark.parametrize(\n    \"n_basis\",\n    parameters_ipsdta_base,\n)\n@pytest.mark.parametrize(\n    \"n_blocks\",\n    parameters_block_decomposition_ipsdta_base,\n)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"scale_restoration\", parameters_scale_restoration)\ndef test_block_decomposition_ipsdta_base(\n    n_basis: int,\n    n_blocks: int,\n    callbacks: Optional[\n        Union[\n            Callable[[BlockDecompositionIPSDTABase], None],\n            List[Callable[[BlockDecompositionIPSDTABase], None]],\n        ]\n    ],\n    scale_restoration: Union[str, bool],\n):\n    ipsdta = BlockDecompositionIPSDTABase(\n        n_basis,\n        n_blocks,\n        callbacks=callbacks,\n        scale_restoration=scale_restoration,\n        record_loss=False,\n        rng=rng,\n    )\n\n    print(ipsdta)\n\n\n@pytest.mark.parametrize(\n    \"n_sources, n_basis, n_blocks, reset_kwargs\",\n    parameters_ipsdta,\n)\n@pytest.mark.parametrize(\"spatial_algorithm\", parameters_spatial_algorithm)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"source_normalization\", parameters_source_normalization)\n@pytest.mark.parametrize(\"scale_restoration\", parameters_scale_restoration)\ndef test_gauss_ipsdta(\n    n_sources: int,\n    n_basis: int,\n    n_blocks: int,\n    spatial_algorithm: str,\n    callbacks: Optional[Union[Callable[[GaussIPSDTA], None], List[Callable[[GaussIPSDTA], None]]]],\n    source_normalization: Optional[Union[str, bool]],\n    scale_restoration: Union[str, bool],\n    reset_kwargs: Dict[str, Any],\n):\n    if n_sources < 4:\n        sisec2010_tag = \"dev1_female3\"\n    elif n_sources == 4:\n        sisec2010_tag = \"dev1_female4\"\n    else:\n        raise ValueError(\"n_sources should be less than 5.\")\n\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=sisec2010_tag,\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=window, nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    ipsdta = GaussIPSDTA(\n        n_basis,\n        n_blocks,\n        spatial_algorithm=spatial_algorithm,\n        callbacks=callbacks,\n        source_normalization=source_normalization,\n        scale_restoration=scale_restoration,\n        rng=rng,\n    )\n\n    if spatial_algorithm == \"FPI\":\n        with pytest.raises(NotImplementedError) as e:\n            spectrogram_est = ipsdta(spectrogram_mix, n_iter=n_iter, **reset_kwargs)\n\n        assert str(e.value) == \"IPSDTA with fixed-point iteration is not supported.\"\n    else:\n        spectrogram_est = ipsdta(spectrogram_mix, n_iter=n_iter, **reset_kwargs)\n\n        assert spectrogram_mix.shape == spectrogram_est.shape\n        assert type(ipsdta.loss[-1]) is float\n\n    print(ipsdta)\n\n\n@pytest.mark.parametrize(\n    \"n_sources, n_basis, n_blocks, reset_kwargs\",\n    parameters_ipsdta,\n)\n@pytest.mark.parametrize(\"dof\", parameters_dof)\n@pytest.mark.parametrize(\"spatial_algorithm\", parameters_spatial_algorithm)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"source_normalization\", parameters_source_normalization)\n@pytest.mark.parametrize(\"scale_restoration\", parameters_scale_restoration)\ndef test_t_ipsdta(\n    n_sources: int,\n    n_basis: int,\n    n_blocks: int,\n    dof: float,\n    spatial_algorithm: str,\n    callbacks: Optional[Union[Callable[[GaussIPSDTA], None], List[Callable[[GaussIPSDTA], None]]]],\n    source_normalization: Optional[Union[str, bool]],\n    scale_restoration: Union[str, bool],\n    reset_kwargs: Dict[str, Any],\n):\n    if n_sources < 4:\n        sisec2010_tag = \"dev1_female3\"\n    elif n_sources == 4:\n        sisec2010_tag = \"dev1_female4\"\n    else:\n        raise ValueError(\"n_sources should be less than 5.\")\n\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=sisec2010_tag,\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=window, nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    ipsdta = TIPSDTA(\n        n_basis,\n        n_blocks,\n        dof=dof,\n        spatial_algorithm=spatial_algorithm,\n        callbacks=callbacks,\n        source_normalization=source_normalization,\n        scale_restoration=scale_restoration,\n        rng=rng,\n    )\n\n    if spatial_algorithm != \"VCD\":\n        with pytest.raises(NotImplementedError) as e:\n            spectrogram_est = ipsdta(spectrogram_mix, n_iter=n_iter, **reset_kwargs)\n\n        assert str(e.value) == \"Not support {}.\".format(spatial_algorithm)\n    else:\n        spectrogram_est = ipsdta(spectrogram_mix, n_iter=n_iter, **reset_kwargs)\n\n        assert spectrogram_mix.shape == spectrogram_est.shape\n        assert type(ipsdta.loss[-1]) is float\n\n    print(ipsdta)\n"
  },
  {
    "path": "tests/package/bss/test_iterative_methods.py",
    "content": "import numpy as np\n\nfrom ssspy.bss.base import IterativeMethodBase\nfrom ssspy.bss.cacgmm import CACGMM\nfrom ssspy.bss.fdica import (\n    AuxFDICA,\n    AuxLaplaceFDICA,\n    GradFDICA,\n    GradLaplaceFDICA,\n    NaturalGradFDICA,\n    NaturalGradLaplaceFDICA,\n)\nfrom ssspy.bss.ica import FastICA, GradICA, GradLaplaceICA, NaturalGradICA, NaturalGradLaplaceICA\nfrom ssspy.bss.ilrma import GGDILRMA, TILRMA, GaussILRMA\nfrom ssspy.bss.ipsdta import TIPSDTA, GaussIPSDTA\nfrom ssspy.bss.iva import (\n    PDSIVA,\n    AuxGaussIVA,\n    AuxIVA,\n    AuxLaplaceIVA,\n    FasterIVA,\n    FastIVA,\n    GradGaussIVA,\n    GradIVA,\n    GradLaplaceIVA,\n    NaturalGradGaussIVA,\n    NaturalGradIVA,\n    NaturalGradLaplaceIVA,\n)\nfrom ssspy.bss.mnmf import FastGaussMNMF, GaussMNMF\nfrom ssspy.bss.pdsbss import PDSBSS\n\n\ndef test_grad_ica_inheritance() -> None:\n    def contrast_fn(x):\n        return np.log(1 + np.exp(x))\n\n    def score_fn(x):\n        return 1 / (1 + np.exp(-x))\n\n    ica = GradICA(contrast_fn=contrast_fn, score_fn=score_fn)\n\n    assert isinstance(ica, IterativeMethodBase)\n\n    ica = GradLaplaceICA()\n\n    assert isinstance(ica, IterativeMethodBase)\n\n\ndef test_natural_grad_ica_inheritance() -> None:\n    def contrast_fn(x):\n        return np.log(1 + np.exp(x))\n\n    def score_fn(x):\n        return 1 / (1 + np.exp(-x))\n\n    ica = NaturalGradICA(contrast_fn=contrast_fn, score_fn=score_fn)\n\n    assert isinstance(ica, IterativeMethodBase)\n\n    ica = NaturalGradLaplaceICA()\n\n    assert isinstance(ica, IterativeMethodBase)\n\n\ndef test_fast_ica_inheritance() -> None:\n    def contrast_fn(x):\n        return np.log(1 + np.exp(x))\n\n    def score_fn(x):\n        return 1 / (1 + np.exp(-x))\n\n    def d_score_fn(x):\n        s = 1 / (1 + np.exp(-x))\n        return s * (1 - s)\n\n    ica = FastICA(contrast_fn=contrast_fn, score_fn=score_fn, d_score_fn=d_score_fn)\n\n    assert isinstance(ica, IterativeMethodBase)\n\n\ndef test_grad_fdica_inheritance() -> None:\n    def contrast_fn(y):\n        return 2 * np.abs(y)\n\n    def score_fn(y):\n        denominator = np.maximum(np.abs(y), 1e-10)\n        return y / denominator\n\n    fdica = GradFDICA(contrast_fn=contrast_fn, score_fn=score_fn)\n\n    assert isinstance(fdica, IterativeMethodBase)\n\n    fdica = GradLaplaceFDICA()\n\n    assert isinstance(fdica, IterativeMethodBase)\n\n\ndef test_natural_grad_fdica_inheritance() -> None:\n    def contrast_fn(y):\n        return 2 * np.abs(y)\n\n    def score_fn(y):\n        denominator = np.maximum(np.abs(y), 1e-10)\n        return y / denominator\n\n    fdica = NaturalGradFDICA(contrast_fn=contrast_fn, score_fn=score_fn)\n\n    assert isinstance(fdica, IterativeMethodBase)\n\n    fdica = NaturalGradLaplaceFDICA()\n\n    assert isinstance(fdica, IterativeMethodBase)\n\n\ndef test_aux_fdica_inheritance() -> None:\n    def contrast_fn(y):\n        return 2 * np.abs(y)\n\n    def d_contrast_fn(y):\n        return 2 * np.ones_like(y)\n\n    fdica = AuxFDICA(\n        contrast_fn=contrast_fn,\n        d_contrast_fn=d_contrast_fn,\n    )\n\n    assert isinstance(fdica, IterativeMethodBase)\n\n    fdica = AuxLaplaceFDICA()\n\n    assert isinstance(fdica, IterativeMethodBase)\n\n\ndef test_grad_iva_inheritance() -> None:\n    def contrast_fn(y: np.ndarray) -> np.ndarray:\n        r\"\"\"Contrast function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_bins, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_frames).\n        \"\"\"\n        return 2 * np.linalg.norm(y, axis=1)\n\n    def score_fn(y) -> np.ndarray:\n        r\"\"\"Score function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_bins, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_bins, n_frames).\n        \"\"\"\n        norm = np.linalg.norm(y, axis=1, keepdims=True)\n        norm = np.maximum(norm, 1e-10)\n        return y / norm\n\n    iva = GradIVA(contrast_fn=contrast_fn, score_fn=score_fn)\n\n    assert isinstance(iva, IterativeMethodBase)\n\n    iva = GradLaplaceIVA()\n\n    assert isinstance(iva, IterativeMethodBase)\n\n    iva = GradGaussIVA()\n\n    assert isinstance(iva, IterativeMethodBase)\n\n\ndef test_natural_grad_iva_inheritance() -> None:\n    def contrast_fn(y: np.ndarray) -> np.ndarray:\n        r\"\"\"Contrast function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_bins, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_frames).\n        \"\"\"\n        return 2 * np.linalg.norm(y, axis=1)\n\n    def score_fn(y):\n        r\"\"\"Score function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_bins, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_bins, n_frames).\n        \"\"\"\n        norm = np.linalg.norm(y, axis=1, keepdims=True)\n        norm = np.maximum(norm, 1e-10)\n        return y / norm\n\n    iva = NaturalGradIVA(contrast_fn=contrast_fn, score_fn=score_fn)\n\n    assert isinstance(iva, IterativeMethodBase)\n\n    iva = NaturalGradLaplaceIVA()\n\n    assert isinstance(iva, IterativeMethodBase)\n\n    iva = NaturalGradGaussIVA()\n\n    assert isinstance(iva, IterativeMethodBase)\n\n\ndef test_fast_iva_inheritance() -> None:\n    def contrast_fn(y: np.ndarray) -> np.ndarray:\n        r\"\"\"Contrast function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_bins, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_frames).\n        \"\"\"\n        return 2 * np.linalg.norm(y, axis=1)\n\n    def d_contrast_fn(y) -> np.ndarray:\n        r\"\"\"Derivative of contrast function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_frames).\n        \"\"\"\n        return 2 * np.ones_like(y)\n\n    def dd_contrast_fn(y) -> np.ndarray:\n        r\"\"\"Second order derivative of contrast function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_frames).\n        \"\"\"\n        return 2 * np.zeros_like(y)\n\n    iva = FastIVA(\n        contrast_fn=contrast_fn,\n        d_contrast_fn=d_contrast_fn,\n        dd_contrast_fn=dd_contrast_fn,\n    )\n\n    assert isinstance(iva, IterativeMethodBase)\n\n\ndef test_faster_iva_inheritance() -> None:\n    def contrast_fn(y: np.ndarray) -> np.ndarray:\n        r\"\"\"Contrast function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_bins, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_frames).\n        \"\"\"\n        return 2 * np.linalg.norm(y, axis=1)\n\n    def d_contrast_fn(y) -> np.ndarray:\n        r\"\"\"Derivative of contrast function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_frames).\n        \"\"\"\n        return 2 * np.ones_like(y)\n\n    iva = FasterIVA(contrast_fn=contrast_fn, d_contrast_fn=d_contrast_fn)\n\n    assert isinstance(iva, IterativeMethodBase)\n\n\ndef test_aux_iva_inheritance() -> None:\n    def contrast_fn(y: np.ndarray) -> np.ndarray:\n        r\"\"\"Contrast function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_bins, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_frames).\n        \"\"\"\n        return 2 * np.linalg.norm(y, axis=1)\n\n    def d_contrast_fn(y) -> np.ndarray:\n        r\"\"\"Derivative of contrast function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_frames).\n        \"\"\"\n        return 2 * np.ones_like(y)\n\n    iva = AuxIVA(\n        contrast_fn=contrast_fn,\n        d_contrast_fn=d_contrast_fn,\n    )\n\n    assert isinstance(iva, IterativeMethodBase)\n\n    iva = AuxLaplaceIVA()\n\n    assert isinstance(iva, IterativeMethodBase)\n\n    iva = AuxGaussIVA()\n\n    assert isinstance(iva, IterativeMethodBase)\n\n\ndef test_pds_iva_inheritance() -> None:\n    iva = PDSIVA(\n        contrast_fn=None,\n        prox_penalty=None,\n    )\n\n    assert isinstance(iva, IterativeMethodBase)\n\n\ndef test_ilrma_inheritance() -> None:\n    n_basis = 2\n\n    ilrma = GaussILRMA(n_basis=n_basis)\n\n    assert isinstance(ilrma, IterativeMethodBase)\n\n    ilrma = TILRMA(n_basis=n_basis, dof=1000)\n\n    assert isinstance(ilrma, IterativeMethodBase)\n\n    ilrma = GGDILRMA(n_basis=n_basis, beta=1.95)\n\n    assert isinstance(ilrma, IterativeMethodBase)\n\n\ndef test_ipsdta_inheritance() -> None:\n    n_basis = 2\n    n_blocks = 2\n\n    ipsdta = GaussIPSDTA(n_basis=n_basis, n_blocks=n_blocks)\n\n    assert isinstance(ipsdta, IterativeMethodBase)\n\n    ipsdta = TIPSDTA(n_basis=n_basis, n_blocks=n_blocks, dof=1000)\n\n    assert isinstance(ipsdta, IterativeMethodBase)\n\n\ndef test_mnmf_inheritance() -> None:\n    n_basis = 2\n\n    mnmf = GaussMNMF(n_basis=n_basis)\n\n    assert isinstance(mnmf, IterativeMethodBase)\n\n    mnmf = FastGaussMNMF(n_basis=n_basis)\n\n    assert isinstance(mnmf, IterativeMethodBase)\n\n\ndef test_pdsbss_inheritance() -> None:\n    def contrast_fn(y: np.ndarray) -> np.ndarray:\n        r\"\"\"Contrast function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_bins, n_frames).\n\n        Returns:\n            np.ndarray of the shape is (n_sources, n_frames).\n        \"\"\"\n        return 2 * np.linalg.norm(y, axis=1)\n\n    def penalty_fn(y: np.ndarray) -> float:\n        loss = contrast_fn(y)\n        loss = np.sum(loss.mean(axis=-1))\n        return loss\n\n    def prox_penalty(y: np.ndarray, step_size: float = 1) -> np.ndarray:\n        r\"\"\"Proximal operator of penalty function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_bins, n_frames).\n            step_size (float):\n                Step size. Default: 1.\n\n        Returns:\n            np.ndarray of the shape is (n_sources, n_bins, n_frames).\n        \"\"\"\n        norm = np.linalg.norm(y, axis=1, keepdims=True)\n        return y * np.maximum(1 - step_size / norm, 0)\n\n    pdsbss = PDSBSS(penalty_fn=penalty_fn, prox_penalty=prox_penalty)\n\n    assert isinstance(pdsbss, IterativeMethodBase)\n\n\ndef test_cacgmm_inheritance() -> None:\n    cacgmm = CACGMM()\n\n    assert isinstance(cacgmm, IterativeMethodBase)\n"
  },
  {
    "path": "tests/package/bss/test_iva.py",
    "content": "from typing import Any, Callable, Dict, List, Optional, Union\n\nimport numpy as np\nimport pytest\nimport scipy.signal as ss\nfrom dummy.callback import DummyCallback, dummy_function\nfrom dummy.utils.dataset import download_sample_speech_data\n\nfrom ssspy.bss.iva import (\n    PDSIVA,\n    AuxGaussIVA,\n    AuxIVA,\n    AuxIVABase,\n    AuxLaplaceIVA,\n    FasterIVA,\n    FastIVA,\n    FastIVABase,\n    GradGaussIVA,\n    GradIVA,\n    GradIVABase,\n    GradLaplaceIVA,\n    IVABase,\n    NaturalGradGaussIVA,\n    NaturalGradIVA,\n    NaturalGradLaplaceIVA,\n)\n\nmax_duration = 0.5\nn_fft = 512\nhop_length = 256\nn_bins = n_fft // 2 + 1\nn_iter = 3\n\nparameters_spatial_algorithm = [\"IP\", \"IP1\", \"IP2\", \"ISS\", \"ISS1\", \"ISS2\", \"IPA\"]\nparameters_callbacks = [None, dummy_function, [DummyCallback(), dummy_function]]\nparameters_is_holonomic = [True, False]\nparameters_scale_restoration = [True, False, \"projection_back\", \"minimal_distortion_principle\"]\nparameters_grad_iva = [\n    (2, {}),\n    (\n        3,\n        {\"demix_filter\": np.tile(-np.eye(3, dtype=np.complex128), reps=(n_bins, 1, 1))},\n    ),\n]\nparameters_fast_iva = [\n    (2, \"dev1_female3\", {}),\n    (\n        3,\n        \"dev1_female3\",\n        {\"demix_filter\": np.tile(-np.eye(3, dtype=np.complex128), reps=(n_bins, 1, 1))},\n    ),\n    (2, \"dev1_female3\", {\"demix_filter\": None}),\n]\nparameters_aux_iva = [\n    (2, \"dev1_female3\", {}),\n    (\n        3,\n        \"dev1_female3\",\n        {\"demix_filter\": np.tile(-np.eye(3, dtype=np.complex128), reps=(n_bins, 1, 1))},\n    ),\n    (2, \"dev1_female3\", {\"demix_filter\": None}),\n    (\n        3,\n        \"dev1_female3\",\n        {\"demix_filter\": np.tile(-np.eye(3, dtype=np.complex128), reps=(n_bins, 1, 1))},\n    ),\n    (4, \"dev1_female4\", {\"demix_filter\": None}),\n]\nparameters_pds_iva = [\n    (2, \"dev1_female3\", {}),\n    (\n        3,\n        \"dev1_female3\",\n        {\"demix_filter\": np.tile(-np.eye(3, dtype=np.complex128), reps=(n_bins, 1, 1))},\n    ),\n    (4, \"dev1_female4\", {}),\n]\n\n\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\ndef test_iva_base(\n    callbacks: Optional[Union[Callable[[AuxIVA], None], List[Callable[[AuxIVA], None]]]],\n):\n    iva = IVABase(callbacks=callbacks)\n\n    print(iva)\n\n\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\ndef test_fast_iva_base(\n    callbacks: Optional[Union[Callable[[AuxIVA], None], List[Callable[[AuxIVA], None]]]],\n):\n    np.random.seed(111)\n\n    iva = FastIVABase(callbacks=callbacks)\n\n    print(iva)\n\n\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"is_holonomic\", parameters_is_holonomic)\ndef test_grad_iva_base(\n    callbacks: Optional[Union[Callable[[GradIVA], None], List[Callable[[GradIVA], None]]]],\n    is_holonomic: bool,\n):\n    np.random.seed(111)\n\n    def contrast_fn(y: np.ndarray) -> np.ndarray:\n        r\"\"\"Contrast function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_bins, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_frames).\n        \"\"\"\n        return 2 * np.linalg.norm(y, axis=1)\n\n    def score_fn(y) -> np.ndarray:\n        r\"\"\"Score function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_bins, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_bins, n_frames).\n        \"\"\"\n        norm = np.linalg.norm(y, axis=1, keepdims=True)\n        norm = np.maximum(norm, 1e-10)\n        return y / norm\n\n    iva = GradIVABase(\n        contrast_fn=contrast_fn, score_fn=score_fn, callbacks=callbacks, is_holonomic=is_holonomic\n    )\n\n    print(iva)\n\n\n@pytest.mark.parametrize(\"n_sources, reset_kwargs\", parameters_grad_iva)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"is_holonomic\", parameters_is_holonomic)\ndef test_grad_iva(\n    n_sources: int,\n    callbacks: Optional[Union[Callable[[GradIVA], None], List[Callable[[GradIVA], None]]]],\n    is_holonomic: bool,\n    reset_kwargs: Dict[Any, Any],\n):\n    np.random.seed(111)\n\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=\"dev1_female3\",\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    def contrast_fn(y: np.ndarray) -> np.ndarray:\n        r\"\"\"Contrast function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_bins, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_frames).\n        \"\"\"\n        return 2 * np.linalg.norm(y, axis=1)\n\n    def score_fn(y) -> np.ndarray:\n        r\"\"\"Score function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_bins, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_bins, n_frames).\n        \"\"\"\n        norm = np.linalg.norm(y, axis=1, keepdims=True)\n        norm = np.maximum(norm, 1e-10)\n        return y / norm\n\n    iva = GradIVA(\n        contrast_fn=contrast_fn, score_fn=score_fn, callbacks=callbacks, is_holonomic=is_holonomic\n    )\n    spectrogram_est = iva(spectrogram_mix, n_iter=n_iter)\n\n    assert spectrogram_mix.shape == spectrogram_est.shape\n    assert type(iva.loss[-1]) is float\n\n    print(iva)\n\n\n@pytest.mark.parametrize(\"n_sources, reset_kwargs\", parameters_grad_iva)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"is_holonomic\", parameters_is_holonomic)\ndef test_natural_grad_iva(\n    n_sources: int,\n    callbacks: Optional[\n        Union[Callable[[NaturalGradIVA], None], List[Callable[[NaturalGradIVA], None]]]\n    ],\n    is_holonomic: bool,\n    reset_kwargs: Dict[Any, Any],\n):\n    np.random.seed(111)\n\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=\"dev1_female3\",\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    def contrast_fn(y: np.ndarray) -> np.ndarray:\n        r\"\"\"Contrast function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_bins, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_frames).\n        \"\"\"\n        return 2 * np.linalg.norm(y, axis=1)\n\n    def score_fn(y):\n        r\"\"\"Score function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_bins, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_bins, n_frames).\n        \"\"\"\n        norm = np.linalg.norm(y, axis=1, keepdims=True)\n        norm = np.maximum(norm, 1e-10)\n        return y / norm\n\n    iva = NaturalGradIVA(\n        contrast_fn=contrast_fn, score_fn=score_fn, callbacks=callbacks, is_holonomic=is_holonomic\n    )\n    spectrogram_est = iva(spectrogram_mix, n_iter=n_iter)\n\n    assert spectrogram_mix.shape == spectrogram_est.shape\n    assert type(iva.loss[-1]) is float\n\n    print(iva)\n\n\n@pytest.mark.parametrize(\"n_sources, sisec2010_tag, reset_kwargs\", parameters_fast_iva)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\ndef test_fast_iva(\n    n_sources: int,\n    sisec2010_tag: str,\n    callbacks: Optional[Union[Callable[[AuxIVA], None], List[Callable[[AuxIVA], None]]]],\n    reset_kwargs: Dict[Any, Any],\n):\n    np.random.seed(111)\n\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=sisec2010_tag,\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    def contrast_fn(y: np.ndarray) -> np.ndarray:\n        r\"\"\"Contrast function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_bins, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_frames).\n        \"\"\"\n        return 2 * np.linalg.norm(y, axis=1)\n\n    def d_contrast_fn(y) -> np.ndarray:\n        r\"\"\"Derivative of contrast function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_frames).\n        \"\"\"\n        return 2 * np.ones_like(y)\n\n    def dd_contrast_fn(y) -> np.ndarray:\n        r\"\"\"Second order derivative of contrast function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_frames).\n        \"\"\"\n        return 2 * np.zeros_like(y)\n\n    iva = FastIVA(\n        contrast_fn=contrast_fn,\n        d_contrast_fn=d_contrast_fn,\n        dd_contrast_fn=dd_contrast_fn,\n        callbacks=callbacks,\n    )\n    spectrogram_est = iva(spectrogram_mix, n_iter=n_iter)\n\n    assert spectrogram_mix.shape == spectrogram_est.shape\n    assert type(iva.loss[-1]) is float\n\n    print(iva)\n\n\n@pytest.mark.parametrize(\"n_sources, sisec2010_tag, reset_kwargs\", parameters_fast_iva)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\ndef test_faster_iva(\n    n_sources: int,\n    sisec2010_tag: str,\n    callbacks: Optional[Union[Callable[[AuxIVA], None], List[Callable[[AuxIVA], None]]]],\n    reset_kwargs: Dict[Any, Any],\n):\n    np.random.seed(111)\n\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=sisec2010_tag,\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    def contrast_fn(y: np.ndarray) -> np.ndarray:\n        r\"\"\"Contrast function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_bins, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_frames).\n        \"\"\"\n        return 2 * np.linalg.norm(y, axis=1)\n\n    def d_contrast_fn(y) -> np.ndarray:\n        r\"\"\"Derivative of contrast function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_frames).\n        \"\"\"\n        return 2 * np.ones_like(y)\n\n    iva = FasterIVA(contrast_fn=contrast_fn, d_contrast_fn=d_contrast_fn, callbacks=callbacks)\n    spectrogram_est = iva(spectrogram_mix, n_iter=n_iter)\n\n    assert spectrogram_mix.shape == spectrogram_est.shape\n    assert type(iva.loss[-1]) is float\n\n    print(iva)\n\n\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"scale_restoration\", parameters_scale_restoration)\ndef test_aux_iva_base(\n    callbacks: Optional[Union[Callable[[AuxIVA], None], List[Callable[[AuxIVA], None]]]],\n    scale_restoration: Union[str, bool],\n):\n    np.random.seed(111)\n\n    def contrast_fn(y: np.ndarray) -> np.ndarray:\n        r\"\"\"Contrast function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_bins, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_frames).\n        \"\"\"\n        return 2 * np.linalg.norm(y, axis=1)\n\n    def d_contrast_fn(y) -> np.ndarray:\n        r\"\"\"Derivative of contrast function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_frames).\n        \"\"\"\n        return 2 * np.ones_like(y)\n\n    iva = AuxIVABase(\n        contrast_fn=contrast_fn,\n        d_contrast_fn=d_contrast_fn,\n        callbacks=callbacks,\n        scale_restoration=scale_restoration,\n    )\n\n    print(iva)\n\n\n@pytest.mark.parametrize(\"n_sources, sisec2010_tag, reset_kwargs\", parameters_aux_iva)\n@pytest.mark.parametrize(\"spatial_algorithm\", parameters_spatial_algorithm)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"scale_restoration\", parameters_scale_restoration)\ndef test_aux_iva(\n    n_sources: int,\n    sisec2010_tag: str,\n    spatial_algorithm: str,\n    callbacks: Optional[Union[Callable[[AuxIVA], None], List[Callable[[AuxIVA], None]]]],\n    scale_restoration: Union[str, bool],\n    reset_kwargs: Dict[Any, Any],\n):\n    if spatial_algorithm in [\"IP\", \"ISS\"] and not pytest.run_redundant:\n        pytest.skip(reason=\"Need --run-redundant option to run.\")\n\n    np.random.seed(111)\n\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=sisec2010_tag,\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    def contrast_fn(y: np.ndarray) -> np.ndarray:\n        r\"\"\"Contrast function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_bins, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_frames).\n        \"\"\"\n        return 2 * np.linalg.norm(y, axis=1)\n\n    def d_contrast_fn(y) -> np.ndarray:\n        r\"\"\"Derivative of contrast function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_frames).\n        \"\"\"\n        return 2 * np.ones_like(y)\n\n    iva = AuxIVA(\n        spatial_algorithm=spatial_algorithm,\n        contrast_fn=contrast_fn,\n        d_contrast_fn=d_contrast_fn,\n        callbacks=callbacks,\n        scale_restoration=scale_restoration,\n    )\n    spectrogram_est = iva(spectrogram_mix, n_iter=n_iter)\n\n    assert spectrogram_mix.shape == spectrogram_est.shape\n    assert type(iva.loss[-1]) is float\n\n    if spatial_algorithm in [\"ISS\", \"ISS1\", \"ISS2\"]:\n        assert iva.demix_filter is None\n\n    print(iva)\n\n\n@pytest.mark.parametrize(\"n_sources, sisec2010_tag, reset_kwargs\", parameters_pds_iva)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"scale_restoration\", parameters_scale_restoration)\ndef test_pds_iva(\n    n_sources: int,\n    sisec2010_tag: str,\n    callbacks: Optional[Union[Callable[[AuxIVA], None], List[Callable[[AuxIVA], None]]]],\n    scale_restoration: Union[str, bool],\n    reset_kwargs: Dict[Any, Any],\n):\n    np.random.seed(111)\n\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=sisec2010_tag,\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    iva = PDSIVA(\n        contrast_fn=None,\n        prox_penalty=None,\n        callbacks=callbacks,\n        scale_restoration=scale_restoration,\n    )\n    spectrogram_est = iva(spectrogram_mix, n_iter=n_iter, **reset_kwargs)\n\n    assert spectrogram_mix.shape == spectrogram_est.shape\n    assert type(iva.loss[-1]) is float\n\n    print(iva)\n\n\n@pytest.mark.parametrize(\"specify_contrast_fn\", [True, False])\ndef test_iva_insufficient_fn(specify_contrast_fn: bool):\n    def _contrast_fn(y: np.ndarray) -> np.ndarray:\n        return np.linalg.norm(y, axis=1)\n\n    def _prox_penalty(y: np.ndarray, step_size: float = 1) -> np.ndarray:\n        norm = np.linalg.norm(y, axis=1, keepdims=True)\n        return y * np.maximum(1 - step_size / norm, 0)\n\n    if specify_contrast_fn:\n        contrast_fn = _contrast_fn\n        prox_penalty = None\n    else:\n        contrast_fn = None\n        prox_penalty = _prox_penalty\n\n    with pytest.raises(ValueError) as e:\n        _ = PDSIVA(\n            contrast_fn=contrast_fn,\n            prox_penalty=prox_penalty,\n        )\n\n    if specify_contrast_fn:\n        assert str(e.value) == \"Set prox_penalty.\"\n    else:\n        assert str(e.value) == \"Set contrast_fn.\"\n\n\n@pytest.mark.parametrize(\"n_sources, reset_kwargs\", parameters_grad_iva)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"is_holonomic\", parameters_is_holonomic)\ndef test_grad_laplace_iva(\n    n_sources: int,\n    callbacks: Optional[\n        Union[Callable[[GradLaplaceIVA], None], List[Callable[[GradLaplaceIVA], None]]]\n    ],\n    is_holonomic: bool,\n    reset_kwargs: Dict[Any, Any],\n):\n    np.random.seed(111)\n\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=\"dev1_female3\",\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    iva = GradLaplaceIVA(callbacks=callbacks, is_holonomic=is_holonomic)\n    spectrogram_est = iva(spectrogram_mix, n_iter=n_iter)\n\n    assert spectrogram_mix.shape == spectrogram_est.shape\n    assert type(iva.loss[-1]) is float\n\n    print(iva)\n\n\n@pytest.mark.parametrize(\"n_sources, reset_kwargs\", parameters_grad_iva)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"is_holonomic\", parameters_is_holonomic)\ndef test_grad_gauss_iva(\n    n_sources: int,\n    callbacks: Optional[\n        Union[Callable[[GradGaussIVA], None], List[Callable[[GradGaussIVA], None]]]\n    ],\n    is_holonomic: bool,\n    reset_kwargs: Dict[Any, Any],\n):\n    np.random.seed(111)\n\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=\"dev1_female3\",\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    iva = GradGaussIVA(callbacks=callbacks, is_holonomic=is_holonomic)\n    spectrogram_est = iva(spectrogram_mix, n_iter=n_iter)\n\n    assert spectrogram_mix.shape == spectrogram_est.shape\n    assert type(iva.loss[-1]) is float\n\n    print(iva)\n\n\n@pytest.mark.parametrize(\"n_sources, reset_kwargs\", parameters_grad_iva)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"is_holonomic\", parameters_is_holonomic)\ndef test_natural_grad_laplace_iva(\n    n_sources: int,\n    callbacks: Optional[\n        Union[\n            Callable[[NaturalGradLaplaceIVA], None], List[Callable[[NaturalGradLaplaceIVA], None]]\n        ]\n    ],\n    is_holonomic: bool,\n    reset_kwargs: Dict[Any, Any],\n):\n    np.random.seed(111)\n\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=\"dev1_female3\",\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    iva = NaturalGradLaplaceIVA(callbacks=callbacks, is_holonomic=is_holonomic)\n    spectrogram_est = iva(spectrogram_mix, n_iter=n_iter)\n\n    assert spectrogram_mix.shape == spectrogram_est.shape\n    assert type(iva.loss[-1]) is float\n\n    print(iva)\n\n\n@pytest.mark.parametrize(\"n_sources, reset_kwargs\", parameters_grad_iva)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"is_holonomic\", parameters_is_holonomic)\ndef test_natural_grad_gauss_iva(\n    n_sources: int,\n    callbacks: Optional[\n        Union[Callable[[NaturalGradGaussIVA], None], List[Callable[[NaturalGradGaussIVA], None]]]\n    ],\n    is_holonomic: bool,\n    reset_kwargs: Dict[Any, Any],\n):\n    np.random.seed(111)\n\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=\"dev1_female3\",\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    iva = NaturalGradGaussIVA(callbacks=callbacks, is_holonomic=is_holonomic)\n    spectrogram_est = iva(spectrogram_mix, n_iter=n_iter)\n\n    assert spectrogram_mix.shape == spectrogram_est.shape\n    assert type(iva.loss[-1]) is float\n\n    print(iva)\n\n\n@pytest.mark.parametrize(\"n_sources, sisec2010_tag, reset_kwargs\", parameters_aux_iva)\n@pytest.mark.parametrize(\"spatial_algorithm\", parameters_spatial_algorithm)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"scale_restoration\", parameters_scale_restoration)\ndef test_aux_laplace_iva(\n    n_sources: int,\n    sisec2010_tag: str,\n    spatial_algorithm: str,\n    callbacks: Optional[\n        Union[\n            Callable[[NaturalGradLaplaceIVA], None], List[Callable[[NaturalGradLaplaceIVA], None]]\n        ]\n    ],\n    scale_restoration: Union[str, bool],\n    reset_kwargs: Dict[Any, Any],\n):\n    if spatial_algorithm in [\"IP\", \"ISS\"] and not pytest.run_redundant:\n        pytest.skip(reason=\"Need --run-redundant option to run.\")\n\n    np.random.seed(111)\n\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=sisec2010_tag,\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    iva = AuxLaplaceIVA(\n        spatial_algorithm=spatial_algorithm,\n        callbacks=callbacks,\n        scale_restoration=scale_restoration,\n    )\n    spectrogram_est = iva(spectrogram_mix, n_iter=n_iter)\n\n    assert spectrogram_mix.shape == spectrogram_est.shape\n    assert type(iva.loss[-1]) is float\n\n    if spatial_algorithm in [\"ISS\", \"ISS1\", \"ISS2\"]:\n        assert iva.demix_filter is None\n\n    print(iva)\n\n\n@pytest.mark.parametrize(\"n_sources, sisec2010_tag, reset_kwargs\", parameters_aux_iva)\n@pytest.mark.parametrize(\"spatial_algorithm\", parameters_spatial_algorithm)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"scale_restoration\", parameters_scale_restoration)\ndef test_aux_gauss_iva(\n    n_sources: int,\n    sisec2010_tag: str,\n    spatial_algorithm: str,\n    callbacks: Optional[\n        Union[Callable[[NaturalGradGaussIVA], None], List[Callable[[NaturalGradGaussIVA], None]]]\n    ],\n    scale_restoration: Union[str, bool],\n    reset_kwargs: Dict[Any, Any],\n):\n    if spatial_algorithm in [\"IP\", \"ISS\"] and not pytest.run_redundant:\n        pytest.skip(reason=\"Need --run-redundant option to run.\")\n\n    np.random.seed(111)\n\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=sisec2010_tag,\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    iva = AuxGaussIVA(\n        spatial_algorithm=spatial_algorithm,\n        callbacks=callbacks,\n        scale_restoration=scale_restoration,\n    )\n    spectrogram_est = iva(spectrogram_mix, n_iter=n_iter)\n\n    assert spectrogram_mix.shape == spectrogram_est.shape\n    assert type(iva.loss[-1]) is float\n\n    if spatial_algorithm in [\"ISS\", \"ISS1\", \"ISS2\"]:\n        assert iva.demix_filter is None\n\n    print(iva)\n"
  },
  {
    "path": "tests/package/bss/test_mnmf.py",
    "content": "from typing import Any, Callable, Dict, List, Optional, Union\n\nimport numpy as np\nimport pytest\nimport scipy.signal as ss\nfrom dummy.callback import DummyCallback, dummy_function\nfrom dummy.utils.dataset import download_sample_speech_data\n\nfrom ssspy.bss.mnmf import FastGaussMNMF, FastMNMFBase, GaussMNMF, MNMFBase\n\nmax_duration = 0.1\nn_fft = 256\nhop_length = 128\nwindow = \"hann\"\nn_bins = n_fft // 2 + 1\nn_iter = 3\nrng = np.random.default_rng(42)\n\n\nparameters_diagonalizer_algorithm = [\"IP\", \"IP1\", \"IP2\"]\nparameters_partitioning = [True, False]\nparameters_callbacks = [None, dummy_function, [DummyCallback(), dummy_function]]\nparameters_normalization = [True, False]\nparameters_mnmf_base = [2]\nparameters_mnmf = [\n    (2, 2, 2, {}),\n    (3, 2, 3, {}),\n]\n\n\n@pytest.mark.parametrize(\n    \"n_basis\",\n    parameters_mnmf_base,\n)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\ndef test_mnmf_base(\n    n_basis: int,\n    callbacks: Optional[Union[Callable[[MNMFBase], None], List[Callable[[MNMFBase], None]]]],\n):\n    ipsdta = MNMFBase(\n        n_basis,\n        callbacks=callbacks,\n        record_loss=False,\n        rng=rng,\n    )\n\n    print(ipsdta)\n\n\n@pytest.mark.parametrize(\n    \"n_basis\",\n    parameters_mnmf_base,\n)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\ndef test_fast_mnmf_base(\n    n_basis: int,\n    callbacks: Optional[\n        Union[Callable[[FastMNMFBase], None], List[Callable[[FastMNMFBase], None]]]\n    ],\n):\n    ipsdta = FastMNMFBase(\n        n_basis,\n        callbacks=callbacks,\n        record_loss=False,\n        rng=rng,\n    )\n\n    print(ipsdta)\n\n\n@pytest.mark.parametrize(\n    \"n_sources, n_channels, n_basis, reset_kwargs\",\n    parameters_mnmf,\n)\n@pytest.mark.parametrize(\n    \"partitioning\",\n    parameters_partitioning,\n)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"normalization\", parameters_normalization)\ndef test_gauss_mnmf(\n    n_sources: int,\n    n_channels: int,\n    n_basis: int,\n    partitioning: bool,\n    callbacks: Optional[Union[Callable[[GaussMNMF], None], List[Callable[[GaussMNMF], None]]]],\n    normalization: Optional[Union[str, bool]],\n    reset_kwargs: Dict[str, Any],\n):\n    if n_sources < 4:\n        sisec2010_tag = \"dev1_female3\"\n    elif n_sources == 4:\n        sisec2010_tag = \"dev1_female4\"\n    else:\n        raise ValueError(\"n_sources should be less than 5.\")\n\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=sisec2010_tag,\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img[:n_channels], axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=window, nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    mnmf = GaussMNMF(\n        n_basis,\n        n_sources=n_sources,\n        partitioning=partitioning,\n        callbacks=callbacks,\n        normalization=normalization,\n        rng=rng,\n    )\n\n    spectrogram_est = mnmf(spectrogram_mix, n_iter=n_iter, **reset_kwargs)\n\n    assert spectrogram_est.shape == (n_sources,) + spectrogram_mix.shape[1:]\n    assert type(mnmf.loss[-1]) is float\n\n    print(mnmf)\n\n\n@pytest.mark.parametrize(\n    \"n_sources, n_channels, n_basis, reset_kwargs\",\n    parameters_mnmf,\n)\n@pytest.mark.parametrize(\"diagonalizer_algorithm\", parameters_diagonalizer_algorithm)\n@pytest.mark.parametrize(\n    \"partitioning\",\n    parameters_partitioning,\n)\n@pytest.mark.parametrize(\"callbacks\", parameters_callbacks)\n@pytest.mark.parametrize(\"normalization\", parameters_normalization)\ndef test_fast_gauss_mnmf(\n    n_sources: int,\n    n_channels: int,\n    n_basis: int,\n    diagonalizer_algorithm: str,\n    partitioning: bool,\n    callbacks: Optional[Union[Callable[[GaussMNMF], None], List[Callable[[GaussMNMF], None]]]],\n    normalization: Optional[Union[str, bool]],\n    reset_kwargs: Dict[str, Any],\n):\n    if diagonalizer_algorithm in [\"IP\"] and not pytest.run_redundant:\n        pytest.skip(reason=\"Need --run-redundant option to run.\")\n\n    if n_sources < 4:\n        sisec2010_tag = \"dev1_female3\"\n    elif n_sources == 4:\n        sisec2010_tag = \"dev1_female4\"\n    else:\n        raise ValueError(\"n_sources should be less than 5.\")\n\n    waveform_src_img, _ = download_sample_speech_data(\n        n_sources=n_sources,\n        sisec2010_tag=sisec2010_tag,\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img[:n_channels], axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=window, nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    if partitioning:\n        with pytest.raises(AssertionError) as e:\n            mnmf = FastGaussMNMF(\n                n_basis,\n                n_sources=n_sources,\n                diagonalizer_algorithm=diagonalizer_algorithm,\n                partitioning=partitioning,\n                callbacks=callbacks,\n                normalization=normalization,\n                rng=rng,\n            )\n\n        assert str(e.value) == \"partitioning function is not supported.\"\n    else:\n        mnmf = FastGaussMNMF(\n            n_basis,\n            n_sources=n_sources,\n            diagonalizer_algorithm=diagonalizer_algorithm,\n            partitioning=partitioning,\n            callbacks=callbacks,\n            normalization=normalization,\n            rng=rng,\n        )\n\n        spectrogram_est = mnmf(spectrogram_mix, n_iter=n_iter, **reset_kwargs)\n\n        assert spectrogram_est.shape == (n_sources,) + spectrogram_mix.shape[1:]\n        assert type(mnmf.loss[-1]) is float\n\n        print(mnmf)\n"
  },
  {
    "path": "tests/package/bss/test_pair_selector.py",
    "content": "import pytest\n\nfrom ssspy.bss._select_pair import combination_pair_selector, sequential_pair_selector\n\nparameters_n_sources = [4]\nparameters_step = [1, 2]\nparameters_ascend = [True, False]\n\n\n@pytest.mark.parametrize(\"n_sources\", parameters_n_sources)\n@pytest.mark.parametrize(\"step\", parameters_step)\n@pytest.mark.parametrize(\"ascend\", parameters_ascend)\ndef test_sequential_pair_selector(n_sources: int, step: int, ascend: bool):\n    with pytest.warns(UserWarning) as record:\n        for m, n in sequential_pair_selector(n_sources, step=step, sort=ascend):\n            if ascend:\n                assert m < n\n\n    assert len(record) == 1\n    assert str(record[0].message) == \"Use ssspy.utils.select_pair.sequential_pair_selector instead.\"\n\n\n@pytest.mark.parametrize(\"n_sources\", parameters_n_sources)\n@pytest.mark.parametrize(\"ascend\", parameters_ascend)\ndef test_combination_pair_selector(n_sources: int, ascend: bool):\n    with pytest.warns(UserWarning) as record:\n        for m, n in combination_pair_selector(n_sources, sort=ascend):\n            if ascend:\n                assert m < n\n\n    assert len(record) == 1\n    assert (\n        str(record[0].message) == \"Use ssspy.utils.select_pair.combination_pair_selector instead.\"\n    )\n"
  },
  {
    "path": "tests/package/bss/test_pdsbss.py",
    "content": "import functools\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport numpy as np\nimport pytest\nimport scipy.signal as ss\nfrom dummy.callback import DummyCallback, dummy_function\nfrom dummy.utils.dataset import download_sample_speech_data\n\nfrom ssspy.bss.pdsbss import PDSBSS, MaskingPDSBSS, PDSBSSBase\n\nmax_duration = 0.5\nn_fft = 2048\nhop_length = 1024\nn_bins = n_fft // 2 + 1\nn_iter = 5\n\nparameters_pdsbss = [\n    (2, None, {}),\n    (\n        3,\n        dummy_function,\n        {\"demix_filter\": np.tile(-np.eye(3, dtype=np.complex128), reps=(n_bins, 1, 1))},\n    ),\n    (2, [DummyCallback(), dummy_function], {}),\n]\nparameters_set_panalty_fn = [True, False]\n\n\ndef contrast_fn(y: np.ndarray) -> np.ndarray:\n    r\"\"\"Contrast function.\n\n    Args:\n        y (np.ndarray):\n            The shape is (n_sources, n_bins, n_frames).\n\n    Returns:\n        np.ndarray of the shape is (n_sources, n_frames).\n    \"\"\"\n    return 2 * np.linalg.norm(y, axis=1)\n\n\ndef penalty_fn(y: np.ndarray) -> float:\n    loss = contrast_fn(y)\n    loss = np.sum(loss.mean(axis=-1))\n    return loss\n\n\ndef prox_penalty(y: np.ndarray, step_size: float = 1) -> np.ndarray:\n    r\"\"\"Proximal operator of penalty function.\n\n    Args:\n        y (np.ndarray):\n            The shape is (n_sources, n_bins, n_frames).\n        step_size (float):\n            Step size. Default: 1.\n\n    Returns:\n        np.ndarray of the shape is (n_sources, n_bins, n_frames).\n    \"\"\"\n    norm = np.linalg.norm(y, axis=1, keepdims=True)\n    return y * np.maximum(1 - step_size / norm, 0)\n\n\ndef mask_fn(y: np.ndarray, step_size: float = 1) -> np.ndarray:\n    r\"\"\"Masking function.\n\n    Args:\n        y (np.ndarray):\n            The shape is (n_sources, n_bins, n_frames).\n        step_size (float):\n            Step size. Default: 1.\n\n    Returns:\n        np.ndarray of the shape is (n_sources, n_bins, n_frames).\n    \"\"\"\n    norm = np.linalg.norm(y, axis=1, keepdims=True)\n    mask = np.maximum(1 - step_size / norm, 0)\n    mask = np.tile(mask, (1, y.shape[1], 1))\n\n    return mask\n\n\ndef test_pds_base():\n    pdsbss = PDSBSSBase(penalty_fn=penalty_fn, prox_penalty=prox_penalty)\n\n    print(pdsbss)\n\n\n@pytest.mark.parametrize(\"n_sources, callbacks, reset_kwargs\", parameters_pdsbss)\n@pytest.mark.parametrize(\"set_panalty_fn\", parameters_set_panalty_fn)\ndef test_pdsbss(\n    n_sources: int,\n    callbacks: Optional[Union[Callable[[PDSBSS], None], List[Callable[[PDSBSS], None]]]],\n    reset_kwargs: Dict[Any, Any],\n    set_panalty_fn: bool,\n):\n    np.random.seed(111)\n\n    waveform_src_img, _ = download_sample_speech_data(\n        sisec2010_root=\"./tests/.data/SiSEC2010\",\n        mird_root=\"./tests/.data/MIRD\",\n        n_sources=n_sources,\n        sisec2010_tag=\"dev1_female3\",\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    if set_panalty_fn:\n        pdsbss = PDSBSS(penalty_fn=penalty_fn, prox_penalty=prox_penalty, callbacks=callbacks)\n    else:\n        pdsbss = PDSBSS(prox_penalty=prox_penalty, callbacks=callbacks)\n\n    spectrogram_mix_normalized = pdsbss.normalize_by_spectral_norm(spectrogram_mix)\n    spectrogram_est = pdsbss(spectrogram_mix_normalized, n_iter=n_iter, **reset_kwargs)\n\n    assert spectrogram_mix.shape == spectrogram_est.shape\n\n    print(pdsbss)\n\n\n@pytest.mark.parametrize(\"n_sources, callbacks, reset_kwargs\", parameters_pdsbss)\ndef test_masking_pdsbss(\n    n_sources: int,\n    callbacks: Optional[\n        Union[Callable[[MaskingPDSBSS], None], List[Callable[[MaskingPDSBSS], None]]]\n    ],\n    reset_kwargs: Dict[Any, Any],\n):\n    np.random.seed(111)\n\n    waveform_src_img, _ = download_sample_speech_data(\n        sisec2010_root=\"./tests/.data/SiSEC2010\",\n        mird_root=\"./tests/.data/MIRD\",\n        n_sources=n_sources,\n        sisec2010_tag=\"dev1_female3\",\n        max_duration=max_duration,\n        conv=True,\n    )\n    waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)\n\n    _, _, spectrogram_mix = ss.stft(\n        waveform_mix, window=\"hann\", nperseg=n_fft, noverlap=n_fft - hop_length\n    )\n\n    pdsbss = MaskingPDSBSS(mask_fn=functools.partial(mask_fn, step_size=1), callbacks=callbacks)\n\n    spectrogram_mix_normalized = pdsbss.normalize_by_spectral_norm(spectrogram_mix)\n    spectrogram_est = pdsbss(spectrogram_mix_normalized, n_iter=n_iter, **reset_kwargs)\n\n    assert spectrogram_mix.shape == spectrogram_est.shape\n\n    print(pdsbss)\n"
  },
  {
    "path": "tests/package/bss/test_proxbss.py",
    "content": "import numpy as np\n\nfrom ssspy.bss.proxbss import ProxBSSBase\n\n\ndef contrast_fn(y: np.ndarray) -> np.ndarray:\n    r\"\"\"Contrast function.\n\n    Args:\n        y (np.ndarray):\n            The shape is (n_sources, n_bins, n_frames).\n\n    Returns:\n        np.ndarray of the shape is (n_sources, n_frames).\n    \"\"\"\n    return 2 * np.linalg.norm(y, axis=1)\n\n\ndef penalty_fn(y: np.ndarray) -> float:\n    loss = contrast_fn(y)\n    loss = np.sum(loss.mean(axis=-1))\n    return loss\n\n\ndef prox_penalty(y: np.ndarray, step_size: float = 1) -> np.ndarray:\n    r\"\"\"Proximal operator of penalty function.\n\n    Args:\n        y (np.ndarray):\n            The shape is (n_sources, n_bins, n_frames).\n        step_size (float):\n            Step size. Default: 1.\n\n    Returns:\n        np.ndarray of the shape is (n_sources, n_bins, n_frames).\n    \"\"\"\n    norm = np.linalg.norm(y, axis=1, keepdims=True)\n    return y * np.maximum(1 - step_size / norm, 0)\n\n\ndef test_proxbss_base() -> None:\n    proxbss = ProxBSSBase(penalty_fn=penalty_fn, prox_penalty=prox_penalty)\n\n    print(proxbss)\n"
  },
  {
    "path": "tests/package/bss/test_psd_legacy.py",
    "content": "from typing import Tuple\n\nimport numpy as np\nimport pytest\n\nfrom ssspy.bss._psd import to_psd\nfrom ssspy.special import add_flooring\n\nrng = np.random.default_rng(42)\n\nparameters_shape = [(5, 2, 2), (3, 3)]\nparameters_kwargs = [{}, {\"flooring_fn\": None}, {\"flooring_fn\": add_flooring}]\n\n\n@pytest.mark.parametrize(\"shape\", parameters_shape)\n@pytest.mark.parametrize(\"kwargs\", parameters_kwargs)\ndef test_to_psd_real(shape: Tuple[int], kwargs):\n    X = rng.standard_normal(shape)\n    X = X @ X.swapaxes(-1, -2)\n    X = to_psd(X, **kwargs)\n    eigvals = np.linalg.eigvalsh(X)\n\n    assert np.all(X == X.swapaxes(-1, -2))\n    assert np.min(eigvals) > 0\n\n\n@pytest.mark.parametrize(\"shape\", parameters_shape)\n@pytest.mark.parametrize(\"kwargs\", parameters_kwargs)\ndef test_to_psd_complex(shape: Tuple[int], kwargs):\n    X = rng.standard_normal(shape) + 1j * rng.standard_normal(shape)\n    X = X @ X.swapaxes(-1, -2).conj()\n    X = to_psd(X, **kwargs)\n    eigvals = np.linalg.eigvalsh(X)\n\n    assert np.all(X == X.swapaxes(-1, -2).conj())\n    assert np.min(eigvals) > 0\n"
  },
  {
    "path": "tests/package/bss/test_solve_permutation.py",
    "content": "import numpy as np\nimport pytest\n\nfrom ssspy.bss._solve_permutation import correlation_based_permutation_solver\n\nrng = np.random.default_rng(0)\n\nparameters_give_demixing_filter = [True, False]\n\n\n@pytest.mark.parametrize(\"give_demixing_filter\", parameters_give_demixing_filter)\ndef test_correlation_based_permutation_solver(give_demixing_filter: bool):\n    n_sources = 3\n    n_channels = n_sources\n    n_bins, n_frames = 4, 16\n\n    shape = (n_channels, n_bins, n_frames)\n    mixture = rng.standard_normal(shape) + 1j * rng.standard_normal(shape)\n    shape = (n_bins, n_sources, n_channels)\n    demix_filter = rng.standard_normal(shape) + 1j * rng.standard_normal(shape)\n    separated = demix_filter @ mixture.transpose(1, 0, 2)\n\n    with pytest.warns(UserWarning) as record:\n        if give_demixing_filter:\n            separated, demix_filter = correlation_based_permutation_solver(separated, demix_filter)\n            assert demix_filter.shape == (n_bins, n_sources, n_channels)\n        else:\n            separated = correlation_based_permutation_solver(separated)\n\n            assert separated.shape == (n_bins, n_sources, n_frames)\n\n    assert len(record) == 1\n    assert (\n        str(record[0].message)\n        == \"Use ssspy.algorithm.permutation_alignment.correlation_based_permutation_solver instead.\"\n    )\n"
  },
  {
    "path": "tests/package/bss/test_update_spatial_model.py",
    "content": "from typing import Callable, Iterable, Optional, Tuple\n\nimport numpy as np\nimport pytest\n\nfrom ssspy.bss._update_spatial_model import (\n    _psd_inv,\n    update_by_block_decomposition_vcd,\n    update_by_ip1,\n    update_by_ip2,\n    update_by_ip2_one_pair,\n    update_by_iss1,\n    update_by_iss2,\n)\nfrom ssspy.special import add_flooring, max_flooring\nfrom ssspy.utils.select_pair import combination_pair_selector, sequential_pair_selector\n\n\ndef negative_pair_selector(n_sources):\n    for m in range(n_sources):\n        m, n = m % n_sources, (m + 1) % n_sources\n        m, n = m - n_sources, n - n_sources\n\n        yield m, n\n\n\nparameters = [(31, 20)]\nparameters_block_decomposition_vcd = [(15, 2, 20)]\nparameters_n_sources = [2, 3]\nparameters_flooring_fn = [max_flooring, add_flooring, None]\nparameters_overwrite = [True, False]\nparameters_singular_fn = [\n    lambda x: np.abs(x) < max_flooring(x),\n    lambda x: np.abs(x) < add_flooring(x),\n    None,\n]\nparameters_pair_selector = [\n    sequential_pair_selector,\n    combination_pair_selector,\n    negative_pair_selector,\n    None,\n]\n\n\n@pytest.mark.parametrize(\"n_bins, n_frames\", parameters)\n@pytest.mark.parametrize(\"n_sources\", parameters_n_sources)\n@pytest.mark.parametrize(\"flooring_fn\", parameters_flooring_fn)\ndef test_update_by_ip1(\n    n_bins: int,\n    n_frames: int,\n    n_sources: int,\n    flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]],\n):\n    n_channels = n_sources\n\n    rng = np.random.default_rng(42)\n\n    varphi = 1 / rng.random((n_sources, n_frames))\n    X = rng.standard_normal((n_channels, n_bins, n_frames))\n    real = rng.standard_normal((n_bins, n_sources, n_sources))\n    imag = rng.standard_normal((n_bins, n_sources, n_sources))\n    W = real + 1j * imag\n\n    XX_Hermite = X[:, np.newaxis, :, :] * X[np.newaxis, :, :, :].conj()\n    XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3)\n    GXX = varphi[:, np.newaxis, np.newaxis, :] * XX_Hermite[:, np.newaxis, :, :, :]\n    U = np.mean(GXX, axis=-1)\n    W_updated = update_by_ip1(W, U, flooring_fn=flooring_fn)\n\n    assert W_updated.shape == W.shape\n\n\n@pytest.mark.parametrize(\"n_bins, n_frames\", parameters)\n@pytest.mark.parametrize(\"n_sources\", parameters_n_sources)\n@pytest.mark.parametrize(\"flooring_fn\", parameters_flooring_fn)\n@pytest.mark.parametrize(\"pair_selector\", parameters_pair_selector)\ndef test_update_by_ip2(\n    n_bins: int,\n    n_frames: int,\n    n_sources: int,\n    flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]],\n    pair_selector: Optional[Callable[[int], Iterable[Tuple[int, int]]]],\n):\n    n_channels = n_sources\n\n    rng = np.random.default_rng(42)\n\n    varphi = 1 / rng.random((n_sources, n_frames))\n    X = rng.standard_normal((n_channels, n_bins, n_frames))\n    real = rng.standard_normal((n_bins, n_sources, n_sources))\n    imag = rng.standard_normal((n_bins, n_sources, n_sources))\n    W = real + 1j * imag\n\n    XX_Hermite = X[:, np.newaxis, :, :] * X[np.newaxis, :, :, :].conj()\n    XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3)\n    GXX = varphi[:, np.newaxis, np.newaxis, :] * XX_Hermite[:, np.newaxis, :, :, :]\n    U = np.mean(GXX, axis=-1)\n    W_updated = update_by_ip2(W, U, flooring_fn=flooring_fn, pair_selector=pair_selector)\n\n    assert W_updated.shape == W.shape\n\n\n@pytest.mark.parametrize(\"n_bins, n_frames\", parameters)\n@pytest.mark.parametrize(\"n_sources\", parameters_n_sources)\n@pytest.mark.parametrize(\"flooring_fn\", parameters_flooring_fn)\ndef test_update_by_ip2_one_pair(\n    n_bins: int,\n    n_frames: int,\n    n_sources: int,\n    flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]],\n):\n    n_channels = n_sources\n\n    rng = np.random.default_rng(42)\n\n    varphi = 1 / rng.random((2, n_bins, n_frames))\n    X = rng.standard_normal((n_channels, n_bins, n_frames))\n    real = rng.standard_normal((n_bins, n_sources, n_channels))\n    imag = rng.standard_normal((n_bins, n_sources, n_channels))\n    W = real + 1j * imag\n    XX = X[:, np.newaxis] * X[np.newaxis, :].conj()\n    GXX = np.mean(varphi[:, np.newaxis, np.newaxis, :, :] * XX[np.newaxis, :, :, :, :], axis=-1)\n    GXX = GXX.transpose(3, 0, 1, 2)\n\n    W_updated = update_by_ip2_one_pair(W, GXX, pair=(1, 0), flooring_fn=flooring_fn)\n\n    assert W_updated.shape == (n_bins, 2, n_channels)\n\n\n@pytest.mark.parametrize(\"n_bins, n_frames\", parameters)\n@pytest.mark.parametrize(\"n_sources\", parameters_n_sources)\n@pytest.mark.parametrize(\"flooring_fn\", parameters_flooring_fn)\ndef test_update_by_iss1(\n    n_bins: int,\n    n_frames: int,\n    n_sources: int,\n    flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]],\n):\n    rng = np.random.default_rng(42)\n\n    varphi = 1 / rng.random((n_sources, n_bins, n_frames))\n    real = rng.standard_normal((n_sources, n_bins, n_frames))\n    imag = rng.standard_normal((n_sources, n_bins, n_frames))\n    Y = real + 1j * imag\n\n    Y_updated = update_by_iss1(Y, varphi, flooring_fn=flooring_fn)\n\n    assert Y_updated.shape == Y.shape\n\n\n@pytest.mark.parametrize(\"n_bins, n_frames\", parameters)\n@pytest.mark.parametrize(\"n_sources\", parameters_n_sources)\n@pytest.mark.parametrize(\"flooring_fn\", parameters_flooring_fn)\n@pytest.mark.parametrize(\"pair_selector\", parameters_pair_selector)\ndef test_update_by_iss2(\n    n_bins: int,\n    n_frames: int,\n    n_sources: int,\n    flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]],\n    pair_selector: Optional[Callable[[int], Iterable[Tuple[int, int]]]],\n):\n    rng = np.random.default_rng(42)\n\n    varphi = 1 / rng.random((n_sources, n_bins, n_frames))\n    real = rng.standard_normal((n_sources, n_bins, n_frames))\n    imag = rng.standard_normal((n_sources, n_bins, n_frames))\n    Y = real + 1j * imag\n\n    Y_updated = update_by_iss2(Y, varphi, flooring_fn=flooring_fn, pair_selector=pair_selector)\n\n    assert Y_updated.shape == Y.shape\n\n\n@pytest.mark.parametrize(\"n_blocks, n_neighbors, n_frames\", parameters_block_decomposition_vcd)\n@pytest.mark.parametrize(\"n_sources\", parameters_n_sources)\n@pytest.mark.parametrize(\"singular_fn\", parameters_singular_fn)\n@pytest.mark.parametrize(\"overwrite\", parameters_overwrite)\ndef test_update_by_block_decomposition_vcd(\n    n_blocks: int,\n    n_neighbors: int,\n    n_frames: int,\n    n_sources: int,\n    singular_fn: Optional[Callable[[np.ndarray], np.ndarray]],\n    overwrite: bool,\n):\n    na = np.newaxis\n    n_channels = n_sources\n\n    rng = np.random.default_rng(42)\n\n    R = rng.random((n_blocks, n_neighbors, n_sources, n_channels, n_frames))\n    X = rng.standard_normal((n_channels, n_blocks, n_neighbors, n_frames))\n    real = rng.standard_normal((n_blocks, n_neighbors, n_sources, n_channels))\n    imag = rng.standard_normal((n_blocks, n_neighbors, n_sources, n_channels))\n    W = real + 1j * imag\n\n    R = R[:, :, na, :, :, :] * np.eye(n_neighbors)[:, :, na, na, na]\n    R = R[:, :, :, :, :, na, :] * np.eye(n_channels)[:, :, na]\n\n    XX_Hermite = X[:, na, :, :, na] * X[na, :, :, na, :].conj()\n    XX_Hermite = XX_Hermite.transpose(2, 3, 4, 0, 1, 5)\n\n    RXX = np.mean(R * XX_Hermite[:, :, :, na], axis=-1)\n\n    W_updated = update_by_block_decomposition_vcd(\n        W, weighted_covariance=RXX, singular_fn=singular_fn, overwrite=overwrite\n    )\n\n    assert W_updated.shape == W.shape\n\n\ndef test_psd_inv() -> None:\n    rng = np.random.default_rng(42)\n\n    n_bins, n_frames = 129, 100\n    n_sources = n_channels = 4\n\n    varphi = 1 / rng.random((n_sources, n_frames))\n    X = rng.standard_normal((n_channels, n_bins, n_frames))\n\n    XX_Hermite = X[:, np.newaxis, :, :] * X[np.newaxis, :, :, :].conj()\n    XX_Hermite = XX_Hermite.transpose(2, 0, 1, 3)\n    GXX = varphi[:, np.newaxis, np.newaxis, :] * XX_Hermite[:, np.newaxis, :, :, :]\n    U = np.mean(GXX, axis=-1)\n\n    U_inv = _psd_inv(U)\n    eye = np.eye(n_sources)\n\n    assert np.allclose(U @ U_inv, eye)\n"
  },
  {
    "path": "tests/package/io/test_wavread.py",
    "content": "import os\nimport tempfile\n\nimport numpy as np\nimport pytest\nfrom dummy.io import save_invalid_wavfile\nfrom dummy.utils.dataset import download_ssspy_data\nfrom scipy.io import wavfile\n\nfrom ssspy import wavread, wavwrite\n\nparameters_frame_offset = [0, 10]\nparameters_num_frames = [None, 100]\nparameters_channels_first = [True, False, None]\nparameters_float = [True, False]\nparameters_channels = [1, 2]\n\n\n@pytest.mark.parametrize(\"frame_offset\", parameters_frame_offset)\n@pytest.mark.parametrize(\"num_frames\", parameters_num_frames)\n@pytest.mark.parametrize(\"channels_first\", parameters_channels_first)\ndef test_wavread_monoral(frame_offset: int, num_frames: int, channels_first: bool):\n    path = \"audio/monoral_16k_5sec.wav\"\n    filename = \"./tests/mock/{}\".format(path)\n\n    download_ssspy_data(path, filename=filename)\n\n    if channels_first is not None:\n        return_2d = True\n    else:\n        return_2d = False\n\n    # load file using scipy\n    sample_rate_scipy, waveform_scipy = wavfile.read(filename)\n    waveform_scipy = waveform_scipy / 2**15\n\n    # load file using ssspy\n    waveform_ssspy, sample_rate_ssspy = wavread(\n        filename,\n        frame_offset=frame_offset,\n        num_frames=num_frames,\n        return_2d=return_2d,\n        channels_first=channels_first,\n    )\n\n    assert sample_rate_scipy == sample_rate_ssspy\n\n    if return_2d:\n        if channels_first:\n            waveform_ssspy = waveform_ssspy.squeeze(axis=0)\n        else:\n            waveform_ssspy = waveform_ssspy.squeeze(axis=1)\n\n    if num_frames is not None:\n        assert np.all(waveform_scipy[frame_offset : frame_offset + num_frames] == waveform_ssspy)\n    else:\n        assert np.all(waveform_scipy[frame_offset:] == waveform_ssspy)\n\n\n@pytest.mark.parametrize(\"frame_offset\", parameters_frame_offset)\n@pytest.mark.parametrize(\"num_frames\", parameters_num_frames)\n@pytest.mark.parametrize(\"channels_first\", parameters_channels_first)\ndef test_wavread_stereo(frame_offset: int, num_frames: int, channels_first: bool):\n    path = \"audio/stereo_16k_5sec.wav\"\n    filename = \"./tests/mock/{}\".format(path)\n\n    download_ssspy_data(path, filename=filename)\n\n    # load file using scipy\n    sample_rate_scipy, waveform_scipy = wavfile.read(filename)\n    waveform_scipy = waveform_scipy / 2**15\n\n    # load file using ssspy\n    waveform_ssspy, sample_rate_ssspy = wavread(\n        filename, frame_offset=frame_offset, num_frames=num_frames, channels_first=channels_first\n    )\n\n    assert sample_rate_scipy == sample_rate_ssspy\n\n    if channels_first:\n        # same order as that of scipy\n        waveform_ssspy = waveform_ssspy.transpose(1, 0)\n\n    if num_frames is not None:\n        assert np.all(waveform_scipy[frame_offset : frame_offset + num_frames] == waveform_ssspy)\n    else:\n        assert np.all(waveform_scipy[frame_offset:] == waveform_ssspy)\n\n\n@pytest.mark.parametrize(\"frame_offset\", parameters_frame_offset)\ndef test_wavread_invalid_monoral(frame_offset: int):\n    path = \"audio/monoral_16k_5sec.wav\"\n    filename = \"./tests/mock/{}\".format(path)\n\n    download_ssspy_data(path, filename=filename)\n\n    max_frame = 5 * 16000\n    valid_num_frames = max_frame - frame_offset\n\n    # valid data\n    wavread(filename, frame_offset=frame_offset, num_frames=valid_num_frames)\n\n    # invalid memory size\n    invalid_num_frames = valid_num_frames + 1\n\n    with pytest.raises(ValueError) as e:\n        wavread(filename, frame_offset=frame_offset, num_frames=invalid_num_frames)\n\n    assert str(e.value) == f\"num_frames={invalid_num_frames} exceeds maximum frame {max_frame}.\"\n\n\n@pytest.mark.parametrize(\"frame_offset\", parameters_frame_offset)\ndef test_wavread_invalid_stereo(frame_offset: int):\n    path = \"audio/stereo_16k_5sec.wav\"\n    filename = \"./tests/mock/{}\".format(path)\n\n    download_ssspy_data(path, filename=filename)\n\n    max_frame = 5 * 16000\n    valid_num_frames = max_frame - frame_offset\n\n    # valid data\n    wavread(filename, frame_offset=frame_offset, num_frames=valid_num_frames)\n\n    # invalid memory size\n    invalid_num_frames = valid_num_frames + 1\n\n    with pytest.raises(ValueError) as e:\n        wavread(filename, frame_offset=frame_offset, num_frames=invalid_num_frames)\n\n    assert str(e.value) == f\"num_frames={invalid_num_frames} exceeds maximum frame {max_frame}.\"\n\n\n@pytest.mark.parametrize(\"is_float\", parameters_float)\ndef test_wavio_1d(is_float: np.dtype):\n    rng = np.random.default_rng(0)\n\n    filename = \"valid.wav\"\n    sample_rate = 16000\n    duration = 5\n    bits_per_sample = 16\n    bytes_per_sample = bits_per_sample // 8\n    num_frames = sample_rate * duration\n    vmax = 2 ** (bits_per_sample - 1)\n\n    waveform = rng.integers(-vmax, vmax, size=(num_frames,), dtype=f\"<i{bytes_per_sample}\")\n\n    if is_float:\n        waveform = waveform / vmax\n        waveform_in = waveform.copy()\n    else:\n        waveform_in = waveform / vmax\n\n    with tempfile.TemporaryDirectory() as temp_dir:\n        path = os.path.join(temp_dir, filename)\n        wavwrite(path, waveform, sample_rate=sample_rate)\n        waveform_out, _ = wavread(path)\n\n    assert np.all(waveform_in == waveform_out)\n\n\n@pytest.mark.parametrize(\"is_float\", parameters_float)\n@pytest.mark.parametrize(\"n_channels\", parameters_channels)\n@pytest.mark.parametrize(\"channels_first\", parameters_channels_first)\ndef test_wavio_2d(is_float: np.dtype, n_channels: int, channels_first: bool):\n    rng = np.random.default_rng(0)\n\n    filename = \"valid.wav\"\n    sample_rate = 16000\n    duration = 5\n    bits_per_sample = 16\n    bytes_per_sample = bits_per_sample // 8\n    num_frames = sample_rate * duration\n    vmax = 2 ** (bits_per_sample - 1)\n\n    if channels_first:\n        shape = (n_channels, num_frames)\n    else:\n        shape = (num_frames, n_channels)\n\n    waveform = rng.integers(\n        -vmax,\n        vmax,\n        size=shape,\n        dtype=f\"<i{bytes_per_sample}\",\n    )\n\n    if is_float:\n        waveform = waveform / vmax\n        waveform_in = waveform.copy()\n    else:\n        waveform_in = waveform / vmax\n\n    with tempfile.TemporaryDirectory() as temp_dir:\n        path = os.path.join(temp_dir, filename)\n        wavwrite(path, waveform, sample_rate=sample_rate, channels_first=channels_first)\n        waveform_out, _ = wavread(path, return_2d=True, channels_first=channels_first)\n\n    assert np.all(waveform_in == waveform_out)\n\n\ndef test_waveread_invalid_metadata():\n    filename = \"./tests/mock/audio/monoral_16k_5sec_invalid.wav\"\n\n    # invalid riff\n    save_invalid_wavfile(filename, invalid_riff=True)\n\n    with pytest.raises(NotImplementedError) as e:\n        wavread(filename)\n\n    assert str(e.value) == f\"Not support {b'RIFX'}.\"\n\n    # invalid ftype\n    save_invalid_wavfile(filename, invalid_ftype=True)\n\n    with pytest.raises(NotImplementedError) as e:\n        wavread(filename)\n\n    assert str(e.value) == f\"Not support {b'wave'}.\"\n\n    # invalid fmt chunk marker\n    save_invalid_wavfile(filename, invalid_fmt_chunk_marker=True)\n\n    with pytest.raises(NotImplementedError) as e:\n        wavread(filename)\n\n    assert str(e.value) == f\"Not support {b'FMT '}.\"\n\n    # invalid fmt chunk size\n    save_invalid_wavfile(filename, invalid_fmt_chunk_size=True)\n\n    with pytest.raises(NotImplementedError) as e:\n        wavread(filename)\n\n    assert str(e.value) == \"Invalid header is detected.\"\n\n    # invalid fmt\n    save_invalid_wavfile(filename, invalid_fmt=True)\n\n    with pytest.raises(NotImplementedError) as e:\n        wavread(filename)\n\n    assert str(e.value) == \"Invalid header 0 is detected.\"\n\n    # invalid fmt byte rate\n    save_invalid_wavfile(filename, invalid_byte_rate=True)\n\n    with pytest.raises(ValueError) as e:\n        wavread(filename)\n\n    assert str(e.value) == \"Invalid header is detected.\"\n\n    # invalid data chunk marker\n    save_invalid_wavfile(filename, invalid_data_chunk_marker=True)\n\n    with pytest.raises(NotImplementedError) as e:\n        wavread(filename)\n\n    assert str(e.value) == f\"Not support {b'DATA'}.\"\n\n    os.remove(filename)\n"
  },
  {
    "path": "tests/package/linalg/test_cubic.py",
    "content": "import numpy as np\n\nfrom ssspy.linalg import cbrt\n\n\ndef test_cbrt():\n    rng = np.random.default_rng(0)\n    n_bins = 8\n    n_channels = 4\n\n    # real number\n    x = rng.standard_normal((n_bins, n_channels))\n    y = cbrt(x)\n\n    assert np.allclose(y**3, x)\n\n    # complex number\n    x = rng.standard_normal((n_bins, n_channels)) + 1j * rng.standard_normal((n_bins, n_channels))\n    y = cbrt(x)\n\n    assert np.allclose(y**3, x)\n"
  },
  {
    "path": "tests/package/linalg/test_eigh.py",
    "content": "import numpy as np\nimport pytest\n\nfrom ssspy.linalg import eigh, eigh2\n\nparameters_sources = [2, 5]\nparameters_channels = [4, 3]\nparameters_frames = [32, 16]\nparameters_is_complex = [True, False]\nparameters_type = [1, 2, 3]\n\n\n@pytest.mark.parametrize(\"n_sources\", parameters_sources)\n@pytest.mark.parametrize(\"n_channels\", parameters_channels)\n@pytest.mark.parametrize(\"n_frames\", parameters_frames)\n@pytest.mark.parametrize(\"is_complex\", parameters_is_complex)\ndef test_eigh(n_sources: int, n_channels: int, n_frames: int, is_complex: bool):\n    np.random.seed(111)\n\n    shape = (n_sources, n_channels, n_frames)\n\n    if is_complex:\n        a = np.random.randn(*shape) + 1j * np.random.randn(*shape)\n        A = np.mean(a[:, :, np.newaxis, :] * a[:, np.newaxis, :, :].conj(), axis=-1)\n    else:\n        a = np.random.randn(*shape)\n        A = np.mean(a[:, :, np.newaxis, :] * a[:, np.newaxis, :, :], axis=-1)\n\n    lamb, z = eigh(A)\n\n    assert lamb.shape == (n_sources, n_channels)\n    assert z.shape == (n_sources, n_channels, n_channels)\n\n    assert np.allclose(A @ z, lamb[:, np.newaxis, :] * z)\n\n\n@pytest.mark.parametrize(\"n_sources\", parameters_sources)\n@pytest.mark.parametrize(\"n_channels\", parameters_channels)\n@pytest.mark.parametrize(\"n_frames\", parameters_frames)\n@pytest.mark.parametrize(\"is_complex\", parameters_is_complex)\n@pytest.mark.parametrize(\"type\", parameters_type)\ndef test_generalized_eigh(\n    n_sources: int, n_channels: int, n_frames: int, is_complex: bool, type: int\n):\n    np.random.seed(111)\n\n    shape = (n_sources, n_channels, n_frames)\n\n    if is_complex:\n        a = np.random.randn(*shape) + 1j * np.random.randn(*shape)\n        b = np.random.randn(*shape) + 1j * np.random.randn(*shape)\n        A = np.mean(a[:, :, np.newaxis, :] * a[:, np.newaxis, :, :].conj(), axis=-1)\n        B = np.mean(b[:, :, np.newaxis, :] * b[:, np.newaxis, :, :].conj(), axis=-1)\n    else:\n        a = np.random.randn(*shape)\n        b = np.random.randn(*shape)\n        A = np.mean(a[:, :, np.newaxis, :] * a[:, np.newaxis, :, :], axis=-1)\n        B = np.mean(b[:, :, np.newaxis, :] * b[:, np.newaxis, :, :], axis=-1)\n\n    lamb, z = eigh(A, B, type=type)\n\n    assert lamb.shape == (n_sources, n_channels)\n    assert z.shape == (n_sources, n_channels, n_channels)\n\n    if type == 1:\n        assert np.allclose(A @ z, lamb[:, np.newaxis, :] * (B @ z))\n    elif type == 2:\n        assert np.allclose(A @ B @ z, lamb[:, np.newaxis, :] * z)\n    elif type == 3:\n        assert np.allclose(B @ A @ z, lamb[:, np.newaxis, :] * z)\n    else:\n        raise ValueError(\"Invalid type={} is given.\".format(type))\n\n\n@pytest.mark.parametrize(\"n_sources\", parameters_sources)\n@pytest.mark.parametrize(\"n_frames\", parameters_frames)\n@pytest.mark.parametrize(\"is_complex\", parameters_is_complex)\ndef test_eigh2(n_sources: int, n_frames: int, is_complex: bool):\n    np.random.seed(111)\n\n    shape = (n_sources, 2, n_frames)\n\n    if is_complex:\n        a = np.random.randn(*shape) + 1j * np.random.randn(*shape)\n        A = np.mean(a[:, :, np.newaxis, :] * a[:, np.newaxis, :, :].conj(), axis=-1)\n    else:\n        a = np.random.randn(*shape)\n        A = np.mean(a[:, :, np.newaxis, :] * a[:, np.newaxis, :, :], axis=-1)\n\n    lamb, z = eigh2(A)\n\n    assert lamb.shape == (n_sources, 2)\n    assert z.shape == (n_sources, 2, 2)\n\n    assert np.allclose(A @ z, lamb[:, np.newaxis, :] * z)\n\n\n@pytest.mark.parametrize(\"n_sources\", parameters_sources)\n@pytest.mark.parametrize(\"n_frames\", parameters_frames)\n@pytest.mark.parametrize(\"is_complex\", parameters_is_complex)\n@pytest.mark.parametrize(\"type\", parameters_type)\ndef test_generalized_eigh2(n_sources: int, n_frames: int, is_complex: bool, type: int):\n    np.random.seed(111)\n\n    shape = (n_sources, 2, n_frames)\n\n    if is_complex:\n        a = np.random.randn(*shape) + 1j * np.random.randn(*shape)\n        b = np.random.randn(*shape) + 1j * np.random.randn(*shape)\n        A = np.mean(a[:, :, np.newaxis, :] * a[:, np.newaxis, :, :].conj(), axis=-1)\n        B = np.mean(b[:, :, np.newaxis, :] * b[:, np.newaxis, :, :].conj(), axis=-1)\n    else:\n        a = np.random.randn(*shape)\n        b = np.random.randn(*shape)\n        A = np.mean(a[:, :, np.newaxis, :] * a[:, np.newaxis, :, :], axis=-1)\n        B = np.mean(b[:, :, np.newaxis, :] * b[:, np.newaxis, :, :], axis=-1)\n\n    lamb, z = eigh2(A, B, type=type)\n\n    assert lamb.shape == (n_sources, 2)\n    assert z.shape == (n_sources, 2, 2)\n\n    if type == 1:\n        assert np.allclose(A @ z, lamb[:, np.newaxis, :] * (B @ z))\n    elif type == 2:\n        assert np.allclose(A @ B @ z, lamb[:, np.newaxis, :] * z)\n    elif type == 3:\n        assert np.allclose(B @ A @ z, lamb[:, np.newaxis, :] * z)\n    else:\n        raise ValueError(\"Invalid type={} is given.\".format(type))\n"
  },
  {
    "path": "tests/package/linalg/test_gmean.py",
    "content": "import numpy as np\nimport pytest\nfrom scipy.linalg import sqrtm\n\nfrom ssspy.linalg import gmeanmh\n\nparameters_type = [1, 2, 3]\n\n\ndef gmeanmh_scipy(A: np.ndarray, B: np.ndarray, inverse=\"left\") -> np.ndarray:\n    def _sqrtm(X) -> np.ndarray:\n        return np.stack([sqrtm(x) for x in X], axis=0)\n\n    if inverse == \"left\":\n        AB = np.linalg.solve(A, B)\n        G = A @ _sqrtm(AB)\n    elif inverse == \"right\":\n        AB = np.linalg.solve(B, A)\n        AB = AB.swapaxes(-2, -1).conj()\n        G = _sqrtm(AB) @ B\n    else:\n        raise ValueError(f\"Invalid inverse={inverse} is given.\")\n\n    return G\n\n\n@pytest.mark.parametrize(\"type\", parameters_type)\ndef test_gmean(type: int):\n    rng = np.random.default_rng(0)\n    size = (16, 32, 4, 1)\n\n    def create_psd():\n        x = rng.random(size) + 1j * rng.random(size)\n        XX = x * x.transpose(0, 1, 3, 2).conj()\n\n        return np.mean(XX, axis=0)\n\n    A = create_psd()\n    B = create_psd()\n\n    G1 = gmeanmh(A, B, type=type)\n\n    if type == 1:\n        assert np.allclose(G1 @ np.linalg.inv(A) @ G1, B)\n    elif type == 2:\n        assert np.allclose(G1 @ A @ G1, B)\n    elif type == 3:\n        assert np.allclose(G1 @ np.linalg.inv(A) @ G1, np.linalg.inv(B))\n    else:\n        raise ValueError(\"Invalid type={} is given.\".format(type))\n\n    if type == 2:\n        A = np.linalg.inv(A)\n    elif type == 3:\n        B = np.linalg.inv(B)\n\n    G2 = gmeanmh_scipy(A, B, inverse=\"left\")\n    G3 = gmeanmh_scipy(A, B, inverse=\"right\")\n\n    assert np.allclose(G1, G2)\n    assert np.allclose(G1, G3)\n"
  },
  {
    "path": "tests/package/linalg/test_inv.py",
    "content": "import numpy as np\nimport pytest\n\nfrom ssspy.linalg import inv2\n\nparameters_sources = [2, 5]\n\n\n@pytest.mark.parametrize(\"n_sources\", parameters_sources)\ndef test_inv2(n_sources: int):\n    np.random.seed(111)\n\n    shape = (n_sources, 2, 2)\n\n    A = np.random.randn(*shape) + 1j * np.random.randn(*shape)\n    B = inv2(A)\n\n    assert np.allclose(A @ B, np.eye(2))\n"
  },
  {
    "path": "tests/package/linalg/test_lqpqm.py",
    "content": "import numpy as np\n\nfrom ssspy.linalg.lqpqm import _find_largest_root\n\n\ndef test_find_largest_root():\n    alpha = np.array([-1, 1, 1, -1 + 1j])\n    beta = np.array([0, 1, 1, -1 - 1j])\n    gamma = np.array([1, 1, 2, 1])\n\n    A = -np.real(alpha + beta + gamma)\n    B = np.real(alpha * beta + beta * gamma + gamma * alpha)\n    C = -np.real(alpha * beta * gamma)\n\n    X = _find_largest_root(A, B, C)\n\n    assert np.allclose(X, gamma)\n"
  },
  {
    "path": "tests/package/linalg/test_polynomial.py",
    "content": "import numpy as np\n\nfrom ssspy.linalg.polynomial import _find_cubic_roots, solve_cubic\n\n\ndef test_find_cubic_roots():\n    rng = np.random.default_rng(0)\n\n    n_bins, n_channels = 3, 2\n\n    P = rng.standard_normal((n_bins, n_channels))\n    Q = rng.standard_normal((n_bins, n_channels))\n    X = _find_cubic_roots(P, Q)\n    Y = X**3 + P * X + Q\n\n    assert np.allclose(Y, 0)\n\n\ndef test_solve_cubic():\n    rng = np.random.default_rng(0)\n\n    n_bins, n_channels = 3, 2\n\n    # real coefficients\n    A = rng.standard_normal((n_bins, n_channels))\n    B = rng.standard_normal((n_bins, n_channels))\n    C = rng.standard_normal((n_bins, n_channels))\n    D = rng.standard_normal((n_bins, n_channels))\n\n    X = solve_cubic(A, B, C)\n    Y = X**3 + A * X**2 + B * X + C\n\n    assert np.allclose(Y, 0)\n\n    X = solve_cubic(A, B, C, D)\n    Y = A * X**3 + B * X**2 + C * X + D\n\n    assert np.allclose(Y, 0)\n\n    # corner case\n    A = np.zeros_like(C)\n    B = np.zeros_like(C)\n\n    X = solve_cubic(A, B, C)\n    Y = X**3 + A * X**2 + B * X + C\n\n    assert np.allclose(Y, 0)\n\n    # complex coefficients\n    A = rng.standard_normal((n_bins, n_channels)) + 1j * rng.standard_normal((n_bins, n_channels))\n    B = rng.standard_normal((n_bins, n_channels)) + 1j * rng.standard_normal((n_bins, n_channels))\n    C = rng.standard_normal((n_bins, n_channels)) + 1j * rng.standard_normal((n_bins, n_channels))\n    D = rng.standard_normal((n_bins, n_channels)) + 1j * rng.standard_normal((n_bins, n_channels))\n\n    X = solve_cubic(A, B, C)\n    Y = X**3 + A * X**2 + B * X + C\n\n    assert np.allclose(Y, 0)\n\n    X = solve_cubic(A, B, C, D)\n    Y = A * X**3 + B * X**2 + C * X + D\n\n    assert np.allclose(Y, 0)\n\n    # corner case\n    A = np.zeros_like(C)\n    B = np.zeros_like(C)\n\n    X = solve_cubic(A, B, C)\n    Y = X**3 + A * X**2 + B * X + C\n\n    assert np.allclose(Y, 0)\n"
  },
  {
    "path": "tests/package/linalg/test_sqrtm.py",
    "content": "import numpy as np\nimport pytest\n\nfrom ssspy.linalg import invsqrtmh, sqrtmh\n\nparameters_sources = [2]\nparameters_channels = [3, 4]\nparameters_frames = [32]\nparameters_is_complex = [True, False]\nparameters_is_flooring = [True, False]\n\n\n@pytest.mark.parametrize(\"n_sources\", parameters_sources)\n@pytest.mark.parametrize(\"n_channels\", parameters_channels)\n@pytest.mark.parametrize(\"n_frames\", parameters_frames)\n@pytest.mark.parametrize(\"is_complex\", parameters_is_complex)\ndef test_sqrtmh(n_sources: int, n_channels: int, n_frames: int, is_complex: bool):\n    rng = np.random.default_rng(0)\n\n    shape = (n_sources, n_channels, n_frames)\n\n    if is_complex:\n        x = rng.standard_normal(shape) + 1j * rng.standard_normal(shape)\n        X = np.mean(x[:, :, np.newaxis, :] * x[:, np.newaxis, :, :].conj(), axis=-1)\n    else:\n        x = rng.standard_normal(shape)\n        X = np.mean(x[:, :, np.newaxis, :] * x[:, np.newaxis, :, :], axis=-1)\n\n    X_sqrt = sqrtmh(X)\n\n    assert np.allclose(X, X_sqrt @ X_sqrt)\n\n\n@pytest.mark.parametrize(\"n_sources\", parameters_sources)\n@pytest.mark.parametrize(\"n_channels\", parameters_channels)\n@pytest.mark.parametrize(\"n_frames\", parameters_frames)\n@pytest.mark.parametrize(\"is_complex\", parameters_is_complex)\n@pytest.mark.parametrize(\"is_flooring\", parameters_is_flooring)\ndef test_invsqrtmh(\n    n_sources: int, n_channels: int, n_frames: int, is_complex: bool, is_flooring: bool\n):\n    rng = np.random.default_rng(0)\n\n    shape = (n_sources, n_channels, n_frames)\n\n    if is_complex:\n        x = rng.standard_normal(shape) + 1j * rng.standard_normal(shape)\n        X = np.mean(x[:, :, np.newaxis, :] * x[:, np.newaxis, :, :].conj(), axis=-1)\n    else:\n        x = rng.standard_normal(shape)\n        X = np.mean(x[:, :, np.newaxis, :] * x[:, np.newaxis, :, :], axis=-1)\n\n    if is_flooring:\n        X_invsqrt = invsqrtmh(X, flooring_fn=lambda x: np.maximum(x, 1e-10))\n    else:\n        X_invsqrt = invsqrtmh(X)\n\n    X_sqrt = np.linalg.inv(X_invsqrt)\n\n    assert np.allclose(X, X_sqrt @ X_sqrt)\n"
  },
  {
    "path": "tests/package/special/test_logsumexp.py",
    "content": "from typing import Optional\n\nimport numpy as np\nimport pytest\nimport scipy.special\n\nfrom ssspy.special import logsumexp\n\nparameters_axis = [0, 1, (0, 2), None]\nparameters_keepdims = [True, False]\n\n\n@pytest.mark.parametrize(\"axis\", parameters_axis)\n@pytest.mark.parametrize(\"keepdims\", parameters_keepdims)\ndef test_logsumexp(axis: Optional[int], keepdims: bool):\n    rng = np.random.default_rng(0)\n\n    n_sources, n_channels = 4, 3\n    n_frames = 8\n    shape = (n_sources, n_frames, n_channels, n_channels)\n\n    X = rng.random(shape)\n\n    Y = logsumexp(X, axis=axis, keepdims=keepdims)\n    Y_scipy = scipy.special.logsumexp(X, axis=axis, keepdims=keepdims)\n\n    assert np.allclose(Y, Y_scipy)\n"
  },
  {
    "path": "tests/package/special/test_psd.py",
    "content": "from typing import Tuple\n\nimport numpy as np\nimport pytest\n\nfrom ssspy.special import add_flooring, to_psd\n\nrng = np.random.default_rng(42)\n\nparameters_shape = [(5, 2, 2), (3, 3)]\nparameters_kwargs = [{}, {\"flooring_fn\": None}, {\"flooring_fn\": add_flooring}]\n\n\n@pytest.mark.parametrize(\"shape\", parameters_shape)\n@pytest.mark.parametrize(\"kwargs\", parameters_kwargs)\ndef test_to_psd_real(shape: Tuple[int], kwargs):\n    X = rng.standard_normal(shape)\n    X = X @ X.swapaxes(-1, -2)\n    X = to_psd(X, **kwargs)\n    eigvals = np.linalg.eigvalsh(X)\n\n    assert np.all(X == X.swapaxes(-1, -2))\n    assert np.min(eigvals) > 0\n\n\n@pytest.mark.parametrize(\"shape\", parameters_shape)\n@pytest.mark.parametrize(\"kwargs\", parameters_kwargs)\ndef test_to_psd_complex(shape: Tuple[int], kwargs):\n    X = rng.standard_normal(shape) + 1j * rng.standard_normal(shape)\n    X = X @ X.swapaxes(-1, -2).conj()\n    X = to_psd(X, **kwargs)\n    eigvals = np.linalg.eigvalsh(X)\n\n    assert np.all(X == X.swapaxes(-1, -2).conj())\n    assert np.min(eigvals) > 0\n"
  },
  {
    "path": "tests/package/special/test_softmax.py",
    "content": "from typing import Optional\n\nimport numpy as np\nimport pytest\nimport scipy.special\n\nfrom ssspy.special import softmax\n\nparameters_axis = [0, 1, (0, 2), None]\n\n\n@pytest.mark.parametrize(\"axis\", parameters_axis)\ndef test_logsumexp(axis: Optional[int]):\n    rng = np.random.default_rng(0)\n\n    n_sources, n_channels = 4, 3\n    n_frames = 8\n    shape = (n_sources, n_frames, n_channels, n_channels)\n\n    X = rng.random(shape)\n\n    Y = softmax(X, axis=axis)\n    Y_scipy = scipy.special.softmax(X, axis=axis)\n\n    assert np.allclose(Y, Y_scipy)\n"
  },
  {
    "path": "tests/package/transform/test_pca.py",
    "content": "import numpy as np\nimport pytest\n\nfrom ssspy.transform import pca\n\nparameters_ascend = [True, False]\nparameters_batch_size = [1, 4]\nparameters_n_channels = [2, 3]\nparameters_pca_real = [10, 20]\nparameters_pca_complex = [(257, 8), (65, 12)]\n\n\n@pytest.mark.parametrize(\"ascend\", parameters_ascend)\n@pytest.mark.parametrize(\"n_channels\", parameters_n_channels)\n@pytest.mark.parametrize(\"n_samples\", parameters_pca_real)\ndef test_pca_real_2d(ascend: bool, n_channels: int, n_samples: int):\n    np.random.seed(111)\n\n    input = np.random.randn(n_channels, n_samples)\n    output = pca(input, ascend=ascend)\n\n    assert input.shape == output.shape\n\n    covariance = output[:, np.newaxis, :] * output[np.newaxis, :, :]\n    covariance = np.mean(covariance, axis=-1)\n    mask = 1 - np.eye(n_channels)\n    zero = np.zeros((n_channels, n_channels))\n\n    assert np.allclose(mask * covariance, zero)\n\n\n@pytest.mark.parametrize(\"ascend\", parameters_ascend)\n@pytest.mark.parametrize(\"batch_size\", parameters_batch_size)\n@pytest.mark.parametrize(\"n_channels\", parameters_n_channels)\n@pytest.mark.parametrize(\"n_samples\", parameters_pca_real)\ndef test_pca_real_3d(ascend: bool, batch_size: int, n_channels: int, n_samples: int):\n    np.random.seed(111)\n\n    input = np.random.randn(batch_size, n_channels, n_samples)\n    output = pca(input, ascend=ascend)\n\n    assert input.shape == output.shape\n\n    covariance = output[:, :, np.newaxis, :] * output[:, np.newaxis, :, :]\n    covariance = np.mean(covariance, axis=-1)\n    mask = 1 - np.eye(n_channels)\n    zero = np.zeros((batch_size, n_channels, n_channels))\n\n    assert np.allclose(mask * covariance, zero)\n\n\n@pytest.mark.parametrize(\"ascend\", parameters_ascend)\n@pytest.mark.parametrize(\"n_channels\", parameters_n_channels)\n@pytest.mark.parametrize(\"n_bins, n_frames\", parameters_pca_complex)\ndef test_pca_complex_3d(ascend: bool, n_channels: int, n_bins: int, n_frames: int):\n    np.random.seed(111)\n\n    real = np.random.randn(n_channels, n_bins, n_frames)\n    imag = np.random.randn(n_channels, n_bins, n_frames)\n    input = real + 1j * imag\n    output = pca(input, ascend=ascend)\n\n    assert input.shape == output.shape\n\n    covariance = output[:, np.newaxis, :, :] * output[np.newaxis, :, :, :].conj()\n    covariance = np.mean(covariance, axis=-1)\n    covariance = covariance.transpose(2, 0, 1)\n    mask = 1 - np.eye(n_channels)\n    zero = np.zeros((n_bins, n_channels, n_channels))\n\n    assert np.allclose(mask * covariance, zero)\n\n\n@pytest.mark.parametrize(\"ascend\", parameters_ascend)\n@pytest.mark.parametrize(\"batch_size\", parameters_batch_size)\n@pytest.mark.parametrize(\"n_channels\", parameters_n_channels)\n@pytest.mark.parametrize(\"n_bins, n_frames\", parameters_pca_complex)\ndef test_pca_complex_4d(ascend: bool, batch_size: int, n_channels: int, n_bins: int, n_frames: int):\n    np.random.seed(111)\n\n    real = np.random.randn(batch_size, n_channels, n_bins, n_frames)\n    imag = np.random.randn(batch_size, n_channels, n_bins, n_frames)\n    input = real + 1j * imag\n    output = pca(input, ascend=ascend)\n\n    assert input.shape == output.shape\n\n    covariance = output[:, :, np.newaxis, :, :] * output[:, np.newaxis, :, :, :].conj()\n    covariance = np.mean(covariance, axis=-1)\n    covariance = covariance.transpose(0, 3, 1, 2)\n    mask = 1 - np.eye(n_channels)\n    zero = np.zeros((batch_size, n_bins, n_channels, n_channels))\n\n    assert np.allclose(mask * covariance, zero)\n"
  },
  {
    "path": "tests/package/transform/test_whiten.py",
    "content": "import numpy as np\nimport pytest\n\nfrom ssspy.transform import whiten\n\nparameters_batch_size = [1, 4]\nparameters_n_channels = [2, 3]\nparameters_whiten_real = [10, 20]\nparameters_whiten_complex = [(2049, 8), (513, 12)]\n\n\n@pytest.mark.parametrize(\"n_channels\", parameters_n_channels)\n@pytest.mark.parametrize(\"n_samples\", parameters_whiten_real)\ndef test_whiten_real_2d(n_channels: int, n_samples: int):\n    np.random.seed(111)\n\n    input = np.random.randn(n_channels, n_samples)\n    output = whiten(input)\n\n    assert input.shape == output.shape\n\n    covariance = output[:, np.newaxis, :] * output[np.newaxis, :, :]\n    covariance = np.mean(covariance, axis=-1)\n    eye = np.eye(n_channels)\n\n    assert np.allclose(covariance, eye)\n\n\n@pytest.mark.parametrize(\"batch_size\", parameters_batch_size)\n@pytest.mark.parametrize(\"n_channels\", parameters_n_channels)\n@pytest.mark.parametrize(\"n_samples\", parameters_whiten_real)\ndef test_whiten_real_3d(batch_size: int, n_channels: int, n_samples: int):\n    np.random.seed(111)\n\n    input = np.random.randn(batch_size, n_channels, n_samples)\n    output = whiten(input)\n\n    assert input.shape == output.shape\n\n    covariance = output[:, :, np.newaxis, :] * output[:, np.newaxis, :, :]\n    covariance = np.mean(covariance, axis=-1)\n    eye = np.eye(n_channels)\n\n    assert np.allclose(covariance, eye)\n\n\n@pytest.mark.parametrize(\"n_channels\", parameters_n_channels)\n@pytest.mark.parametrize(\"n_bins, n_frames\", parameters_whiten_complex)\ndef test_whiten_complex_3d(n_channels: int, n_bins: int, n_frames: int):\n    np.random.seed(111)\n\n    real = np.random.randn(n_channels, n_bins, n_frames)\n    imag = np.random.randn(n_channels, n_bins, n_frames)\n    input = real + 1j * imag\n    output = whiten(input)\n\n    assert input.shape == output.shape\n\n    covariance = output[:, np.newaxis, :, :] * output[np.newaxis, :, :, :].conj()\n    covariance = np.mean(covariance, axis=-1)\n    covariance = covariance.transpose(2, 0, 1)\n    eye = np.eye(n_channels)\n    eye = np.tile(eye, reps=(n_bins, 1, 1))\n\n    assert np.allclose(covariance, eye)\n\n\n@pytest.mark.parametrize(\"batch_size\", parameters_batch_size)\n@pytest.mark.parametrize(\"n_channels\", parameters_n_channels)\n@pytest.mark.parametrize(\"n_bins, n_frames\", parameters_whiten_complex)\ndef test_whiten_complex_4d(batch_size: int, n_channels: int, n_bins: int, n_frames: int):\n    np.random.seed(111)\n\n    real = np.random.randn(batch_size, n_channels, n_bins, n_frames)\n    imag = np.random.randn(batch_size, n_channels, n_bins, n_frames)\n    input = real + 1j * imag\n    output = whiten(input)\n\n    assert input.shape == output.shape\n\n    covariance = output[:, :, np.newaxis, :, :] * output[:, np.newaxis, :, :, :].conj()\n    covariance = np.mean(covariance, axis=-1)\n    covariance = covariance.transpose(0, 3, 1, 2)\n    eye = np.eye(n_channels)\n    eye = np.tile(eye, reps=(batch_size, n_bins, 1, 1))\n\n    assert np.allclose(covariance, eye)\n"
  },
  {
    "path": "tests/package/utils/test_dataset.py",
    "content": "import pytest\n\nfrom ssspy.utils.dataset import download_sample_speech_data\n\nparameters_dataset = [\n    (2, \"dev1_female3\"),\n    (3, \"dev1_female3\"),\n    (4, \"dev1_female4\"),\n]\nparameters_max_duration = [1.2]\nparameters_conv = [True, False]\n\n\n@pytest.mark.parametrize(\"n_sources, sisec2010_tag\", parameters_dataset)\n@pytest.mark.parametrize(\"max_duration\", parameters_max_duration)\n@pytest.mark.parametrize(\"conv\", parameters_conv)\ndef test_conv_dataset(n_sources: int, sisec2010_tag: str, max_duration: int, conv: bool):\n    waveform_src_img, sample_rate = download_sample_speech_data(\n        sisec2010_root=\"./tests/.data/SiSEC2010\",\n        mird_root=\"./tests/.data/MIRD\",\n        n_sources=n_sources,\n        sisec2010_tag=sisec2010_tag,\n        max_duration=max_duration,\n        conv=conv,\n    )\n\n    n_channels = n_sources\n\n    assert waveform_src_img.shape == (n_channels, n_sources, int(sample_rate * max_duration))\n"
  },
  {
    "path": "tests/package/utils/test_select_pair.py",
    "content": "import pytest\n\nfrom ssspy.utils.select_pair import combination_pair_selector, sequential_pair_selector\n\nparameters_n_sources = [2, 3, 4]\nparameters_step = [1, 2]\nparameters_ascend = [True, False]\n\n\n@pytest.mark.parametrize(\"n_sources\", parameters_n_sources)\n@pytest.mark.parametrize(\"step\", parameters_step)\n@pytest.mark.parametrize(\"ascend\", parameters_ascend)\ndef test_sequential_pair_selector(n_sources: int, step: int, ascend: bool):\n    for m, n in sequential_pair_selector(n_sources, step=step, sort=ascend):\n        if ascend:\n            assert m < n\n\n\n@pytest.mark.parametrize(\"n_sources\", parameters_n_sources)\n@pytest.mark.parametrize(\"ascend\", parameters_ascend)\ndef test_combination_pair_selector(n_sources: int, ascend: bool):\n    for m, n in combination_pair_selector(n_sources, sort=ascend):\n        if ascend:\n            assert m < n\n"
  },
  {
    "path": "tests/regression/bss/test_cacgmm.py",
    "content": "import sys\nfrom os import makedirs\nfrom os.path import dirname, join, realpath\n\nimport numpy as np\n\nfrom ssspy.bss.cacgmm import CACGMM\n\nssspy_tests_dir = dirname(dirname(dirname(realpath(__file__))))\nsys.path.append(ssspy_tests_dir)\n\nfrom dummy.utils.dataset import load_regression_data  # noqa: E402\n\ncacgmm_root = join(ssspy_tests_dir, \"mock\", \"regression\", \"bss\", \"cacgmm\")\n\n\ndef test_cacgmm(save_feature: bool = False):\n    rng = np.random.default_rng(0)\n\n    if save_feature:\n        (npz_input,) = load_regression_data(root=cacgmm_root, filenames=[\"input.npz\"])\n        spectrogram_tgt = None\n        n_iter = 10\n    else:\n        npz_input, npz_target = load_regression_data(\n            root=cacgmm_root, filenames=[\"input.npz\", \"target.npz\"]\n        )\n        spectrogram_tgt = npz_target[\"spectrogram\"]\n        n_iter = npz_target[\"n_iter\"].item()\n\n    spectrogram_mix = npz_input[\"spectrogram\"]\n\n    cacgmm = CACGMM(rng=rng)\n    spectrogram_est = cacgmm(spectrogram_mix, n_iter=n_iter)\n\n    if save_feature:\n        makedirs(cacgmm_root, exist_ok=True)\n        np.savez(\n            join(cacgmm_root, \"target.npz\"),\n            spectrogram=spectrogram_est,\n            n_iter=n_iter,\n        )\n    else:\n        assert np.allclose(spectrogram_est, spectrogram_tgt, atol=1e-7), np.max(\n            np.abs(spectrogram_est - spectrogram_tgt)\n        )\n\n\ndef save_all_features() -> None:\n    test_cacgmm(save_feature=True)\n\n\nif __name__ == \"__main__\":\n    save_all_features()\n"
  },
  {
    "path": "tests/regression/bss/test_fdica.py",
    "content": "import sys\nfrom os import makedirs\nfrom os.path import dirname, join, realpath\n\nimport numpy as np\nimport pytest\n\nfrom ssspy.bss.fdica import AuxLaplaceFDICA, GradLaplaceFDICA, NaturalGradLaplaceFDICA\n\nssspy_tests_dir = dirname(dirname(dirname(realpath(__file__))))\nsys.path.append(ssspy_tests_dir)\n\n\nfrom dummy.utils.dataset import load_regression_data  # noqa: E402\n\nfdica_root = join(ssspy_tests_dir, \"mock\", \"regression\", \"bss\", \"fdica\")\nn_sources = 2\n\nparameters_is_holonomic = [True, False]\nparameters_spatial_algorithm = [\"IP1\", \"IP2\"]\n\n\n@pytest.mark.parametrize(\"is_holonomic\", parameters_is_holonomic)\ndef test_grad_laplace_fdica(is_holonomic: bool, save_feature: bool = False):\n    if is_holonomic:\n        root = join(fdica_root, \"grad_laplace_fdica\", \"holonomic\")\n    else:\n        root = join(fdica_root, \"grad_laplace_fdica\", \"nonholonomic\")\n\n    if save_feature:\n        (npz_input,) = load_regression_data(root=root, filenames=[\"input.npz\"])\n        spectrogram_tgt = None\n        n_iter = 10\n    else:\n        npz_input, npz_target = load_regression_data(\n            root=root, filenames=[\"input.npz\", \"target.npz\"]\n        )\n        spectrogram_tgt = npz_target[\"spectrogram\"]\n        n_iter = npz_target[\"n_iter\"].item()\n\n    spectrogram_mix = npz_input[\"spectrogram\"]\n\n    fdica = GradLaplaceFDICA(is_holonomic=is_holonomic)\n    spectrogram_est = fdica(spectrogram_mix, n_iter=n_iter)\n\n    if save_feature:\n        np.savez(\n            join(root, \"target.npz\"),\n            spectrogram=spectrogram_est,\n            n_iter=n_iter,\n        )\n    else:\n        assert np.allclose(spectrogram_est, spectrogram_tgt, atol=1e-7), np.max(\n            np.abs(spectrogram_est - spectrogram_tgt)\n        )\n\n\n@pytest.mark.parametrize(\"is_holonomic\", parameters_is_holonomic)\ndef test_natural_grad_laplace_fdica(is_holonomic: bool, save_feature: bool = False):\n    if is_holonomic:\n        root = join(fdica_root, \"natural_grad_laplace_fdica\", \"holonomic\")\n    else:\n        root = join(fdica_root, \"natural_grad_laplace_fdica\", \"nonholonomic\")\n\n    if save_feature:\n        (npz_input,) = load_regression_data(root=root, filenames=[\"input.npz\"])\n        spectrogram_tgt = None\n        n_iter = 10\n    else:\n        npz_input, npz_target = load_regression_data(\n            root=root, filenames=[\"input.npz\", \"target.npz\"]\n        )\n        spectrogram_tgt = npz_target[\"spectrogram\"]\n        n_iter = npz_target[\"n_iter\"].item()\n\n    spectrogram_mix = npz_input[\"spectrogram\"]\n\n    fdica = NaturalGradLaplaceFDICA(is_holonomic=is_holonomic)\n    spectrogram_est = fdica(spectrogram_mix, n_iter=n_iter)\n\n    if save_feature:\n        makedirs(root, exist_ok=True)\n        np.savez(\n            join(root, \"target.npz\"),\n            spectrogram=spectrogram_est,\n            n_iter=n_iter,\n        )\n    else:\n        assert np.allclose(spectrogram_est, spectrogram_tgt, atol=1e-7), np.max(\n            np.abs(spectrogram_est - spectrogram_tgt)\n        )\n\n\n@pytest.mark.parametrize(\"spatial_algorithm\", parameters_spatial_algorithm)\ndef test_aux_laplace_fdica(spatial_algorithm: str, save_feature: bool = False):\n    root = join(fdica_root, \"aux_laplace_fdica\", spatial_algorithm)\n\n    if save_feature:\n        (npz_input,) = load_regression_data(root=root, filenames=[\"input.npz\"])\n        spectrogram_tgt = None\n        n_iter = 10\n    else:\n        npz_input, npz_target = load_regression_data(\n            root=root, filenames=[\"input.npz\", \"target.npz\"]\n        )\n        spectrogram_tgt = npz_target[\"spectrogram\"]\n        n_iter = npz_target[\"n_iter\"].item()\n\n    spectrogram_mix = npz_input[\"spectrogram\"]\n\n    fdica = AuxLaplaceFDICA(spatial_algorithm=spatial_algorithm)\n    spectrogram_est = fdica(spectrogram_mix, n_iter=n_iter)\n\n    if save_feature:\n        makedirs(root, exist_ok=True)\n        np.savez(\n            join(root, \"target.npz\"),\n            spectrogram=spectrogram_est,\n            n_iter=n_iter,\n        )\n    else:\n        assert np.allclose(spectrogram_est, spectrogram_tgt, atol=1e-7), np.max(\n            np.abs(spectrogram_est - spectrogram_tgt)\n        )\n\n\ndef save_all_features() -> None:\n    for is_holonomic in parameters_is_holonomic:\n        test_grad_laplace_fdica(is_holonomic=is_holonomic, save_feature=True)\n\n    for is_holonomic in parameters_is_holonomic:\n        test_natural_grad_laplace_fdica(is_holonomic=is_holonomic, save_feature=True)\n\n    for spatial_algorithm in parameters_spatial_algorithm:\n        test_aux_laplace_fdica(spatial_algorithm=spatial_algorithm, save_feature=True)\n\n\nif __name__ == \"__main__\":\n    save_all_features()\n"
  },
  {
    "path": "tests/regression/bss/test_ilrma.py",
    "content": "import sys\nfrom os import makedirs\nfrom os.path import dirname, join, realpath\n\nimport numpy as np\nimport pytest\n\nfrom ssspy.bss.ilrma import GGDILRMA, TILRMA, GaussILRMA\n\nssspy_tests_dir = dirname(dirname(dirname(realpath(__file__))))\nsys.path.append(ssspy_tests_dir)\n\nfrom dummy.utils.dataset import load_regression_data  # noqa: E402\n\nilrma_root = join(ssspy_tests_dir, \"mock\", \"regression\", \"bss\", \"ilrma\")\n\nparameters_spatial_algorithm = [\"IP1\", \"IP2\", \"ISS1\", \"ISS2\", \"IPA\"]\nparameters_source_algorithm = [\"MM\", \"ME\"]\n\n\n@pytest.mark.parametrize(\"spatial_algorithm\", parameters_spatial_algorithm)\n@pytest.mark.parametrize(\"source_algorithm\", parameters_source_algorithm)\ndef test_gauss_ilrma(spatial_algorithm: str, source_algorithm: str, save_feature: bool = False):\n    rng = np.random.default_rng(0)\n    root = join(ilrma_root, \"gauss_ilrma\", spatial_algorithm, source_algorithm)\n\n    if save_feature:\n        (npz_input,) = load_regression_data(root=root, filenames=[\"input.npz\"])\n        spectrogram_tgt = None\n        n_basis = 2\n        n_iter = 10\n    else:\n        npz_input, npz_target = load_regression_data(\n            root=root, filenames=[\"input.npz\", \"target.npz\"]\n        )\n        spectrogram_tgt = npz_target[\"spectrogram\"]\n        n_basis = npz_target[\"n_basis\"].item()\n        n_iter = npz_target[\"n_iter\"].item()\n\n    spectrogram_mix = npz_input[\"spectrogram\"]\n\n    if save_feature:\n        n_sources, n_bins, n_frames = spectrogram_mix.shape\n\n        basis = rng.random((n_sources, n_bins, n_basis))\n        activation = rng.random((n_sources, n_basis, n_frames))\n    else:\n        basis = npz_target[\"basis\"]\n        activation = npz_target[\"activation\"]\n\n    ilrma = GaussILRMA(\n        n_basis=n_basis,\n        spatial_algorithm=spatial_algorithm,\n        source_algorithm=source_algorithm,\n        rng=rng,\n    )\n    spectrogram_est = ilrma(\n        spectrogram_mix,\n        n_iter=n_iter,\n        basis=basis,\n        activation=activation,\n    )\n\n    if save_feature:\n        makedirs(root, exist_ok=True)\n        np.savez(\n            join(root, \"target.npz\"),\n            spectrogram=spectrogram_est,\n            basis=basis,\n            activation=activation,\n            n_basis=n_basis,\n            n_iter=n_iter,\n        )\n    else:\n        assert np.allclose(spectrogram_est, spectrogram_tgt, atol=1e-7), np.max(\n            np.abs(spectrogram_est - spectrogram_tgt)\n        )\n\n\n@pytest.mark.parametrize(\"spatial_algorithm\", parameters_spatial_algorithm)\n@pytest.mark.parametrize(\"source_algorithm\", parameters_source_algorithm)\ndef test_t_ilrma(spatial_algorithm: str, source_algorithm: str, save_feature: bool = False):\n    if spatial_algorithm == \"IPA\":\n        pytest.skip(reason=\"IPA is not supported for TILRMA.\")\n\n    rng = np.random.default_rng(0)\n    root = join(ilrma_root, \"t_ilrma\", spatial_algorithm, source_algorithm)\n\n    if save_feature:\n        (npz_input,) = load_regression_data(root=root, filenames=[\"input.npz\"])\n        spectrogram_tgt = None\n        n_basis = 2\n        dof = 1000\n        n_iter = 10\n    else:\n        npz_input, npz_target = load_regression_data(\n            root=root, filenames=[\"input.npz\", \"target.npz\"]\n        )\n        spectrogram_tgt = npz_target[\"spectrogram\"]\n        n_basis = npz_target[\"n_basis\"].item()\n        dof = npz_target[\"dof\"].item()\n        n_iter = npz_target[\"n_iter\"].item()\n\n    spectrogram_mix = npz_input[\"spectrogram\"]\n\n    if save_feature:\n        n_sources, n_bins, n_frames = spectrogram_mix.shape\n\n        basis = rng.random((n_sources, n_bins, n_basis))\n        activation = rng.random((n_sources, n_basis, n_frames))\n    else:\n        basis = npz_target[\"basis\"]\n        activation = npz_target[\"activation\"]\n\n    ilrma = TILRMA(\n        n_basis=n_basis,\n        dof=dof,\n        spatial_algorithm=spatial_algorithm,\n        source_algorithm=source_algorithm,\n        rng=rng,\n    )\n    spectrogram_est = ilrma(\n        spectrogram_mix,\n        n_iter=n_iter,\n        basis=basis,\n        activation=activation,\n    )\n\n    if save_feature:\n        makedirs(root, exist_ok=True)\n        np.savez(\n            join(root, \"target.npz\"),\n            spectrogram=spectrogram_est,\n            basis=basis,\n            activation=activation,\n            n_basis=n_basis,\n            dof=dof,\n            n_iter=n_iter,\n        )\n    else:\n        assert np.allclose(spectrogram_est, spectrogram_tgt, atol=1e-7), np.max(\n            np.abs(spectrogram_est - spectrogram_tgt)\n        )\n\n\n@pytest.mark.parametrize(\"spatial_algorithm\", parameters_spatial_algorithm)\n@pytest.mark.parametrize(\"source_algorithm\", parameters_source_algorithm)\ndef test_ggd_ilrma(spatial_algorithm: str, source_algorithm: str, save_feature: bool = False):\n    if spatial_algorithm == \"IPA\":\n        pytest.skip(reason=\"IPA is not supported for GGDILRMA.\")\n\n    if source_algorithm == \"ME\":\n        pytest.skip(reason=\"ME is not supported for GGDILRMA.\")\n\n    rng = np.random.default_rng(0)\n    root = join(ilrma_root, \"ggd_ilrma\", spatial_algorithm, source_algorithm)\n\n    if save_feature:\n        (npz_input,) = load_regression_data(root=root, filenames=[\"input.npz\"])\n        spectrogram_tgt = None\n        n_basis = 2\n        beta = 1.5\n        n_iter = 10\n    else:\n        npz_input, npz_target = load_regression_data(\n            root=root, filenames=[\"input.npz\", \"target.npz\"]\n        )\n        spectrogram_tgt = npz_target[\"spectrogram\"]\n        n_basis = npz_target[\"n_basis\"].item()\n        beta = npz_target[\"beta\"].item()\n        n_iter = npz_target[\"n_iter\"].item()\n\n    spectrogram_mix = npz_input[\"spectrogram\"]\n\n    if save_feature:\n        n_sources, n_bins, n_frames = spectrogram_mix.shape\n\n        basis = rng.random((n_sources, n_bins, n_basis))\n        activation = rng.random((n_sources, n_basis, n_frames))\n    else:\n        basis = npz_target[\"basis\"]\n        activation = npz_target[\"activation\"]\n\n    ilrma = GGDILRMA(\n        n_basis=n_basis,\n        beta=beta,\n        spatial_algorithm=spatial_algorithm,\n        source_algorithm=source_algorithm,\n        rng=rng,\n    )\n    spectrogram_est = ilrma(\n        spectrogram_mix,\n        n_iter=n_iter,\n        basis=basis,\n        activation=activation,\n    )\n\n    if save_feature:\n        makedirs(root, exist_ok=True)\n        np.savez(\n            join(root, \"target.npz\"),\n            spectrogram=spectrogram_est,\n            basis=basis,\n            activation=activation,\n            n_basis=n_basis,\n            beta=beta,\n            n_iter=n_iter,\n        )\n    else:\n        assert np.allclose(spectrogram_est, spectrogram_tgt, atol=1e-7), np.max(\n            np.abs(spectrogram_est - spectrogram_tgt)\n        )\n\n\ndef save_all_features() -> None:\n    for spatial_algorithm in parameters_spatial_algorithm:\n        for source_algorithm in parameters_source_algorithm:\n            test_gauss_ilrma(\n                spatial_algorithm=spatial_algorithm,\n                source_algorithm=source_algorithm,\n                save_feature=True,\n            )\n\n    for spatial_algorithm in parameters_spatial_algorithm:\n        if spatial_algorithm == \"IPA\":\n            continue\n\n        for source_algorithm in parameters_source_algorithm:\n            test_t_ilrma(\n                spatial_algorithm=spatial_algorithm,\n                source_algorithm=source_algorithm,\n                save_feature=True,\n            )\n\n    for spatial_algorithm in parameters_spatial_algorithm:\n        if spatial_algorithm == \"IPA\":\n            continue\n\n        for source_algorithm in parameters_source_algorithm:\n            if source_algorithm == \"ME\":\n                continue\n\n            test_ggd_ilrma(\n                spatial_algorithm=spatial_algorithm,\n                source_algorithm=source_algorithm,\n                save_feature=True,\n            )\n\n\nif __name__ == \"__main__\":\n    save_all_features()\n"
  },
  {
    "path": "tests/regression/bss/test_ipsdta.py",
    "content": "import sys\nfrom os import makedirs\nfrom os.path import dirname, join, realpath\n\nimport numpy as np\nimport pytest\n\nfrom ssspy.bss.ipsdta import TIPSDTA, GaussIPSDTA\n\nssspy_tests_dir = dirname(dirname(dirname(realpath(__file__))))\nsys.path.append(ssspy_tests_dir)\n\nfrom dummy.utils.dataset import load_regression_data  # noqa: E402\n\nipsdta_root = join(ssspy_tests_dir, \"mock\", \"regression\", \"bss\", \"ipsdta\")\n\nparameters_spatial_algorithm = [\"VCD\"]\nparameters_source_algorithm = [\"EM\", \"MM\"]\n\n\n@pytest.mark.parametrize(\"spatial_algorithm\", parameters_spatial_algorithm)\n@pytest.mark.parametrize(\"source_algorithm\", parameters_source_algorithm)\ndef test_gauss_ipsdta(spatial_algorithm: str, source_algorithm: str, save_feature: bool = False):\n    if source_algorithm == \"EM\":\n        pytest.skip(reason=\"EM is not supported for GaussIPSDTA.\")\n\n    rng = np.random.default_rng(0)\n    root = join(ipsdta_root, \"gauss_ipsdta\", spatial_algorithm, source_algorithm)\n\n    if save_feature:\n        (npz_input,) = load_regression_data(root=root, filenames=[\"input.npz\"])\n        spectrogram_tgt = None\n        n_basis = 2\n        n_iter = 10\n    else:\n        npz_input, npz_target = load_regression_data(\n            root=root, filenames=[\"input.npz\", \"target.npz\"]\n        )\n        spectrogram_tgt = npz_target[\"spectrogram\"]\n        n_basis = npz_target[\"n_basis\"].item()\n        n_iter = npz_target[\"n_iter\"].item()\n\n    spectrogram_mix = npz_input[\"spectrogram\"]\n\n    if save_feature:\n        n_blocks = spectrogram_mix.shape[1] // 2\n        n_sources, n_bins, n_frames = spectrogram_mix.shape\n\n        n_neighbors = n_bins // n_blocks\n        n_remains = n_bins % n_blocks\n\n        eye = np.eye(n_neighbors, dtype=np.complex128)\n        rand = rng.random((n_sources, n_basis, n_blocks - n_remains, n_neighbors))\n        T = rand[..., np.newaxis] * eye\n\n        if n_remains > 0:\n            eye = np.eye(n_neighbors + 1, dtype=np.complex128)\n            rand = rng.random((n_sources, n_basis, n_remains, n_neighbors + 1))\n            T_high = rand[..., np.newaxis] * eye\n\n            T = T, T_high\n\n        V = rng.random((n_sources, n_basis, n_frames))\n\n        basis = T\n        activation = V\n    else:\n        n_blocks = npz_target[\"n_blocks\"].item()\n\n        if \"basis\" in npz_target.keys():\n            basis = npz_target[\"basis\"]\n        else:\n            basis_low = npz_target[\"basis_low\"]\n            basis_high = npz_target[\"basis_high\"]\n            basis = basis_low, basis_high\n\n        activation = npz_target[\"activation\"]\n\n    ipsdta = GaussIPSDTA(\n        n_basis=n_basis,\n        n_blocks=n_blocks,\n        spatial_algorithm=spatial_algorithm,\n        source_algorithm=source_algorithm,\n        rng=rng,\n    )\n    spectrogram_est = ipsdta(\n        spectrogram_mix,\n        n_iter=n_iter,\n        basis=basis,\n        activation=activation,\n    )\n\n    if isinstance(basis, tuple):\n        basis_low, basis_high = basis\n        basis = {\n            \"basis_low\": basis_low,\n            \"basis_high\": basis_high,\n        }\n    else:\n        basis = {\n            \"basis\": basis,\n        }\n\n    if save_feature:\n        makedirs(root, exist_ok=True)\n        np.savez(\n            join(root, \"target.npz\"),\n            spectrogram=spectrogram_est,\n            **basis,\n            activation=activation,\n            n_basis=n_basis,\n            n_blocks=n_blocks,\n            n_iter=n_iter,\n        )\n    else:\n        assert np.allclose(spectrogram_est, spectrogram_tgt, atol=1e-7), np.max(\n            np.abs(spectrogram_est - spectrogram_tgt)\n        )\n\n\n@pytest.mark.parametrize(\"spatial_algorithm\", parameters_spatial_algorithm)\n@pytest.mark.parametrize(\"source_algorithm\", parameters_source_algorithm)\ndef test_t_ipsdta(spatial_algorithm: str, source_algorithm: str, save_feature: bool = False):\n    if source_algorithm == \"EM\":\n        pytest.skip(reason=\"EM is not supported for TIPSDTA.\")\n\n    rng = np.random.default_rng(0)\n    root = join(ipsdta_root, \"t_ipsdta\", spatial_algorithm, source_algorithm)\n\n    if save_feature:\n        (npz_input,) = load_regression_data(root=root, filenames=[\"input.npz\"])\n        spectrogram_tgt = None\n        n_basis = 2\n        dof = 1000\n        n_iter = 10\n    else:\n        npz_input, npz_target = load_regression_data(\n            root=root, filenames=[\"input.npz\", \"target.npz\"]\n        )\n        spectrogram_tgt = npz_target[\"spectrogram\"]\n        n_basis = npz_target[\"n_basis\"].item()\n        dof = npz_target[\"dof\"].item()\n        n_iter = npz_target[\"n_iter\"].item()\n\n    spectrogram_mix = npz_input[\"spectrogram\"]\n\n    if save_feature:\n        n_blocks = spectrogram_mix.shape[1] // 2\n        n_sources, n_bins, n_frames = spectrogram_mix.shape\n\n        n_neighbors = n_bins // n_blocks\n        n_remains = n_bins % n_blocks\n\n        eye = np.eye(n_neighbors, dtype=np.complex128)\n        rand = rng.random((n_sources, n_basis, n_blocks - n_remains, n_neighbors))\n        T = rand[..., np.newaxis] * eye\n\n        if n_remains > 0:\n            eye = np.eye(n_neighbors + 1, dtype=np.complex128)\n            rand = rng.random((n_sources, n_basis, n_remains, n_neighbors + 1))\n            T_high = rand[..., np.newaxis] * eye\n\n            T = T, T_high\n\n        V = rng.random((n_sources, n_basis, n_frames))\n\n        basis = T\n        activation = V\n    else:\n        n_blocks = npz_target[\"n_blocks\"].item()\n\n        if \"basis\" in npz_target.keys():\n            basis = npz_target[\"basis\"]\n        else:\n            basis_low = npz_target[\"basis_low\"]\n            basis_high = npz_target[\"basis_high\"]\n            basis = basis_low, basis_high\n\n        activation = npz_target[\"activation\"]\n\n    ipsdta = TIPSDTA(\n        n_basis=n_basis,\n        n_blocks=n_blocks,\n        dof=dof,\n        spatial_algorithm=spatial_algorithm,\n        source_algorithm=source_algorithm,\n        rng=rng,\n    )\n    spectrogram_est = ipsdta(\n        spectrogram_mix,\n        n_iter=n_iter,\n        basis=basis,\n        activation=activation,\n    )\n\n    if isinstance(basis, tuple):\n        basis_low, basis_high = basis\n        basis = {\n            \"basis_low\": basis_low,\n            \"basis_high\": basis_high,\n        }\n    else:\n        basis = {\n            \"basis\": basis,\n        }\n\n    if save_feature:\n        makedirs(root, exist_ok=True)\n        np.savez(\n            join(root, \"target.npz\"),\n            spectrogram=spectrogram_est,\n            **basis,\n            activation=activation,\n            n_basis=n_basis,\n            n_blocks=n_blocks,\n            dof=dof,\n            n_iter=n_iter,\n        )\n    else:\n        assert np.allclose(spectrogram_est, spectrogram_tgt, atol=1e-7), np.max(\n            np.abs(spectrogram_est - spectrogram_tgt)\n        )\n\n\ndef save_all_features() -> None:\n    for spatial_algorithm in parameters_spatial_algorithm:\n        for source_algorithm in parameters_source_algorithm:\n            if source_algorithm == \"EM\":\n                continue\n\n            test_gauss_ipsdta(\n                spatial_algorithm=spatial_algorithm,\n                source_algorithm=source_algorithm,\n                save_feature=True,\n            )\n\n    for spatial_algorithm in parameters_spatial_algorithm:\n        for source_algorithm in parameters_source_algorithm:\n            if source_algorithm == \"EM\":\n                continue\n\n            test_t_ipsdta(\n                spatial_algorithm=spatial_algorithm,\n                source_algorithm=source_algorithm,\n                save_feature=True,\n            )\n\n\nif __name__ == \"__main__\":\n    save_all_features()\n"
  },
  {
    "path": "tests/regression/bss/test_iva.py",
    "content": "import sys\nfrom os import makedirs\nfrom os.path import dirname, join, realpath\n\nimport numpy as np\nimport pytest\n\nfrom ssspy.bss.iva import AuxIVA, FastIVA, GradIVA, NaturalGradIVA\n\nssspy_tests_dir = dirname(dirname(dirname(realpath(__file__))))\nsys.path.append(ssspy_tests_dir)\n\nfrom dummy.utils.dataset import load_regression_data  # noqa: E402\n\niva_root = join(ssspy_tests_dir, \"mock\", \"regression\", \"bss\", \"iva\")\n\nparameters_is_holonomic = [True, False]\nparameters_spatial_algorithm = [\"IP1\", \"IP2\", \"ISS1\", \"ISS2\", \"IPA\"]\n\n\n@pytest.mark.parametrize(\"is_holonomic\", parameters_is_holonomic)\ndef test_grad_iva(is_holonomic: bool, save_feature: bool = False):\n    if is_holonomic:\n        root = join(iva_root, \"grad_iva\", \"holonomic\")\n    else:\n        root = join(iva_root, \"grad_iva\", \"nonholonomic\")\n\n    if save_feature:\n        (npz_input,) = load_regression_data(root=root, filenames=[\"input.npz\"])\n        spectrogram_tgt = None\n        n_iter = 10\n    else:\n        npz_input, npz_target = load_regression_data(\n            root=root, filenames=[\"input.npz\", \"target.npz\"]\n        )\n        spectrogram_tgt = npz_target[\"spectrogram\"]\n        n_iter = npz_target[\"n_iter\"].item()\n\n    spectrogram_mix = npz_input[\"spectrogram\"]\n\n    def contrast_fn(y: np.ndarray) -> np.ndarray:\n        r\"\"\"Contrast function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_bins, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_frames).\n        \"\"\"\n        return 2 * np.linalg.norm(y, axis=1)\n\n    def score_fn(y) -> np.ndarray:\n        r\"\"\"Score function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_bins, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_bins, n_frames).\n        \"\"\"\n        norm = np.linalg.norm(y, axis=1, keepdims=True)\n        norm = np.maximum(norm, 1e-10)\n        return y / norm\n\n    iva = GradIVA(\n        contrast_fn=contrast_fn,\n        score_fn=score_fn,\n        is_holonomic=is_holonomic,\n    )\n    spectrogram_est = iva(spectrogram_mix, n_iter=n_iter)\n\n    if save_feature:\n        makedirs(root, exist_ok=True)\n        np.savez(\n            join(root, \"target.npz\"),\n            spectrogram=spectrogram_est,\n            n_iter=n_iter,\n        )\n    else:\n        assert np.allclose(spectrogram_est, spectrogram_tgt, atol=1e-7), np.max(\n            np.abs(spectrogram_est - spectrogram_tgt)\n        )\n\n\n@pytest.mark.parametrize(\"is_holonomic\", parameters_is_holonomic)\ndef test_natural_grad_iva(is_holonomic: bool, save_feature: bool = False):\n    if is_holonomic:\n        root = join(iva_root, \"natural_grad_iva\", \"holonomic\")\n    else:\n        root = join(iva_root, \"natural_grad_iva\", \"nonholonomic\")\n\n    if save_feature:\n        (npz_input,) = load_regression_data(root=root, filenames=[\"input.npz\"])\n        spectrogram_tgt = None\n        n_iter = 10\n    else:\n        npz_input, npz_target = load_regression_data(\n            root=root, filenames=[\"input.npz\", \"target.npz\"]\n        )\n        spectrogram_tgt = npz_target[\"spectrogram\"]\n        n_iter = npz_target[\"n_iter\"].item()\n\n    spectrogram_mix = npz_input[\"spectrogram\"]\n\n    def contrast_fn(y: np.ndarray) -> np.ndarray:\n        r\"\"\"Contrast function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_bins, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_frames).\n        \"\"\"\n        return 2 * np.linalg.norm(y, axis=1)\n\n    def score_fn(y) -> np.ndarray:\n        r\"\"\"Score function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_bins, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_bins, n_frames).\n        \"\"\"\n        norm = np.linalg.norm(y, axis=1, keepdims=True)\n        norm = np.maximum(norm, 1e-10)\n        return y / norm\n\n    iva = NaturalGradIVA(\n        contrast_fn=contrast_fn,\n        score_fn=score_fn,\n        is_holonomic=is_holonomic,\n    )\n    spectrogram_est = iva(spectrogram_mix, n_iter=n_iter)\n\n    if save_feature:\n        makedirs(root, exist_ok=True)\n        np.savez(\n            join(root, \"target.npz\"),\n            spectrogram=spectrogram_est,\n            n_iter=n_iter,\n        )\n    else:\n        assert np.allclose(spectrogram_est, spectrogram_tgt, atol=1e-7), np.max(\n            np.abs(spectrogram_est - spectrogram_tgt)\n        )\n\n\n@pytest.mark.parametrize(\"spatial_algorithm\", parameters_spatial_algorithm)\ndef test_aux_iva(spatial_algorithm: str, save_feature: bool = False):\n    root = join(iva_root, \"aux_iva\", spatial_algorithm)\n\n    if save_feature:\n        (npz_input,) = load_regression_data(root=root, filenames=[\"input.npz\"])\n        spectrogram_tgt = None\n        n_iter = 10\n    else:\n        npz_input, npz_target = load_regression_data(\n            root=root, filenames=[\"input.npz\", \"target.npz\"]\n        )\n        spectrogram_tgt = npz_target[\"spectrogram\"]\n        n_iter = npz_target[\"n_iter\"].item()\n\n    spectrogram_mix = npz_input[\"spectrogram\"]\n\n    def contrast_fn(y: np.ndarray) -> np.ndarray:\n        r\"\"\"Contrast function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_bins, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_frames).\n        \"\"\"\n        return 2 * np.linalg.norm(y, axis=1)\n\n    def d_contrast_fn(y) -> np.ndarray:\n        r\"\"\"Derivative of contrast function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_frames).\n        \"\"\"\n        return 2 * np.ones_like(y)\n\n    iva = AuxIVA(\n        spatial_algorithm=spatial_algorithm,\n        contrast_fn=contrast_fn,\n        d_contrast_fn=d_contrast_fn,\n    )\n    spectrogram_est = iva(spectrogram_mix, n_iter=n_iter)\n\n    if save_feature:\n        makedirs(root, exist_ok=True)\n        np.savez(\n            join(root, \"target.npz\"),\n            spectrogram=spectrogram_est,\n            n_iter=n_iter,\n        )\n    else:\n        assert np.allclose(spectrogram_est, spectrogram_tgt, atol=1e-7), np.max(\n            np.abs(spectrogram_est - spectrogram_tgt)\n        )\n\n\ndef test_fast_iva(save_feature: bool = False):\n    root = join(iva_root, \"fast_iva\")\n\n    if save_feature:\n        (npz_input,) = load_regression_data(root=root, filenames=[\"input.npz\"])\n        spectrogram_tgt = None\n        n_iter = 5\n    else:\n        npz_input, npz_target = load_regression_data(\n            root=root, filenames=[\"input.npz\", \"target.npz\"]\n        )\n        spectrogram_tgt = npz_target[\"spectrogram\"]\n        n_iter = npz_target[\"n_iter\"].item()\n\n    spectrogram_mix = npz_input[\"spectrogram\"]\n\n    def contrast_fn(y: np.ndarray) -> np.ndarray:\n        r\"\"\"Contrast function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_bins, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_frames).\n        \"\"\"\n        return 2 * np.linalg.norm(y, axis=1)\n\n    def d_contrast_fn(y) -> np.ndarray:\n        r\"\"\"Derivative of contrast function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_frames).\n        \"\"\"\n        return 2 * np.ones_like(y)\n\n    def dd_contrast_fn(y) -> np.ndarray:\n        r\"\"\"Second roder derivative of contrast function.\n\n        Args:\n            y (np.ndarray):\n                The shape is (n_sources, n_frames).\n\n        Returns:\n            np.ndarray:\n                The shape is (n_sources, n_frames).\n        \"\"\"\n        return 2 * np.zeros_like(y)\n\n    iva = FastIVA(\n        contrast_fn=contrast_fn,\n        d_contrast_fn=d_contrast_fn,\n        dd_contrast_fn=dd_contrast_fn,\n    )\n    spectrogram_est = iva(spectrogram_mix, n_iter=n_iter)\n\n    if save_feature:\n        makedirs(root, exist_ok=True)\n        np.savez(\n            join(root, \"target.npz\"),\n            spectrogram=spectrogram_est,\n            n_iter=n_iter,\n        )\n    else:\n        assert np.allclose(spectrogram_est, spectrogram_tgt, atol=1e-7), np.max(\n            np.abs(spectrogram_est - spectrogram_tgt)\n        )\n\n\ndef save_all_features() -> None:\n    for is_holonomic in parameters_is_holonomic:\n        test_grad_iva(is_holonomic=is_holonomic, save_feature=True)\n\n    for is_holonomic in parameters_is_holonomic:\n        test_natural_grad_iva(is_holonomic=is_holonomic, save_feature=True)\n\n    for spatial_algorithm in parameters_spatial_algorithm:\n        test_aux_iva(spatial_algorithm=spatial_algorithm, save_feature=True)\n\n    test_fast_iva(save_feature=True)\n\n\nif __name__ == \"__main__\":\n    save_all_features()\n"
  },
  {
    "path": "tests/regression/bss/test_mnmf.py",
    "content": "import sys\nfrom os import makedirs\nfrom os.path import dirname, join, realpath\n\nimport numpy as np\nimport pytest\n\nfrom ssspy.bss.mnmf import FastGaussMNMF, GaussMNMF\n\nssspy_tests_dir = dirname(dirname(dirname(realpath(__file__))))\nsys.path.append(ssspy_tests_dir)\n\nfrom dummy.utils.dataset import load_regression_data  # noqa: E402\n\nmnmf_root = join(ssspy_tests_dir, \"mock\", \"regression\", \"bss\", \"mnmf\")\n\nparameters_diagonalizer_algorithm = [\"IP1\", \"IP2\"]\n\n\ndef test_gauss_mnmf(save_feature: bool = False):\n    rng = np.random.default_rng(0)\n    root = join(mnmf_root, \"gauss_mnmf\")\n\n    if save_feature:\n        (npz_input,) = load_regression_data(root=root, filenames=[\"input.npz\"])\n        spectrogram_tgt = None\n        n_basis = 2\n        n_iter = 5\n    else:\n        npz_input, npz_target = load_regression_data(\n            root=root, filenames=[\"input.npz\", \"target.npz\"]\n        )\n        spectrogram_tgt = npz_target[\"spectrogram\"]\n        n_basis = npz_target[\"n_basis\"].item()\n        n_iter = npz_target[\"n_iter\"].item()\n\n    spectrogram_mix = npz_input[\"spectrogram\"]\n\n    n_channels, n_bins, n_frames = spectrogram_mix.shape\n    n_sources = n_channels\n\n    if save_feature:\n        basis = rng.random((n_sources, n_bins, n_basis))\n        activation = rng.random((n_sources, n_basis, n_frames))\n\n        spatial = np.eye(n_channels, dtype=spectrogram_mix.dtype)\n        trace = np.trace(spatial, axis1=-2, axis2=-1)\n        spatial = spatial / np.real(trace)\n        spatial = np.tile(spatial, reps=(n_sources, n_bins, 1, 1))\n    else:\n        basis = npz_target[\"basis\"]\n        activation = npz_target[\"activation\"]\n        spatial = npz_target[\"spatial\"]\n\n    mnmf = GaussMNMF(\n        n_basis=n_basis,\n        n_sources=n_sources,\n        rng=rng,\n    )\n    spectrogram_est = mnmf(\n        spectrogram_mix,\n        n_iter=n_iter,\n        basis=basis,\n        activation=activation,\n        spatial=spatial,\n    )\n\n    if save_feature:\n        makedirs(root, exist_ok=True)\n        np.savez(\n            join(root, \"target.npz\"),\n            spectrogram=spectrogram_est,\n            basis=basis,\n            activation=activation,\n            n_basis=n_basis,\n            spatial=spatial,\n            n_iter=n_iter,\n        )\n    else:\n        assert np.allclose(spectrogram_est, spectrogram_tgt, atol=1e-7), np.max(\n            np.abs(spectrogram_est - spectrogram_tgt)\n        )\n\n\n@pytest.mark.parametrize(\"diagonalizer_algorithm\", parameters_diagonalizer_algorithm)\ndef test_fast_gauss_mnmf(diagonalizer_algorithm: str, save_feature: bool = False):\n    rng = np.random.default_rng(0)\n    root = join(mnmf_root, \"fast_gauss_mnmf\", diagonalizer_algorithm)\n\n    if save_feature:\n        (npz_input,) = load_regression_data(root=root, filenames=[\"input.npz\"])\n        spectrogram_tgt = None\n        n_basis = 2\n        n_iter = 5\n    else:\n        npz_input, npz_target = load_regression_data(\n            root=root, filenames=[\"input.npz\", \"target.npz\"]\n        )\n        spectrogram_tgt = npz_target[\"spectrogram\"]\n        n_basis = npz_target[\"n_basis\"].item()\n        n_iter = npz_target[\"n_iter\"].item()\n\n    spectrogram_mix = npz_input[\"spectrogram\"]\n\n    n_channels, n_bins, n_frames = spectrogram_mix.shape\n    n_sources = n_channels\n\n    if save_feature:\n        basis = rng.random((n_sources, n_bins, n_basis))\n        activation = rng.random((n_sources, n_basis, n_frames))\n        spatial = rng.random((n_bins, n_sources, n_channels))\n        diagonalizer = np.eye(n_channels, dtype=np.complex128)\n        diagonalizer = np.tile(diagonalizer, reps=(n_bins, 1, 1))\n    else:\n        basis = npz_target[\"basis\"]\n        activation = npz_target[\"activation\"]\n        spatial = npz_target[\"spatial\"]\n        diagonalizer = npz_target[\"diagonalizer\"]\n\n    mnmf = FastGaussMNMF(\n        n_basis=n_basis,\n        n_sources=n_sources,\n        diagonalizer_algorithm=diagonalizer_algorithm,\n        rng=rng,\n    )\n    spectrogram_est = mnmf(\n        spectrogram_mix,\n        n_iter=n_iter,\n        basis=basis,\n        activation=activation,\n        spatial=spatial,\n        diagonalizer=diagonalizer,\n    )\n\n    if save_feature:\n        makedirs(root, exist_ok=True)\n        np.savez(\n            join(root, \"target.npz\"),\n            spectrogram=spectrogram_est,\n            basis=basis,\n            activation=activation,\n            spatial=spatial,\n            diagonalizer=diagonalizer,\n            n_basis=n_basis,\n            n_iter=n_iter,\n        )\n    else:\n        assert np.allclose(spectrogram_est, spectrogram_tgt, atol=1e-7), np.max(\n            np.abs(spectrogram_est - spectrogram_tgt)\n        )\n\n\ndef save_all_features() -> None:\n    test_gauss_mnmf(save_feature=True)\n\n    for diagonalizer_algorithm in parameters_diagonalizer_algorithm:\n        test_fast_gauss_mnmf(diagonalizer_algorithm=diagonalizer_algorithm, save_feature=True)\n\n\nif __name__ == \"__main__\":\n    save_all_features()\n"
  },
  {
    "path": "tests/scripts/download_all.py",
    "content": "# It is expected to run from root ssspy directory\nimport sys\nfrom os.path import dirname, realpath\n\ntests_dir = dirname(dirname(realpath(__file__)))\nsys.path.append(tests_dir)\n\nfrom dummy.utils.dataset import download_sample_speech_data  # noqa: E402\nfrom dummy.utils.dataset import download_ssspy_data  # noqa: E402\n\n\ndef download_all() -> None:\n    # Download sample speech data\n    conditions = [\n        {\"n_sources\": 2, \"sisec2010_tag\": \"dev1_female3\"},\n        {\"n_sources\": 3, \"sisec2010_tag\": \"dev1_female3\"},\n        {\"n_sources\": 4, \"sisec2010_tag\": \"dev1_female4\"},\n    ]\n    max_durations = [0.1, 0.5]\n\n    for kwargs in conditions:\n        for max_duration in max_durations:\n            download_sample_speech_data(max_duration=max_duration, **kwargs)\n\n    # Download sample audio for tests of IO\n    paths = [\n        \"audio/monoral_16k_5sec.wav\",\n        \"audio/stereo_16k_5sec.wav\",\n    ]\n    template_filename = \"./tests/mock/{}\"\n\n    for path in paths:\n        filename = template_filename.format(path)\n        download_ssspy_data(path, filename=filename)\n\n\nif __name__ == \"__main__\":\n    download_all()\n"
  }
]