[
  {
    "path": ".flake8",
    "content": "[flake8]\nmax-line-length = 120\nignore = E305,E402,E721,E741,F403,F405,F821,F841,F999,W503,W504\nexclude = beta_*.py, .git/\n"
  },
  {
    "path": ".gitignore",
    "content": "# Intellij Idea\n.idea/\n*.iml\n\n# python\n__pycache__/\n*.pyc\n## pyenv\n.python-version\n## python virtual environments\nenv/\nvenv/\n.env/\n.venv\n\n# OSX\n.DS_Store\n._*\n\n# swap files for vim users\n[._]*.swp\n[._]swp\n\n# ctags\ntags\n"
  },
  {
    "path": ".travis.yml",
    "content": "language: python\npython:\n  - \"3.5\"\ninstall:\n  - pip install -e .[tests]\nscript:\n  - python -m pytest\n"
  },
  {
    "path": "README.md",
    "content": "# SUNSETTING torcnaudio-contrib\n\nWe made some progress in this repo and contributed to the original repo [pytorch/audio](https://github.com/pytorch/audio) which satisfied us. We happily stopped working here :) Please visit [pytorch/audio](https://github.com/pytorch/audio) for any issue/request!\n\n.\n\n.\n\n.\n\n.\n\n.\n\n.\n\n.\n\n.\n\n.\n\n.\n\n\n(We keep the existing content as below ⬇)\n\n\n# torchaudio-contrib\n\n\nGoal: To propose audio processing Pytorch codes with nice and easy-to-use APIs and functionality.\n\n:open_hands: This should be seen as a community based proposal and the basis for a discussion we should have inside the pytorch audio user community. Everyone should be welcome to join and discuss.\n\nOur motivation is:\n\n  - API design: Clear, readible names for class/functions/arguments, sensible default values, and shapes.\n      - Reference: [librosa](http://librosa.github.io/librosa/) (audio and MIR on Numpy), [kapre](https://github.com/keunwoochoi/kapre) (audio on Keras), [pytorch/audio](https://github.com/pytorch/audio) (audio on Pytorch)\n  - Fast processing on GPU\n  - Methodology: Both layer and functional\n    - Layers (`nn.Module`) for reusability and easier use\n    - and identical implementation with `Functionals`\n- Simple installation\n- Multi-channel support\n\n## Contribution\n\nMaking things quicker and open! We're `-contrib` repo, hence it's *easy to enter but hard to graduate*. \n\n 1. Make a new [Issue](https://github.com/keunwoochoi/torchaudio-contrib/issues) for a potential PR\n 2. Until it's in a good shape,\n    1. Make a PR with following the current conventions and unittest\n    2. Review-merge.\n 3. Based on it, make a PR to [torch/audio](https://github.com/pytorch/audio)  \n \n \n Discussion on how to contribute - https://github.com/keunwoochoi/torchaudio-contrib/issues/37\n\n## Current issues/future work\n- Better module/sub-module hierarchy\n- Complex number support\n- More time-frequency representations\n- Signal processing modules, e.g., vocoder\n- Augmentation\n\n# API suggestions\n\n## Notes\n  * Audio signals can be multi-channel\n  * `STFT`: short-time Fourier transform, outputing a complex-numbered representation\n  * `Spectrogram`: magnitudes of STFT\n  * `Melspectrogram`: mel-filterbank applied to `spectrogram`\n\n## Shapes\n  * audio signals: `(batch, channel, time)`\n      * E.g., `STFT` input shape\n      * Based on `torch.stft` input shape\n  * 2D representations: `(batch, channel, freq, time)`\n      * E.g., `STFT` output shape\n      * Channel-first, following torch convention.\n      * Then, `(freq, time)`, following `torch.stft`\n\n\n## Overview\n### `STFT`\n```python\nclass STFT(fft_len=2048, hop_len=None, frame_len=None, window=None, pad=0, pad_mode=\"reflect\", **kwargs)\ndef stft(signal, fft_len, hop_len, window, pad=0, pad_mode=\"reflect\", **kwargs)\n```\n\n### `MelFilterbank`\n```python\nclass MelFilterbank(num_bands=128, sample_rate=16000, min_freq=0.0, max_freq=None, num_bins=1025, htk=False)\ndef create_mel_filter(num_bands, sample_rate, min_freq, max_freq, num_bins, to_hertz, from_hertz)\n```\n\n### `Spectrogram`\n```python\ndef Spectrogram(fft_len=2048, hop_len=None, frame_len=None, window=None, pad=0, pad_mode=\"reflect\", power=1., **kwargs)\n```\nCreates an `nn.Sequential`:\n```\n>>> Sequential(\n>>>  (0): STFT(fft_len=2048, hop_len=512, frame_len=2048)\n>>>  (1): ComplexNorm(power=1.0)\n)\n```\n\n### `Melspectrogram`\n```python\ndef Melspectrogram(num_bands=128, sample_rate=16000, min_freq=0.0, max_freq=None, num_bins=None, htk=False, mel_filterbank=None, **kwargs)\n```\nCreates an `nn.Sequential`:\n```\n>>> Sequential(\n>>>  (0): STFT(fft_len=2048, hop_len=512, frame_len=2048)\n>>>  (1): ComplexNorm(power=2.0)\n>>>  (2): ApplyFilterbank()\n)\n```\n\n### `AmplitudeToDb`/`amplitude_to_db`\n```python\nclass AmplitudeToDb(ref=1.0, amin=1e-7)\ndef amplitude_to_db(x, ref=1.0, amin=1e-7)\n```\nArguments names and the default value of `ref` follow librosa. The default value of `amin` however follows Keras's float32 Epsilon, which seems making sense.\n\n### `DbToAmplitude`/`db_to_amplitude`\n```python\nclass DbToAmplitude(ref=1.0)\ndef db_to_amplitude(x, ref=1.0)\n```\n\n### `MuLawEncoding`/`mu_law_encoding`\n```python\nclass MuLawEncoding(n_quantize=256)\ndef mu_law_encoding(x, n_quantize=256)\n```\n\n### `MuLawDecoding`/`mu_law_decoding`\n```python\nclass MuLawDecoding(n_quantize=256)\ndef mu_law_decoding(x_mu, n_quantize=256)\n```\n\n----------\n\n# A Big Issue - Remove SoX Dependency\n\nWe propose to remove the SoX dependency because:\n\n* Many audio ML tasks don’t require the functionality included in Sox (filtering, cutting, effects)\n* Many issues in torchaudio are related to the installation with respect to Sox. While this could be simplified by a [conda build or a wheel](https://github.com/pytorch/builder/issues/279), it will continue being difficult to maintain the repo.\n* SOX doesn’t support MP4 containers, which makes it unusable for multi-stream audio\n* Loading speed is good with torchaudio but e.g. for __wav__, its not faster than other libraries (including cast to torch tensor) -- as in the graph below. See more detailed benchmarks [here](https://github.com/faroit/python_audio_loading_benchmark).\n\n![](https://raw.githubusercontent.com/faroit/python_audio_loading_benchmark/master/results/benchmark_pytorch.png)\n\n## Proposal\n\nIntroduce I/O backends and move the functions that depend on `_torch_sox` to a `backend_sox.py`, which is *not* required to install. Additionally, we could then introduce more backends like scipy.io or pysoundfile. Each backend then imports the (optional) lib within the backend file and each backend includes a minimum spec such as:\n\n```python\nimport _torch_sox\n\ndef load(...)\n    # returns audio, rate\ndef save(...)\n    # write file\ndef info(...)\n    # returns metadata without reading the full file  \n```\n\n### Backend proposals\n\n* `scipy.io` or `soundfile` as default for __wav__ files\n* `aubio` or `audioread` for __mp3__ and __mp4__\n\n\n### Installation\n\n```bash\npip install -e .\n```\n\n\n### Importing\n\nimport torchaudio_contrib\n\n\n## Authors\nKeunwoo Choi, Faro Stöter, Kiran Sanjeevan,  Jan Schlüter\n"
  },
  {
    "path": "requirements.txt",
    "content": "torch"
  },
  {
    "path": "setup.cfg",
    "content": "[tool:pytest]\nxfail_strict = true\n"
  },
  {
    "path": "setup.py",
    "content": "from setuptools import setup\n\nsetup(name='torchaudio_contrib',\n      version='0.1',\n      description='To propose audio processing Pytorch codes with nice and easy-to-use APIs and functionality',\n      url='https://github.com/keunwoochoi/torchaudio-contrib',\n      author='Keunwoo Choi, Faro Stöter, Kiran Sanjeevan, Jan Schlüter',\n      author_email='gnuchoi@gmail.com',\n      license='MIT',\n      install_requires=['torch'],\n      extras_require={'tests': ['pytest', 'librosa']},\n      packages=['torchaudio_contrib'],\n      zip_safe=False)\n"
  },
  {
    "path": "tests/test_functional.py",
    "content": "import pytest\nimport librosa\nimport numpy as np\nimport torch\n\nfrom torchaudio_contrib.functional import (stft, phase_vocoder, magphase, amplitude_to_db, db_to_amplitude,\n                                           complex_norm, mu_law_encoding, mu_law_decoding, apply_filterbank\n                                           )\n\n\nxfail = pytest.mark.xfail\n\n\ndef _num_stft_bins(signal_len, fft_len, hop_length, pad):\n    return (signal_len + 2 * pad - fft_len + hop_length) // hop_length\n\n\ndef _approx_all_equal(x, y, atol=1e-7):\n    return torch.all(torch.lt(torch.abs(torch.add(x, -y)), atol))\n\n\ndef _all_equal(x, y):\n    return torch.all(torch.eq(x, y))\n\n\n@pytest.mark.parametrize('fft_length', [512])\n@pytest.mark.parametrize('hop_length', [256])\n@pytest.mark.parametrize('waveform', [\n    (torch.randn(1, 100000)),\n    (torch.randn(1, 2, 100000)),\n    pytest.param(torch.randn(1, 100), marks=xfail(raises=RuntimeError)),\n])\n@pytest.mark.parametrize('pad_mode', [\n    # 'constant',\n    'reflect',\n])\ndef test_stft(waveform, fft_length, hop_length, pad_mode):\n    \"\"\"\n    Test STFT for multi-channel signals.\n\n    Padding: Value in having padding outside of torch.stft?\n    \"\"\"\n    pad = fft_length // 2\n    window = torch.hann_window(fft_length)\n    complex_spec = stft(waveform, fft_length=fft_length, hop_length=hop_length, window=window, pad_mode=pad_mode)\n    mag_spec, phase_spec = magphase(complex_spec)\n\n    # == Test shape\n    expected_size = list(waveform.size()[:-1])\n    expected_size += [fft_length // 2 + 1, _num_stft_bins(\n        waveform.size(-1), fft_length, hop_length, pad), 2]\n    assert complex_spec.dim() == waveform.dim() + 2\n    assert complex_spec.size() == torch.Size(expected_size)\n\n    # == Test values\n    fft_config = dict(n_fft=fft_length, hop_length=hop_length, pad_mode=pad_mode)\n    # note that librosa *automatically* pad with fft_length // 2.\n    expected_complex_spec = np.apply_along_axis(librosa.stft, -1,\n                                                waveform.numpy(), **fft_config)\n    expected_mag_spec, _ = librosa.magphase(expected_complex_spec)\n    # Convert torch to np.complex\n    complex_spec = complex_spec.numpy()\n    complex_spec = complex_spec[..., 0] + 1j * complex_spec[..., 1]\n\n    assert np.allclose(complex_spec, expected_complex_spec, atol=1e-5)\n    assert np.allclose(mag_spec.numpy(), expected_mag_spec, atol=1e-5)\n\n\n@pytest.mark.parametrize('rate', [0.5, 1.01, 1.3])\n@pytest.mark.parametrize('complex_specgrams', [\n    torch.randn(1, 2, 1025, 400, 2),\n    torch.randn(1, 1025, 400, 2)\n])\n@pytest.mark.parametrize('hop_length', [256])\ndef test_phase_vocoder(complex_specgrams, rate, hop_length):\n\n    class use_double_precision:\n        def __enter__(self):\n            self.default_dtype = torch.get_default_dtype()\n            torch.set_default_dtype(torch.float64)\n\n        def __exit__(self, type, value, traceback):\n            torch.set_default_dtype(self.default_dtype)\n\n    # Due to cummulative sum, numerical error in using torch.float32 will\n    # result in bottom right values of the stretched sectrogram to not\n    # match with librosa\n    with use_double_precision():\n\n        complex_specgrams = complex_specgrams.type(torch.get_default_dtype())\n\n        phase_advance = torch.linspace(0, np.pi * hop_length, complex_specgrams.shape[-3])[..., None]\n        complex_specgrams_stretch = phase_vocoder(complex_specgrams, rate=rate, phase_advance=phase_advance)\n\n        # == Test shape\n        expected_size = list(complex_specgrams.size())\n        expected_size[-2] = int(np.ceil(expected_size[-2] / rate))\n\n        assert complex_specgrams.dim() == complex_specgrams_stretch.dim()\n        assert complex_specgrams_stretch.size() == torch.Size(expected_size)\n\n        # == Test values\n        index = [0] * (complex_specgrams.dim() - 3) + [slice(None)] * 3\n        mono_complex_specgram = complex_specgrams[index].numpy()\n        mono_complex_specgram = mono_complex_specgram[..., 0] + \\\n            mono_complex_specgram[..., 1] * 1j\n        expected_complex_stretch = librosa.phase_vocoder(\n            mono_complex_specgram,\n            rate=rate,\n            hop_length=hop_length)\n\n        complex_stretch = complex_specgrams_stretch[index].numpy()\n        complex_stretch = complex_stretch[..., 0] + \\\n            1j * complex_stretch[..., 1]\n        assert np.allclose(complex_stretch,\n                           expected_complex_stretch, atol=1e-5)\n\n\n@pytest.mark.parametrize('complex_tensor', [\n    torch.randn(1, 2, 1025, 400, 2),\n    torch.randn(1025, 400, 2)\n])\n@pytest.mark.parametrize('power', [1, 2, 0.7])\ndef test_complex_norm(complex_tensor, power):\n    expected_norm_tensor = complex_tensor.pow(2).sum(-1).pow(power / 2)\n    norm_tensor = complex_norm(complex_tensor, power)\n\n    assert _approx_all_equal(expected_norm_tensor, norm_tensor, atol=1e-5)\n\n\n@pytest.mark.parametrize('new_len', [120, 36])\n@pytest.mark.parametrize('mag_spec', [\n    torch.randn(1, 257, 391),\n    torch.randn(1, 2, 257, 391),\n])\ndef test_apply_filterbank(mag_spec, new_len):\n    filterbank = torch.randn(mag_spec.size(-2), new_len)\n    mag_spec_filterbanked = apply_filterbank(mag_spec, filterbank)\n    assert mag_spec.size(-1) == mag_spec_filterbanked.size(-1)\n    assert mag_spec_filterbanked.size(-2) == new_len\n    assert mag_spec.dim() == mag_spec_filterbanked.dim()\n\n\n@pytest.mark.parametrize('amplitude,db', [\n    (torch.Tensor([0.000001, 0.0001, 0.1, 1.0, 10.0, 1000000.0]),\n     torch.Tensor([-60.0, -40.0, -10.0, 0.0, 10.0, 60.0]))\n])\ndef test_amplitude_db(amplitude, db):\n    \"\"\"Test amplitude_to_db and db_to_amplitude.\"\"\"\n    amplitude = np.sqrt(amplitude)\n    assert _approx_all_equal(db, amplitude_to_db(amplitude, ref=1.0))\n    assert _approx_all_equal(amplitude, db_to_amplitude(db, ref=1.0))\n    # both ways\n    assert _approx_all_equal(db_to_amplitude(amplitude_to_db(amplitude, ref=1.0), ref=1.0),\n                             amplitude)\n    assert _approx_all_equal(amplitude_to_db(db_to_amplitude(db, ref=1.0), ref=1.0),\n                             db,\n                             atol=1e-5)\n\n\n@pytest.mark.parametrize('waveform', [\n    torch.randn(1, 100000),\n    (torch.randn(1, 2, 100000)),\n])\n@pytest.mark.parametrize('n_quantize', [256])\ndef test_mu_law(waveform, n_quantize):\n    \"\"\"test mu-law encoding and decoding\"\"\"\n\n    def _test_mu_encoding(waveform, n_quantize):\n\n        waveform = 2 * (waveform - 0.5)\n        # manual computation\n        mu = torch.tensor(n_quantize - 1, dtype=waveform.dtype)\n        waveform_mu = waveform.sign() * torch.log1p(mu * waveform.abs()) / torch.log1p(mu)\n        waveform_mu = ((waveform_mu + 1) / 2 * mu + 0.5).long()\n\n        assert _all_equal(mu_law_encoding(waveform, n_quantize),\n                          waveform_mu)\n\n    def _test_mu_decoding(waveform, n_quantize):\n\n        waveform_mu = torch.randint(low=0, high=n_quantize - 1,\n                                    size=(1, 1024))\n\n        # manual computation\n        waveform_mu = waveform_mu.float()\n        mu = torch.tensor(n_quantize - 1, dtype=waveform_mu.dtype)  # confused about dtype here..\n\n        waveform = (waveform_mu / mu) * 2 - 1.\n        waveform = waveform.sign() * (torch.exp(waveform.abs() * torch.log1p(mu)) - 1.) / mu\n\n        assert _all_equal(mu_law_decoding(waveform_mu, n_quantize),\n                          waveform)\n\n    def _test_both_ways(waveform, n_quantize):\n        waveform_mu = torch.randint(low=0, high=n_quantize - 1,\n                                    size=(1, 1024))\n        assert _all_equal(waveform_mu,\n                          mu_law_encoding(mu_law_decoding(waveform_mu, n_quantize), n_quantize))\n\n    _test_mu_encoding(waveform, n_quantize)\n    _test_mu_decoding(waveform, n_quantize)\n    _test_both_ways(waveform, n_quantize)\n"
  },
  {
    "path": "tests/test_layers.py",
    "content": "\"\"\"\nTest the layers. Currently only on cpu since travis doesn't have GPU.\n\"\"\"\nimport unittest\nimport pytest\nimport librosa\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom torchaudio_contrib.layers import (STFT, Spectrogram, MelFilterbank, AmplitudeToDb, TimeStretch,\n                                       ComplexNorm, ApplyFilterbank)\n\nfrom test_functional import _num_stft_bins\n\nxfail = pytest.mark.xfail\n\n\n@pytest.mark.parametrize('fft_len', [512])\n@pytest.mark.parametrize('hop_length', [256])\n@pytest.mark.parametrize('waveform', [\n    torch.randn(1, 100000)\n])\n@pytest.mark.parametrize('pad_mode', [\n    # 'constant',\n    'reflect',\n])\ndef test_STFT(waveform, fft_len, hop_length, pad_mode):\n    \"\"\"\n    Test STFT for multi-channel signals.\n\n    Padding: Value in having padding outside of torch.stft?\n    \"\"\"\n    pad = fft_len // 2\n    layer = STFT(fft_length=fft_len, hop_length=hop_length, pad_mode=pad_mode)\n\n    assert torch.is_tensor(layer.window)\n    assert not layer.window.requires_grad\n    assert layer.window.size(0) <= layer.fft_length\n\n\n@pytest.mark.parametrize('rate', [0.7])\n@pytest.mark.parametrize('complex_specgrams', [\n    torch.randn(1, 2, 1025, 400, 2),\n    torch.randn(1, 1025, 400, 2)\n])\n@pytest.mark.parametrize('hop_length', [256])\ndef test_TimeStretch(complex_specgrams, rate, hop_length):\n\n    layer = TimeStretch(hop_length=hop_length, num_freqs=complex_specgrams.shape[-3])\n\n    assert torch.is_tensor(layer.phase_advance)\n    assert not layer.phase_advance.requires_grad\n\n\n@pytest.mark.parametrize('fft_length', [512])\n@pytest.mark.parametrize('hop_length', [256])\n@pytest.mark.parametrize('waveform', [\n    torch.randn(1, 100000),\n    torch.randn(1, 2, 100000)\n])\n@pytest.mark.parametrize('pad_mode', [\n    # 'constant',\n    'reflect',\n])\ndef test_SpectrogramDb(waveform, fft_length, hop_length, pad_mode):\n\n    ref, amin = 1.0, 1e-7\n    window = torch.hann_window(fft_length)\n    model = torch.nn.Sequential(*Spectrogram(fft_length, hop_length=hop_length, window=window, pad_mode=pad_mode),\n                                AmplitudeToDb(ref=ref, amin=amin))\n    db_spec = model(waveform).numpy()\n\n    fft_config = dict(n_fft=fft_length, hop_length=hop_length, pad_mode=pad_mode)\n    expected_db_spec = np.abs(np.apply_along_axis(librosa.stft, -1,\n                              waveform.numpy(), **fft_config))\n\n    db_config = dict(ref=ref, amin=amin, top_db=None)\n    expected_db_spec = np.apply_along_axis(librosa.power_to_db,\n                                           -1,\n                                           expected_db_spec**2,\n                                           **db_config)\n\n    assert np.allclose(db_spec, expected_db_spec, atol=1e-2), np.abs(expected_db_spec - db_spec).max()\n\n\n@pytest.mark.parametrize('fft_length', [512])\n@pytest.mark.parametrize('num_mels', [128])\n@pytest.mark.parametrize('hop_length', [256])\n@pytest.mark.parametrize('waveform', [\n    torch.randn(1, 2, 100000),\n    torch.randn(4, 100000)\n])\n@pytest.mark.parametrize('rate', [0.7])\ndef test_MelspectrogramStretch(waveform, fft_length, num_mels, hop_length, rate):\n\n    num_freqs = fft_length // 2 + 1\n    fb = MelFilterbank(num_freqs=num_freqs, num_mels=num_mels, max_freq=1.0).get_filterbank()\n    model = nn.Sequential(STFT(fft_length, hop_length=hop_length),\n                          TimeStretch(hop_length=hop_length, num_freqs=num_freqs, fixed_rate=rate),\n                          ComplexNorm(power=2.0),\n                          ApplyFilterbank(fb))\n    mel_spec = model(waveform)\n    num_bins = _num_stft_bins(waveform.size(-1), fft_length, hop_length, fft_length // 2)\n\n    assert mel_spec.size(-2) == num_mels\n    assert mel_spec.size(-1) == np.ceil(num_bins / rate)\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "torchaudio_contrib/__init__.py",
    "content": "from .functional import *  # noqa: F401\nfrom .layers import *  # noqa: F401\n"
  },
  {
    "path": "torchaudio_contrib/beta_hpss.py",
    "content": "\"\"\"This is a beta-version of harmonic-percussive source separation.\nCurrently it only returns the separated magnitude spectrograms. Once we have inverse-STFT,\nwe can extend it to get waveform results.\n\nTODO: add test\n\"\"\"\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass HPSS(nn.Module):\n    \"\"\"\n    Wrap hpss.\n\n    Args and Returns --> see `hpss`.\n    \"\"\"\n\n    def __init__(self, kernel_size=31, power=2.0, hard=False, mask_only=False):\n        super(HPSS, self).__init__()\n        self.kernel_size = kernel_size\n        self.power = power\n        self.hard = hard\n        self.mask_only = mask_only\n\n    def forward(self, mag_specgrams):\n        return hpss(mag_specgrams, self.kernel_size, self.power, self.hard, self.mask_only)\n\n    def __repr__(self):\n        return self.__class__.__name__ + \\\n               '(kernel_size={}, power={}, hard={}, mask_only={})'.format(\n                   self.kernel_size, self.power, self.hard, self.mask_only)\n\n\ndef hpss(mag_specgrams, kernel_size=31, power=2.0, hard=False, mask_only=False):\n    \"\"\"\n    A function that performs harmonic-percussive source separation.\n    Original method is by Derry Fitzgerald\n    (https://www.researchgate.net/publication/254583990_HarmonicPercussive_Separation_using_Median_Filtering).\n\n    Args:\n        mag_specgrams (Tensor): any magnitude spectrograms in batch, (not in a decibel scale!)\n            in a shape of (batch, ch, freq, time)\n\n        kernel_size (int or (int, int)): odd-numbered\n            if tuple,\n                1st: width of percussive-enhancing filter (one along freq axis)\n                2nd: width of harmonic-enhancing filter (one along time axis)\n            if int,\n                it's applied for both perc/harm filters\n\n        power (float): to which the enhanced spectrograms are used in computing soft masks.\n\n        hard (bool): whether the mask will be binarized (True) or not\n\n        mask_only (bool): if true, returns the masks only.\n\n    Returns:\n        ret (Tuple): A tuple of four\n\n            ret[0]: magnitude spectrograms - harmonic parts (Tensor, in same size with `mag_specgrams`)\n            ret[1]: magnitude spectrograms - percussive parts (Tensor, in same size with `mag_specgrams`)\n            ret[2]: harmonic mask (Tensor, in same size with `mag_specgrams`)\n            ret[3]: percussive mask (Tensor, in same size with `mag_specgrams`)\n    \"\"\"\n\n    def _enhance_either_hpss(mag_specgrams_padded, out, kernel_size, power, which, offset):\n        \"\"\"\n        A helper function for HPSS\n\n        Args:\n            mag_specgrams_padded (Tensor): one that median filtering can be directly applied\n\n            out (Tensor): The tensor to store the result\n\n            kernel_size (int): The kernel size of median filter\n\n            power (float): to which the enhanced spectrograms are used in computing soft masks.\n\n            which (str): either 'harm' or 'perc'\n\n            offset (int): the padded length\n\n        \"\"\"\n        if which == 'harm':\n            for t in range(out.shape[3]):\n                out[:, :, :, t] = torch.median(mag_specgrams_padded[:, :, offset:-offset, t:t + kernel_size], dim=3)[0]\n\n        elif which == 'perc':\n            for f in range(out.shape[2]):\n                out[:, :, f, :] = torch.median(mag_specgrams_padded[:, :, f:f + kernel_size, offset:-offset], dim=2)[0]\n        else:\n            raise NotImplementedError('it should be either but you passed which={}'.format(which))\n\n        if power != 1.0:\n            out.pow_(power)\n        # end of the helper function\n\n    eps = 1e-6\n\n    if not (isinstance(kernel_size, tuple) or isinstance(kernel_size, int)):\n        raise TypeError('kernel_size is expected to be either tuple of input, but it is: %s' % type(kernel_size))\n    if isinstance(kernel_size, int):\n        kernel_size = (kernel_size, kernel_size)\n\n    pad = (kernel_size[0] // 2, kernel_size[0] // 2,\n           kernel_size[1] // 2, kernel_size[1] // 2,)\n\n    harm, perc, ret = torch.empty_like(mag_specgrams), torch.empty_like(mag_specgrams), torch.empty_like(mag_specgrams)\n    mag_specgrams_padded = F.pad(mag_specgrams, pad=pad, mode='reflect')\n\n    _enhance_either_hpss(mag_specgrams_padded, out=perc, kernel_size=kernel_size[0], power=power, which='perc',\n                         offset=kernel_size[1] // 2)\n    _enhance_either_hpss(mag_specgrams_padded, out=harm, kernel_size=kernel_size[1], power=power, which='harm',\n                         offset=kernel_size[0] // 2)\n\n    if hard:\n        mask_harm = harm > perc\n        mask_perc = harm < perc\n    else:\n        mask_harm = (harm + eps) / (harm + perc + eps)\n        mask_perc = (perc + eps) / (harm + perc + eps)\n\n    if mask_only:\n        return None, None, mask_harm, mask_perc\n\n    return mag_specgrams * mask_harm, mag_specgrams * mask_perc, mask_harm, mask_perc\n\n# def pss_src(x, kernel_size=31, power=2.0, hard=False):\n#     \"\"\"perform percusive source separation using `hpss()`.\n#     x: (batch, time)\"\"\"\n#     n_fft = 1024\n#     hop_length = 256\n#     x_stft = torch.stft(x, n_fft=n_fft, hop_length=hop_length)\n#     x_mag = x_stft.pow(2).sum(-1).unsqueeze(1)  # add channel dim\n#     _, _, _, mask_perc = hpss(x_mag, kernel_size, power, hard, mask_only=True)\n#     mask_perc.squeeze_(1).unsqueeze_(3)  # remove channel, add last dim for complex\n#     x_perc = time_freq.istft(x_stft * mask_perc, hop_length=hop_length, length=x.shape[1])\n#     return x_perc\n"
  },
  {
    "path": "torchaudio_contrib/functional.py",
    "content": "import torch\nimport math\n\n\ndef _mel_to_hertz(mel, htk):\n    \"\"\"\n    Converting mel values into frequency\n    \"\"\"\n    mel = torch.as_tensor(mel).type(torch.get_default_dtype())\n\n    if htk:\n        return 700. * (10 ** (mel / 2595.) - 1.)\n\n    f_min = 0.0\n    f_sp = 200.0 / 3\n    hz = f_min + f_sp * mel\n\n    min_log_hz = 1000.0\n    min_log_mel = (min_log_hz - f_min) / f_sp\n    logstep = math.log(6.4) / 27.0\n\n    return torch.where(mel >= min_log_mel, min_log_hz *\n                       torch.exp(logstep * (mel - min_log_mel)), hz)\n\n\ndef _hertz_to_mel(hz, htk):\n    \"\"\"\n    Converting frequency into mel values\n    \"\"\"\n    hz = torch.as_tensor(hz).type(torch.get_default_dtype())\n\n    if htk:\n        return 2595. * torch.log10(torch.tensor(1., dtype=torch.get_default_dtype()) + (hz / 700.))\n\n    f_min = 0.0\n    f_sp = 200.0 / 3\n\n    mel = (hz - f_min) / f_sp\n\n    min_log_hz = 1000.0\n    min_log_mel = (min_log_hz - f_min) / f_sp\n    logstep = math.log(6.4) / 27.0\n\n    return torch.where(hz >= min_log_hz, min_log_mel +\n                       torch.log(hz / min_log_hz) / logstep, mel)\n\n\ndef stft(waveforms, fft_length, hop_length=None, win_length=None, window=None,\n         center=True, pad_mode='reflect', normalized=False, onesided=True):\n    \"\"\"Compute a short-time Fourier transform of the input waveform(s).\n    It wraps `torch.stft` but after reshaping the input audio\n    to allow for `waveforms` that `.dim()` >= 3.\n    It follows most of the `torch.stft` default value, but for `window`,\n    if it's not specified (`None`), it uses hann window.\n\n    Args:\n        waveforms (Tensor): Tensor of audio signal\n            of size `(*, channel, time)`\n        fft_length (int): FFT size [sample]\n        hop_length (int): Hop size [sample] between STFT frames.\n            Defaults to `fft_length // 4` (75%-overlapping windows)\n            by `torch.stft`.\n        win_length (int): Size of STFT window.\n            Defaults to `fft_length` by `torch.stft`.\n        window (Tensor): 1-D Tensor.\n            Defaults to Hann Window of size `win_length`\n            *unlike* `torch.stft`.\n        center (bool): Whether to pad `waveforms` on both sides so that the\n            `t`-th frame is centered at time `t * hop_length`.\n            Defaults to `True` by `torch.stft`.\n        pad_mode (str): padding method (see `torch.nn.functional.pad`).\n            Defaults to `'reflect'` by `torch.stft`.\n        normalized (bool): Whether the results are normalized.\n            Defaults to `False` by `torch.stft`.\n        onesided (bool): Whether the half + 1 frequency bins\n            are returned to removethe symmetric part of STFT\n            of real-valued signal. Defaults to `True`\n            by `torch.stft`.\n\n    Returns:\n        complex_specgrams (Tensor): `(*, channel, num_freqs, time, complex=2)`\n\n    Example:\n        >>> waveforms = torch.randn(16, 2, 10000)  # (batch, channel, time)\n        >>> x = stft(waveforms, 2048, 512)\n        >>> x.shape\n        torch.Size([16, 2, 1025, 20])\n    \"\"\"\n    leading_dims = waveforms.shape[:-1]\n\n    waveforms = waveforms.reshape(-1, waveforms.size(-1))\n\n    if window is None:\n        if win_length is None:\n            window = torch.hann_window(fft_length)\n        else:\n            window = torch.hann_window(win_length)\n\n    complex_specgrams = torch.stft(waveforms,\n                                   n_fft=fft_length,\n                                   hop_length=hop_length,\n                                   win_length=win_length,\n                                   window=window,\n                                   center=center,\n                                   pad_mode=pad_mode,\n                                   normalized=normalized,\n                                   onesided=onesided)\n\n    complex_specgrams = complex_specgrams.reshape(\n        leading_dims +\n        complex_specgrams.shape[1:])\n\n    return complex_specgrams\n\n\ndef complex_norm(complex_tensor, power=1.0):\n    \"\"\"Compute the norm of complex tensor input\n\n    Args:\n        complex_tensor (Tensor): Tensor shape of `(*, complex=2)`\n        power (float): Power of the norm. Defaults to `1.0`.\n\n    Returns:\n        Tensor: power of the normed input tensor, shape of `(*, )`\n    \"\"\"\n    if power == 1.0:\n        return torch.norm(complex_tensor, 2, -1)\n    return torch.norm(complex_tensor, 2, -1).pow(power)\n\n\ndef create_mel_filter(num_freqs, num_mels, min_freq, max_freq, htk):\n    \"\"\"\n    Creates filter matrix to transform fft frequency bins\n    into mel frequency bins.\n    Equivalent to librosa.filters.mel(sample_rate,\n                                      fft_len,\n                                      htk=True,\n                                      norm=None).\n\n    Args:\n        num_freqs (int): number of filter banks from stft.\n        num_mels (int): number of mel bins.\n        min_freq (float): minimum frequency.\n        max_freq (float): maximum frequency.\n        htk (bool): whether following htk-mel scale or not\n\n    Returns:\n        mel_filterbank (Tensor): (num_freqs, num_mels)\n    \"\"\"\n    # Convert to find mel lower/upper bounds\n    m_min = _hertz_to_mel(min_freq, htk)\n    m_max = _hertz_to_mel(max_freq, htk)\n\n    # Compute stft frequency values\n    stft_freqs = torch.linspace(min_freq, max_freq, num_freqs)\n\n    # Find mel values, and convert them to frequency units\n    m_pts = torch.linspace(m_min, m_max, num_mels + 2)\n    f_pts = _mel_to_hertz(m_pts, htk)\n    f_diff = f_pts[1:] - f_pts[:-1]  # (num_mels + 1)\n\n    # (num_freqs, num_mels + 2)\n    slopes = f_pts.unsqueeze(0) - stft_freqs.unsqueeze(1)\n\n    down_slopes = (-1. * slopes[:, :-2]) / f_diff[:-1]  # (num_freqs, num_mels)\n    up_slopes = slopes[:, 2:] / f_diff[1:]  # (num_freqs, num_mels)\n    mel_filterbank = torch.clamp(torch.min(down_slopes, up_slopes), min=0.)\n\n    return mel_filterbank\n\n\ndef apply_filterbank(mag_specgrams, filterbank):\n    \"\"\"\n    Transform spectrogram given a filterbank matrix.\n\n    Args:\n        mag_specgrams (Tensor): (batch, channel, num_freqs, time)\n        filterbank (Tensor): (num_freqs, num_bands)\n\n    Returns:\n        (Tensor): (batch, channel, num_bands, time)\n    \"\"\"\n    return torch.matmul(mag_specgrams.transpose(-2, -1),\n                        filterbank).transpose(-2, -1)\n\n\ndef angle(complex_tensor):\n    \"\"\"\n    Return angle of a complex tensor with shape (*, 2).\n    \"\"\"\n    return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0])\n\n\ndef magphase(complex_tensor, power=1.):\n    \"\"\"\n    Separate a complex-valued spectrogram with shape (*,2)\n    into its magnitude and phase.\n    \"\"\"\n    mag = complex_norm(complex_tensor, power)\n    phase = angle(complex_tensor)\n    return mag, phase\n\n\ndef phase_vocoder(complex_specgrams, rate, phase_advance):\n    \"\"\"\n    Phase vocoder. Given a STFT tensor, speed up in time\n    without modifying pitch by a factor of `rate`.\n\n    Args:\n        complex_specgrams (Tensor):\n            (*, channel, num_freqs, time, complex=2)\n        rate (float): Speed-up factor.\n        phase_advance (Tensor): Expected phase advance in\n            each bin. (num_freqs, 1).\n\n    Returns:\n        complex_specgrams_stretch (Tensor):\n            (*, channel, num_freqs, ceil(time/rate), complex=2).\n\n    Example:\n        >>> num_freqs, hop_length = 1025, 512\n        >>> # (batch, channel, num_freqs, time, complex=2)\n        >>> complex_specgrams = torch.randn(16, 1, num_freqs, 300, 2)\n        >>> rate = 1.3 # Slow down by 30%\n        >>> phase_advance = torch.linspace(\n        >>>    0, math.pi * hop_length, num_freqs)[..., None]\n        >>> x = phase_vocoder(complex_specgrams, rate, phase_advance)\n        >>> x.shape # with 231 == ceil(300 / 1.3)\n        torch.Size([16, 1, 1025, 231, 2])\n    \"\"\"\n    ndim = complex_specgrams.dim()\n    time_slice = [slice(None)] * (ndim - 2)\n\n    time_steps = torch.arange(0, complex_specgrams.size(\n        -2), rate, device=complex_specgrams.device)\n\n    alphas = torch.remainder(time_steps,\n                             torch.tensor(1., device=complex_specgrams.device))\n    phase_0 = angle(complex_specgrams[time_slice + [slice(1)]])\n\n    # Time Padding\n    complex_specgrams = torch.nn.functional.pad(\n        complex_specgrams, [0, 0, 0, 2])\n\n    complex_specgrams_0 = complex_specgrams[time_slice +\n                                            [time_steps.long()]]\n    # (new_bins, num_freqs, 2)\n    complex_specgrams_1 = complex_specgrams[time_slice +\n                                            [(time_steps + 1).long()]]\n\n    angle_0 = angle(complex_specgrams_0)\n    angle_1 = angle(complex_specgrams_1)\n\n    norm_0 = torch.norm(complex_specgrams_0, dim=-1)\n    norm_1 = torch.norm(complex_specgrams_1, dim=-1)\n\n    phase = angle_1 - angle_0 - phase_advance\n    phase = phase - 2 * math.pi * torch.round(phase / (2 * math.pi))\n\n    # Compute Phase Accum\n    phase = phase + phase_advance\n    phase = torch.cat([phase_0, phase[time_slice + [slice(-1)]]], dim=-1)\n    phase_acc = torch.cumsum(phase, -1)\n\n    mag = alphas * norm_1 + (1 - alphas) * norm_0\n\n    real_stretch = mag * torch.cos(phase_acc)\n    imag_stretch = mag * torch.sin(phase_acc)\n\n    complex_specgrams_stretch = torch.stack(\n        [real_stretch, imag_stretch],\n        dim=-1)\n\n    return complex_specgrams_stretch\n\n\ndef amplitude_to_db(x, ref=1.0, amin=1e-7):\n    \"\"\"\n    Amplitude-to-decibel conversion (logarithmic mapping with base=10)\n    By using `amin=1e-7`, it assumes 32-bit floating point input. If the\n    data precision differs, use approproate `amin` accordingly.\n\n    Args:\n        x (Tensor): Input amplitude\n        ref (float): Amplitude value that is equivalent to 0 decibel\n        amin (float): Minimum amplitude. Any input that is smaller than `amin` is\n            clamped to `amin`.\n    Returns:\n        (Tensor): same size of x, after conversion\n    \"\"\"\n    x = x.pow(2.)\n    x = torch.clamp(x, min=amin)\n    return 10.0 * (torch.log10(x) - torch.log10(torch.tensor(ref,\n                                                             device=x.device,\n                                                             requires_grad=False,\n                                                             dtype=x.dtype)))\n\n\ndef db_to_amplitude(x, ref=1.0):\n    \"\"\"\n    Decibel-to-amplitude conversion (exponential mapping with base=10)\n\n    Args:\n        x (Tensor): Input in decibel to be converted\n        ref (float): Amplitude value that is equivalent to 0 decibel\n\n    Returns:\n        (Tensor): same size of x, after conversion\n    \"\"\"\n    power_spec = torch.pow(10.0, x / 10.0 + torch.log10(torch.tensor(ref,\n                                                        device=x.device,\n                                                        requires_grad=False,\n                                                        dtype=x.dtype)))\n    return power_spec.pow(0.5)\n\n\ndef mu_law_encoding(x, n_quantize=256):\n    \"\"\"Apply mu-law encoding to the input tensor.\n    Usually applied to waveforms\n\n    Args:\n        x (Tensor): input value\n        n_quantize (int): quantization level. For 8-bit encoding, set 256 (2 ** 8).\n\n    Returns:\n        (Tensor): same size of x, after encoding\n\n    \"\"\"\n    if not x.dtype.is_floating_point:\n        x = x.to(torch.float)\n    mu = torch.tensor(n_quantize - 1, dtype=x.dtype, requires_grad=False)  # confused about dtype here..\n\n    x_mu = x.sign() * torch.log1p(mu * x.abs()) / torch.log1p(mu)\n    x_mu = ((x_mu + 1) / 2 * mu + 0.5).long()\n    return x_mu\n\n\ndef mu_law_decoding(x_mu, n_quantize=256, dtype=torch.get_default_dtype()):\n    \"\"\"Apply mu-law decoding (expansion) to the input tensor.\n\n    Args:\n        x_mu (Tensor): mu-law encoded input\n        n_quantize (int): quantization level. For 8-bit decoding, set 256 (2 ** 8).\n        dtype: specifies `dtype` for the decoded value. Default: `torch.get_default_dtype()`\n\n    Returns:\n        (Tensor): mu-law decoded tensor\n    \"\"\"\n    if not x_mu.dtype.is_floating_point:\n        x_mu = x_mu.to(dtype)\n    mu = torch.tensor(n_quantize - 1, dtype=x_mu.dtype, requires_grad=False)  # confused about dtype here..\n    x = (x_mu / mu) * 2 - 1.\n    x = x.sign() * (torch.exp(x.abs() * torch.log1p(mu)) - 1.) / mu\n    return x\n"
  },
  {
    "path": "torchaudio_contrib/layers.py",
    "content": "import torch\nimport math\nimport torch.nn as nn\n\nfrom .functional import stft, complex_norm, \\\n    create_mel_filter, phase_vocoder, apply_filterbank, \\\n    amplitude_to_db, db_to_amplitude, \\\n    mu_law_encoding, mu_law_decoding\n\n\nclass _ModuleNoStateBuffers(nn.Module):\n    \"\"\"\n    Extension of nn.Module that removes buffers\n    from state_dict.\n    \"\"\"\n\n    def state_dict(self, destination=None, prefix='', keep_vars=False):\n        ret = super(_ModuleNoStateBuffers, self).state_dict(\n            destination, prefix, keep_vars)\n        for k in self._buffers:\n            del ret[prefix + k]\n        return ret\n\n    def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):\n        # temporarily hide the buffers; we do not want to restore them\n\n        buffers = self._buffers\n        self._buffers = {}\n        result = super(_ModuleNoStateBuffers, self)._load_from_state_dict(\n            state_dict, prefix, *args, **kwargs)\n        self._buffers = buffers\n        return result\n\n\nclass STFT(_ModuleNoStateBuffers):\n    \"\"\"Compute a short-time Fourier transform of the input waveform(s).\n    It essentially wraps `torch.stft` but after reshaping the input audio\n    to allow for `waveforms` that `.dim()` >= 3.\n    It follows most of the `torch.stft` default value, but for `window`,\n    if it's not specified (`None`), it uses hann window.\n\n    Args:\n        fft_length (int): FFT size [sample]\n        hop_length (int): Hop size [sample] between STFT frames.\n            Defaults to `fft_length // 4` (75%-overlapping windows) by `torch.stft`.\n        win_length (int): Size of STFT window.\n            Defaults to `fft_length` by `torch.stft`.\n        window (Tensor): 1-D Tensor.\n            Defaults to Hann Window of size `win_length` *unlike* `torch.stft`.\n        center (bool): Whether to pad `waveforms` on both sides so that the\n            `t`-th frame is centered at time `t * hop_length`.\n            Defaults to `True` by `torch.stft`.\n        pad_mode (str): padding method (see `torch.nn.functional.pad`).\n            Defaults to `'reflect'` by `torch.stft`.\n        normalized (bool): Whether the results are normalized.\n            Defaults to `False` by `torch.stft`.\n        onesided (bool): Whether the half + 1 frequency bins are returned to remove\n            the symmetric part of STFT of real-valued signal.\n            Defaults to `True` by `torch.stft`.\n    \"\"\"\n\n    def __init__(self, fft_length, hop_length=None, win_length=None,\n                 window=None, center=True, pad_mode='reflect',\n                 normalized=False, onesided=True):\n        super(STFT, self).__init__()\n\n        self.fft_length = fft_length\n        self.hop_length = hop_length\n        self.win_length = win_length\n\n        self.center = center\n        self.pad_mode = pad_mode\n        self.normalized = normalized\n        self.onesided = onesided\n\n        if window is None:\n            if win_length is None:\n                window = torch.hann_window(fft_length)\n            else:\n                window = torch.hann_window(win_length)\n\n        self.register_buffer('window', window)\n\n    def forward(self, waveforms):\n        \"\"\"\n        Args:\n            waveforms (Tensor): Tensor of audio signal of size `(*, channel, time)`\n\n        Returns:\n            complex_specgrams (Tensor): `(*, channel, num_freqs, time, complex=2)`\n        \"\"\"\n\n        complex_specgrams = stft(waveforms, self.fft_length,\n                                 hop_length=self.hop_length,\n                                 win_length=self.win_length,\n                                 window=self.window,\n                                 center=self.center,\n                                 pad_mode=self.pad_mode,\n                                 normalized=self.normalized,\n                                 onesided=self.onesided)\n\n        return complex_specgrams\n\n    def __repr__(self):\n        param_str1 = '(fft_length={}, hop_length={}, win_length={})'.format(\n            self.fft_length, self.hop_length, self.win_length)\n        param_str2 = '(center={}, pad_mode={}, normalized={}, onesided={})'.format(\n            self.center, self.pad_mode, self.normalized, self.onesided)\n        return self.__class__.__name__ + param_str1 + param_str2\n\n\nclass ComplexNorm(nn.Module):\n    \"\"\"Compute the norm of complex tensor input\n\n    Args:\n        power (float): Power of the norm. Defaults to `1.0`.\n\n    \"\"\"\n\n    def __init__(self, power=1.0):\n        super(ComplexNorm, self).__init__()\n        self.power = power\n\n    def forward(self, complex_tensor):\n        \"\"\"\n        Args:\n            complex_tensor (Tensor): Tensor shape of `(*, complex=2)`\n\n        Returns:\n            Tensor: norm of the input tensor, shape of `(*, )`\n        \"\"\"\n        return complex_norm(complex_tensor, self.power)\n\n    def __repr__(self):\n        return self.__class__.__name__ + '(power={})'.format(self.power)\n\n\nclass ApplyFilterbank(_ModuleNoStateBuffers):\n    \"\"\"\n    Applies a filterbank transform.\n    \"\"\"\n\n    def __init__(self, filterbank):\n        super(ApplyFilterbank, self).__init__()\n        self.register_buffer('filterbank', filterbank)\n\n    def forward(self, mag_specgrams):\n        \"\"\"\n        Args:\n            mag_specgrams (Tensor): (channel, time, freq) or (batch, channel, time, freq).\n\n        Returns:\n            (Tensor): freq -> filterbank.size(0)\n        \"\"\"\n        return apply_filterbank(mag_specgrams, self.filterbank)\n\n\nclass Filterbank(object):\n    \"\"\"\n    Base class for providing a filterbank matrix.\n    \"\"\"\n\n    def __init__(self):\n        super(Filterbank, self).__init__()\n\n    def get_filterbank(self):\n        raise NotImplementedError\n\n\nclass MelFilterbank(Filterbank):\n    \"\"\"\n    Provides a filterbank matrix to convert a spectrogram into a mel frequency spectrogram.\n\n    Args:\n        num_freqs (int, optional): number of filter banks from stft.\n            Defaults to 2048//2 + 1.\n        num_mels (int): number of mel bins. Defaults to 128.\n        min_freq (float): minimum frequency. Defaults to 0.\n        max_freq (float, optional): maximum frequency. Defaults to sample_rate // 2.\n        sample_rate (int): sample rate of audio signal. Defaults to None.\n        htk (bool, optional): use HTK formula instead of Slaney. Defaults to False.\n    \"\"\"\n\n    def __init__(self, num_freqs=1025, num_mels=128,\n                 min_freq=0.0, max_freq=None, sample_rate=None, htk=False):\n        super(MelFilterbank, self).__init__()\n\n        if sample_rate is None and max_freq is None:\n            raise ValueError('Either max_freq or sample_rate should be specified.'\n                             ', but both are None.')\n        self.num_freqs = num_freqs\n        self.num_mels = num_mels\n        self.min_freq = min_freq\n        self.max_freq = max_freq if max_freq else sample_rate // 2\n        self.htk = htk\n\n    def get_filterbank(self):\n        return create_mel_filter(\n            num_freqs=self.num_freqs,\n            num_mels=self.num_mels,\n            min_freq=self.min_freq,\n            max_freq=self.max_freq,\n            htk=self.htk)\n\n    def __repr__(self):\n        param_str1 = '(num_freqs={}, snum_mels={}'.format(\n            self.num_freqs, self.num_mels)\n        param_str2 = ', min_freq={}, max_freq={})'.format(\n            self.min_freq, self.max_freq)\n        param_str3 = ', htk={}'.format(\n            self.htk)\n        return self.__class__.__name__ + param_str1 + param_str2 + param_str3\n\n\nclass TimeStretch(_ModuleNoStateBuffers):\n    \"\"\"\n    Stretch stft in time without modifying pitch for a given rate.\n\n    Args:\n\n        hop_length (int): Number audio of frames between STFT columns.\n        num_freqs (int, optional): number of filter banks from stft.\n        fixed_rate (float): rate to speed up or slow down by.\n            Defaults to None (in which case a rate must be\n            passed to the forward method per batch).\n    \"\"\"\n\n    def __init__(self, hop_length, num_freqs, fixed_rate=None):\n        super(TimeStretch, self).__init__()\n\n        self.fixed_rate = fixed_rate\n        phase_advance = torch.linspace(\n            0, math.pi * hop_length, num_freqs)[..., None]\n\n        self.register_buffer('phase_advance', phase_advance)\n\n    def forward(self, complex_specgrams, overriding_rate=None):\n        \"\"\"\n\n        Args:\n            complex_specgrams (Tensor): complex spectrogram\n                (*, channel, freq, time, complex=2)\n            overriding_rate (float or None): speed up to apply to this batch.\n                If no rate is passed, use self.fixed_rate.\n\n        Returns:\n            (Tensor): (*, channel, num_freqs, ceil(time/rate), complex=2)\n        \"\"\"\n        if overriding_rate is None:\n            rate = self.fixed_rate\n            if rate is None:\n                raise ValueError(\"If no fixed_rate is specified\"\n                                 \", must pass a valid rate to the forward method.\")\n        else:\n            rate = overriding_rate\n\n        if rate == 1.0:\n            return complex_specgrams\n\n        return phase_vocoder(complex_specgrams, rate, self.phase_advance)\n\n    def __repr__(self):\n        param_str = '(fixed_rate={})'.format(self.fixed_rate)\n        return self.__class__.__name__ + param_str\n\n\ndef Spectrogram(fft_length, hop_length=None, win_length=None,\n                window=None, center=True, pad_mode='reflect',\n                normalized=False, onesided=True, power=1.):\n    \"\"\"Get spectrogram module, which is a Sequential module of\n        `[STFT(), ComplexNorm()]`.\n\n    Args:\n        fft_length (int): FFT size [sample]\n        hop_length (int): Hop size [sample] between STFT frames.\n            Defaults to `fft_length // 4` (75%-overlapping windows) by `torch.stft`.\n        win_length (int): Size of STFT window.\n            Defaults to `fft_length` by `torch.stft`.\n        window (Tensor): 1-D Tensor.\n            Defaults to Hann Window of size `win_length` *unlike* `torch.stft`.\n        center (bool): Whether to pad `waveforms` on both sides so that the\n            `t`-th frame is centered at time `t * hop_length`.\n            Defaults to `True` by `torch.stft`.\n        pad_mode (str): padding method (see `torch.nn.functional.pad`).\n            Defaults to `'reflect'` by `torch.stft`.\n        normalized (bool): Whether the results are normalized.\n            Defaults to `False` by `torch.stft`.\n        onesided (bool): Whether the half + 1 frequency bins are returned to remove\n            the symmetric part of STFT of real-valued signal.\n            Defaults to `True` by `torch.stft`.\n        power (float): Exponent of the magnitude. Defaults to `1.0`.\n\n    \"\"\"\n    return nn.Sequential(\n        STFT(\n            fft_length,\n            hop_length,\n            win_length,\n            window,\n            center,\n            pad_mode,\n            normalized,\n            onesided),\n        ComplexNorm(power))\n\n\ndef Melspectrogram(\n        num_mels=128,\n        sample_rate=22050,\n        min_freq=0.0,\n        max_freq=None,\n        num_freqs=None,\n        htk=False,\n        mel_filterbank=None,\n        **kwargs):\n    \"\"\"\n    Get melspectrogram module.\n\n    Args:\n        num_mels (int): number of mel bins. Defaults to 128.\n        sample_rate (int): sample rate of audio signal. Defaults to 22050.\n        min_freq (float): minimum frequency. Defaults to 0.\n        max_freq (float, optional): maximum frequency. Defaults to sample_rate // 2.\n        num_freqs (int, optional): number of filter banks from stft.\n            Defaults to fft_len//2 + 1 if 'fft_len' in kwargs else 1025.\n        htk (bool, optional): use HTK formula instead of Slaney. Defaults to False.\n        mel_filterbank (class, optional): MelFilterbank class to build filterbank matrix\n        **kwargs: torchaudio_contrib.Spectrogram parameters.\n    \"\"\"\n    fft_length = kwargs.get('fft_length', None)\n    num_freqs = fft_length // 2 + 1 if fft_length else 1025\n    # keunwoo: Why is num_freqs specified like this and not by the passed argument?\n\n    # Check if custom MelFilterbank is passed\n    if mel_filterbank is None:\n        mel_filterbank = MelFilterbank\n\n    mel_fb_matrix = mel_filterbank(\n        num_mels=num_mels,\n        sample_rate=sample_rate,\n        min_freq=min_freq,\n        max_freq=max_freq,\n        num_freqs=num_freqs,\n        htk=htk).get_filterbank()\n\n    return nn.Sequential(*Spectrogram(power=2., **kwargs),\n                         ApplyFilterbank(mel_fb_matrix))\n\n\nclass AmplitudeToDb(_ModuleNoStateBuffers):\n    \"\"\"\n    Amplitude-to-decibel conversion (logarithmic mapping with base=10)\n    By using `amin=1e-7`, it assumes 32-bit floating point input. If the\n    data precision differs, use approproate `amin` accordingly.\n\n    Args:\n        ref (float): Amplitude value that is equivalent to 0 decibel\n        amin (float): Minimum amplitude. Any input that is smaller than `amin` is\n            clamped to `amin`.\n    \"\"\"\n\n    def __init__(self, ref=1.0, amin=1e-7):\n        super(AmplitudeToDb, self).__init__()\n        self.ref = ref\n        self.amin = amin\n        assert ref > amin, \"Reference value is expected to be bigger than amin, but I have\" \\\n                           \"ref:{} and amin:{}\".format(ref, amin)\n\n    def forward(self, x):\n        \"\"\"\n        Args:\n            x (Tensor): Input amplitude\n\n        Returns:\n            (Tensor): same size of x, after conversion\n        \"\"\"\n        return amplitude_to_db(x, ref=self.ref, amin=self.amin)\n\n    def __repr__(self):\n        param_str = '(ref={}, amin={})'.format(self.ref, self.amin)\n        return self.__class__.__name__ + param_str\n\n\nclass DbToAmplitude(_ModuleNoStateBuffers):\n    \"\"\"\n    Decibel-to-amplitude conversion (exponential mapping with base=10)\n\n    Args:\n        x (Tensor): Input in decibel to be converted\n        ref (float): Amplitude value that is equivalent to 0 decibel\n\n    Returns:\n        (Tensor): same size of x, after conversion\n    \"\"\"\n\n    def __init__(self, ref=1.0):\n        super(DbToAmplitude, self).__init__()\n        self.ref = ref\n\n    def forward(self, x):\n        \"\"\"\n        Args:\n            x (Tensor): Input in decibel to be converted\n\n        Returns:\n            (Tensor): same size of x, after conversion\n        \"\"\"\n        return db_to_amplitude(x, ref=self.ref)\n\n    def __repr__(self):\n        param_str = '(ref={})'.format(self.ref)\n        return self.__class__.__name__ + param_str\n\n\nclass MuLawEncoding(_ModuleNoStateBuffers):\n    \"\"\"Apply mu-law encoding to the input tensor.\n    Usually applied to waveforms\n\n    Args:\n        n_quantize (int): quantization level. For 8-bit encoding, set 256 (2 ** 8).\n\n    \"\"\"\n\n    def __init__(self, n_quantize=256):\n        super(MuLawEncoding, self).__init__()\n        self.n_quantize = n_quantize\n\n    def forward(self, x):\n        \"\"\"\n        Args:\n            x (Tensor): input value\n\n        Returns:\n            (Tensor): same size of x, after encoding\n        \"\"\"\n        return mu_law_encoding(x, self.n_quantize)\n\n    def __repr__(self):\n        param_str = '(n_quantize={})'.format(self.n_quantize)\n        return self.__class__.__name__ + param_str\n\n\nclass MuLawDecoding(_ModuleNoStateBuffers):\n    \"\"\"Apply mu-law decoding (expansion) to the input tensor.\n    Usually applied to waveforms\n\n    Args:\n        n_quantize (int): quantization level. For 8-bit decoding, set 256 (2 ** 8).\n    \"\"\"\n\n    def __init__(self, n_quantize=256):\n        super(MuLawDecoding, self).__init__()\n        self.n_quantize = n_quantize\n\n    def forward(self, x_mu):\n        \"\"\"\n        Args:\n            x_mu (Tensor): mu-law encoded input\n\n        Returns:\n            (Tensor): mu-law decoded tensor\n        \"\"\"\n        return mu_law_decoding(x_mu, self.n_quantize)\n\n    def __repr__(self):\n        param_str = '(n_quantize={})'.format(self.n_quantize)\n        return self.__class__.__name__ + param_str\n"
  }
]