Repository: keunwoochoi/torchaudio-contrib Branch: master Commit: bdbbb1800db9 Files: 13 Total size: 51.0 KB Directory structure: gitextract_ztnwls02/ ├── .flake8 ├── .gitignore ├── .travis.yml ├── README.md ├── requirements.txt ├── setup.cfg ├── setup.py ├── tests/ │ ├── test_functional.py │ └── test_layers.py └── torchaudio_contrib/ ├── __init__.py ├── beta_hpss.py ├── functional.py └── layers.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .flake8 ================================================ [flake8] max-line-length = 120 ignore = E305,E402,E721,E741,F403,F405,F821,F841,F999,W503,W504 exclude = beta_*.py, .git/ ================================================ FILE: .gitignore ================================================ # Intellij Idea .idea/ *.iml # python __pycache__/ *.pyc ## pyenv .python-version ## python virtual environments env/ venv/ .env/ .venv # OSX .DS_Store ._* # swap files for vim users [._]*.swp [._]swp # ctags tags ================================================ FILE: .travis.yml ================================================ language: python python: - "3.5" install: - pip install -e .[tests] script: - python -m pytest ================================================ FILE: README.md ================================================ # SUNSETTING torcnaudio-contrib We made some progress in this repo and contributed to the original repo [pytorch/audio](https://github.com/pytorch/audio) which satisfied us. We happily stopped working here :) Please visit [pytorch/audio](https://github.com/pytorch/audio) for any issue/request! . . . . . . . . . . (We keep the existing content as below ⬇) # torchaudio-contrib Goal: To propose audio processing Pytorch codes with nice and easy-to-use APIs and functionality. :open_hands: This should be seen as a community based proposal and the basis for a discussion we should have inside the pytorch audio user community. Everyone should be welcome to join and discuss. Our motivation is: - API design: Clear, readible names for class/functions/arguments, sensible default values, and shapes. - Reference: [librosa](http://librosa.github.io/librosa/) (audio and MIR on Numpy), [kapre](https://github.com/keunwoochoi/kapre) (audio on Keras), [pytorch/audio](https://github.com/pytorch/audio) (audio on Pytorch) - Fast processing on GPU - Methodology: Both layer and functional - Layers (`nn.Module`) for reusability and easier use - and identical implementation with `Functionals` - Simple installation - Multi-channel support ## Contribution Making things quicker and open! We're `-contrib` repo, hence it's *easy to enter but hard to graduate*. 1. Make a new [Issue](https://github.com/keunwoochoi/torchaudio-contrib/issues) for a potential PR 2. Until it's in a good shape, 1. Make a PR with following the current conventions and unittest 2. Review-merge. 3. Based on it, make a PR to [torch/audio](https://github.com/pytorch/audio) Discussion on how to contribute - https://github.com/keunwoochoi/torchaudio-contrib/issues/37 ## Current issues/future work - Better module/sub-module hierarchy - Complex number support - More time-frequency representations - Signal processing modules, e.g., vocoder - Augmentation # API suggestions ## Notes * Audio signals can be multi-channel * `STFT`: short-time Fourier transform, outputing a complex-numbered representation * `Spectrogram`: magnitudes of STFT * `Melspectrogram`: mel-filterbank applied to `spectrogram` ## Shapes * audio signals: `(batch, channel, time)` * E.g., `STFT` input shape * Based on `torch.stft` input shape * 2D representations: `(batch, channel, freq, time)` * E.g., `STFT` output shape * Channel-first, following torch convention. * Then, `(freq, time)`, following `torch.stft` ## Overview ### `STFT` ```python class STFT(fft_len=2048, hop_len=None, frame_len=None, window=None, pad=0, pad_mode="reflect", **kwargs) def stft(signal, fft_len, hop_len, window, pad=0, pad_mode="reflect", **kwargs) ``` ### `MelFilterbank` ```python class MelFilterbank(num_bands=128, sample_rate=16000, min_freq=0.0, max_freq=None, num_bins=1025, htk=False) def create_mel_filter(num_bands, sample_rate, min_freq, max_freq, num_bins, to_hertz, from_hertz) ``` ### `Spectrogram` ```python def Spectrogram(fft_len=2048, hop_len=None, frame_len=None, window=None, pad=0, pad_mode="reflect", power=1., **kwargs) ``` Creates an `nn.Sequential`: ``` >>> Sequential( >>> (0): STFT(fft_len=2048, hop_len=512, frame_len=2048) >>> (1): ComplexNorm(power=1.0) ) ``` ### `Melspectrogram` ```python def Melspectrogram(num_bands=128, sample_rate=16000, min_freq=0.0, max_freq=None, num_bins=None, htk=False, mel_filterbank=None, **kwargs) ``` Creates an `nn.Sequential`: ``` >>> Sequential( >>> (0): STFT(fft_len=2048, hop_len=512, frame_len=2048) >>> (1): ComplexNorm(power=2.0) >>> (2): ApplyFilterbank() ) ``` ### `AmplitudeToDb`/`amplitude_to_db` ```python class AmplitudeToDb(ref=1.0, amin=1e-7) def amplitude_to_db(x, ref=1.0, amin=1e-7) ``` Arguments names and the default value of `ref` follow librosa. The default value of `amin` however follows Keras's float32 Epsilon, which seems making sense. ### `DbToAmplitude`/`db_to_amplitude` ```python class DbToAmplitude(ref=1.0) def db_to_amplitude(x, ref=1.0) ``` ### `MuLawEncoding`/`mu_law_encoding` ```python class MuLawEncoding(n_quantize=256) def mu_law_encoding(x, n_quantize=256) ``` ### `MuLawDecoding`/`mu_law_decoding` ```python class MuLawDecoding(n_quantize=256) def mu_law_decoding(x_mu, n_quantize=256) ``` ---------- # A Big Issue - Remove SoX Dependency We propose to remove the SoX dependency because: * Many audio ML tasks don’t require the functionality included in Sox (filtering, cutting, effects) * Many issues in torchaudio are related to the installation with respect to Sox. While this could be simplified by a [conda build or a wheel](https://github.com/pytorch/builder/issues/279), it will continue being difficult to maintain the repo. * SOX doesn’t support MP4 containers, which makes it unusable for multi-stream audio * Loading speed is good with torchaudio but e.g. for __wav__, its not faster than other libraries (including cast to torch tensor) -- as in the graph below. See more detailed benchmarks [here](https://github.com/faroit/python_audio_loading_benchmark). ![](https://raw.githubusercontent.com/faroit/python_audio_loading_benchmark/master/results/benchmark_pytorch.png) ## Proposal Introduce I/O backends and move the functions that depend on `_torch_sox` to a `backend_sox.py`, which is *not* required to install. Additionally, we could then introduce more backends like scipy.io or pysoundfile. Each backend then imports the (optional) lib within the backend file and each backend includes a minimum spec such as: ```python import _torch_sox def load(...) # returns audio, rate def save(...) # write file def info(...) # returns metadata without reading the full file ``` ### Backend proposals * `scipy.io` or `soundfile` as default for __wav__ files * `aubio` or `audioread` for __mp3__ and __mp4__ ### Installation ```bash pip install -e . ``` ### Importing import torchaudio_contrib ## Authors Keunwoo Choi, Faro Stöter, Kiran Sanjeevan, Jan Schlüter ================================================ FILE: requirements.txt ================================================ torch ================================================ FILE: setup.cfg ================================================ [tool:pytest] xfail_strict = true ================================================ FILE: setup.py ================================================ from setuptools import setup setup(name='torchaudio_contrib', version='0.1', description='To propose audio processing Pytorch codes with nice and easy-to-use APIs and functionality', url='https://github.com/keunwoochoi/torchaudio-contrib', author='Keunwoo Choi, Faro Stöter, Kiran Sanjeevan, Jan Schlüter', author_email='gnuchoi@gmail.com', license='MIT', install_requires=['torch'], extras_require={'tests': ['pytest', 'librosa']}, packages=['torchaudio_contrib'], zip_safe=False) ================================================ FILE: tests/test_functional.py ================================================ import pytest import librosa import numpy as np import torch from torchaudio_contrib.functional import (stft, phase_vocoder, magphase, amplitude_to_db, db_to_amplitude, complex_norm, mu_law_encoding, mu_law_decoding, apply_filterbank ) xfail = pytest.mark.xfail def _num_stft_bins(signal_len, fft_len, hop_length, pad): return (signal_len + 2 * pad - fft_len + hop_length) // hop_length def _approx_all_equal(x, y, atol=1e-7): return torch.all(torch.lt(torch.abs(torch.add(x, -y)), atol)) def _all_equal(x, y): return torch.all(torch.eq(x, y)) @pytest.mark.parametrize('fft_length', [512]) @pytest.mark.parametrize('hop_length', [256]) @pytest.mark.parametrize('waveform', [ (torch.randn(1, 100000)), (torch.randn(1, 2, 100000)), pytest.param(torch.randn(1, 100), marks=xfail(raises=RuntimeError)), ]) @pytest.mark.parametrize('pad_mode', [ # 'constant', 'reflect', ]) def test_stft(waveform, fft_length, hop_length, pad_mode): """ Test STFT for multi-channel signals. Padding: Value in having padding outside of torch.stft? """ pad = fft_length // 2 window = torch.hann_window(fft_length) complex_spec = stft(waveform, fft_length=fft_length, hop_length=hop_length, window=window, pad_mode=pad_mode) mag_spec, phase_spec = magphase(complex_spec) # == Test shape expected_size = list(waveform.size()[:-1]) expected_size += [fft_length // 2 + 1, _num_stft_bins( waveform.size(-1), fft_length, hop_length, pad), 2] assert complex_spec.dim() == waveform.dim() + 2 assert complex_spec.size() == torch.Size(expected_size) # == Test values fft_config = dict(n_fft=fft_length, hop_length=hop_length, pad_mode=pad_mode) # note that librosa *automatically* pad with fft_length // 2. expected_complex_spec = np.apply_along_axis(librosa.stft, -1, waveform.numpy(), **fft_config) expected_mag_spec, _ = librosa.magphase(expected_complex_spec) # Convert torch to np.complex complex_spec = complex_spec.numpy() complex_spec = complex_spec[..., 0] + 1j * complex_spec[..., 1] assert np.allclose(complex_spec, expected_complex_spec, atol=1e-5) assert np.allclose(mag_spec.numpy(), expected_mag_spec, atol=1e-5) @pytest.mark.parametrize('rate', [0.5, 1.01, 1.3]) @pytest.mark.parametrize('complex_specgrams', [ torch.randn(1, 2, 1025, 400, 2), torch.randn(1, 1025, 400, 2) ]) @pytest.mark.parametrize('hop_length', [256]) def test_phase_vocoder(complex_specgrams, rate, hop_length): class use_double_precision: def __enter__(self): self.default_dtype = torch.get_default_dtype() torch.set_default_dtype(torch.float64) def __exit__(self, type, value, traceback): torch.set_default_dtype(self.default_dtype) # Due to cummulative sum, numerical error in using torch.float32 will # result in bottom right values of the stretched sectrogram to not # match with librosa with use_double_precision(): complex_specgrams = complex_specgrams.type(torch.get_default_dtype()) phase_advance = torch.linspace(0, np.pi * hop_length, complex_specgrams.shape[-3])[..., None] complex_specgrams_stretch = phase_vocoder(complex_specgrams, rate=rate, phase_advance=phase_advance) # == Test shape expected_size = list(complex_specgrams.size()) expected_size[-2] = int(np.ceil(expected_size[-2] / rate)) assert complex_specgrams.dim() == complex_specgrams_stretch.dim() assert complex_specgrams_stretch.size() == torch.Size(expected_size) # == Test values index = [0] * (complex_specgrams.dim() - 3) + [slice(None)] * 3 mono_complex_specgram = complex_specgrams[index].numpy() mono_complex_specgram = mono_complex_specgram[..., 0] + \ mono_complex_specgram[..., 1] * 1j expected_complex_stretch = librosa.phase_vocoder( mono_complex_specgram, rate=rate, hop_length=hop_length) complex_stretch = complex_specgrams_stretch[index].numpy() complex_stretch = complex_stretch[..., 0] + \ 1j * complex_stretch[..., 1] assert np.allclose(complex_stretch, expected_complex_stretch, atol=1e-5) @pytest.mark.parametrize('complex_tensor', [ torch.randn(1, 2, 1025, 400, 2), torch.randn(1025, 400, 2) ]) @pytest.mark.parametrize('power', [1, 2, 0.7]) def test_complex_norm(complex_tensor, power): expected_norm_tensor = complex_tensor.pow(2).sum(-1).pow(power / 2) norm_tensor = complex_norm(complex_tensor, power) assert _approx_all_equal(expected_norm_tensor, norm_tensor, atol=1e-5) @pytest.mark.parametrize('new_len', [120, 36]) @pytest.mark.parametrize('mag_spec', [ torch.randn(1, 257, 391), torch.randn(1, 2, 257, 391), ]) def test_apply_filterbank(mag_spec, new_len): filterbank = torch.randn(mag_spec.size(-2), new_len) mag_spec_filterbanked = apply_filterbank(mag_spec, filterbank) assert mag_spec.size(-1) == mag_spec_filterbanked.size(-1) assert mag_spec_filterbanked.size(-2) == new_len assert mag_spec.dim() == mag_spec_filterbanked.dim() @pytest.mark.parametrize('amplitude,db', [ (torch.Tensor([0.000001, 0.0001, 0.1, 1.0, 10.0, 1000000.0]), torch.Tensor([-60.0, -40.0, -10.0, 0.0, 10.0, 60.0])) ]) def test_amplitude_db(amplitude, db): """Test amplitude_to_db and db_to_amplitude.""" amplitude = np.sqrt(amplitude) assert _approx_all_equal(db, amplitude_to_db(amplitude, ref=1.0)) assert _approx_all_equal(amplitude, db_to_amplitude(db, ref=1.0)) # both ways assert _approx_all_equal(db_to_amplitude(amplitude_to_db(amplitude, ref=1.0), ref=1.0), amplitude) assert _approx_all_equal(amplitude_to_db(db_to_amplitude(db, ref=1.0), ref=1.0), db, atol=1e-5) @pytest.mark.parametrize('waveform', [ torch.randn(1, 100000), (torch.randn(1, 2, 100000)), ]) @pytest.mark.parametrize('n_quantize', [256]) def test_mu_law(waveform, n_quantize): """test mu-law encoding and decoding""" def _test_mu_encoding(waveform, n_quantize): waveform = 2 * (waveform - 0.5) # manual computation mu = torch.tensor(n_quantize - 1, dtype=waveform.dtype) waveform_mu = waveform.sign() * torch.log1p(mu * waveform.abs()) / torch.log1p(mu) waveform_mu = ((waveform_mu + 1) / 2 * mu + 0.5).long() assert _all_equal(mu_law_encoding(waveform, n_quantize), waveform_mu) def _test_mu_decoding(waveform, n_quantize): waveform_mu = torch.randint(low=0, high=n_quantize - 1, size=(1, 1024)) # manual computation waveform_mu = waveform_mu.float() mu = torch.tensor(n_quantize - 1, dtype=waveform_mu.dtype) # confused about dtype here.. waveform = (waveform_mu / mu) * 2 - 1. waveform = waveform.sign() * (torch.exp(waveform.abs() * torch.log1p(mu)) - 1.) / mu assert _all_equal(mu_law_decoding(waveform_mu, n_quantize), waveform) def _test_both_ways(waveform, n_quantize): waveform_mu = torch.randint(low=0, high=n_quantize - 1, size=(1, 1024)) assert _all_equal(waveform_mu, mu_law_encoding(mu_law_decoding(waveform_mu, n_quantize), n_quantize)) _test_mu_encoding(waveform, n_quantize) _test_mu_decoding(waveform, n_quantize) _test_both_ways(waveform, n_quantize) ================================================ FILE: tests/test_layers.py ================================================ """ Test the layers. Currently only on cpu since travis doesn't have GPU. """ import unittest import pytest import librosa import numpy as np import torch import torch.nn as nn from torchaudio_contrib.layers import (STFT, Spectrogram, MelFilterbank, AmplitudeToDb, TimeStretch, ComplexNorm, ApplyFilterbank) from test_functional import _num_stft_bins xfail = pytest.mark.xfail @pytest.mark.parametrize('fft_len', [512]) @pytest.mark.parametrize('hop_length', [256]) @pytest.mark.parametrize('waveform', [ torch.randn(1, 100000) ]) @pytest.mark.parametrize('pad_mode', [ # 'constant', 'reflect', ]) def test_STFT(waveform, fft_len, hop_length, pad_mode): """ Test STFT for multi-channel signals. Padding: Value in having padding outside of torch.stft? """ pad = fft_len // 2 layer = STFT(fft_length=fft_len, hop_length=hop_length, pad_mode=pad_mode) assert torch.is_tensor(layer.window) assert not layer.window.requires_grad assert layer.window.size(0) <= layer.fft_length @pytest.mark.parametrize('rate', [0.7]) @pytest.mark.parametrize('complex_specgrams', [ torch.randn(1, 2, 1025, 400, 2), torch.randn(1, 1025, 400, 2) ]) @pytest.mark.parametrize('hop_length', [256]) def test_TimeStretch(complex_specgrams, rate, hop_length): layer = TimeStretch(hop_length=hop_length, num_freqs=complex_specgrams.shape[-3]) assert torch.is_tensor(layer.phase_advance) assert not layer.phase_advance.requires_grad @pytest.mark.parametrize('fft_length', [512]) @pytest.mark.parametrize('hop_length', [256]) @pytest.mark.parametrize('waveform', [ torch.randn(1, 100000), torch.randn(1, 2, 100000) ]) @pytest.mark.parametrize('pad_mode', [ # 'constant', 'reflect', ]) def test_SpectrogramDb(waveform, fft_length, hop_length, pad_mode): ref, amin = 1.0, 1e-7 window = torch.hann_window(fft_length) model = torch.nn.Sequential(*Spectrogram(fft_length, hop_length=hop_length, window=window, pad_mode=pad_mode), AmplitudeToDb(ref=ref, amin=amin)) db_spec = model(waveform).numpy() fft_config = dict(n_fft=fft_length, hop_length=hop_length, pad_mode=pad_mode) expected_db_spec = np.abs(np.apply_along_axis(librosa.stft, -1, waveform.numpy(), **fft_config)) db_config = dict(ref=ref, amin=amin, top_db=None) expected_db_spec = np.apply_along_axis(librosa.power_to_db, -1, expected_db_spec**2, **db_config) assert np.allclose(db_spec, expected_db_spec, atol=1e-2), np.abs(expected_db_spec - db_spec).max() @pytest.mark.parametrize('fft_length', [512]) @pytest.mark.parametrize('num_mels', [128]) @pytest.mark.parametrize('hop_length', [256]) @pytest.mark.parametrize('waveform', [ torch.randn(1, 2, 100000), torch.randn(4, 100000) ]) @pytest.mark.parametrize('rate', [0.7]) def test_MelspectrogramStretch(waveform, fft_length, num_mels, hop_length, rate): num_freqs = fft_length // 2 + 1 fb = MelFilterbank(num_freqs=num_freqs, num_mels=num_mels, max_freq=1.0).get_filterbank() model = nn.Sequential(STFT(fft_length, hop_length=hop_length), TimeStretch(hop_length=hop_length, num_freqs=num_freqs, fixed_rate=rate), ComplexNorm(power=2.0), ApplyFilterbank(fb)) mel_spec = model(waveform) num_bins = _num_stft_bins(waveform.size(-1), fft_length, hop_length, fft_length // 2) assert mel_spec.size(-2) == num_mels assert mel_spec.size(-1) == np.ceil(num_bins / rate) if __name__ == '__main__': unittest.main() ================================================ FILE: torchaudio_contrib/__init__.py ================================================ from .functional import * # noqa: F401 from .layers import * # noqa: F401 ================================================ FILE: torchaudio_contrib/beta_hpss.py ================================================ """This is a beta-version of harmonic-percussive source separation. Currently it only returns the separated magnitude spectrograms. Once we have inverse-STFT, we can extend it to get waveform results. TODO: add test """ import torch import torch.nn as nn import torch.nn.functional as F class HPSS(nn.Module): """ Wrap hpss. Args and Returns --> see `hpss`. """ def __init__(self, kernel_size=31, power=2.0, hard=False, mask_only=False): super(HPSS, self).__init__() self.kernel_size = kernel_size self.power = power self.hard = hard self.mask_only = mask_only def forward(self, mag_specgrams): return hpss(mag_specgrams, self.kernel_size, self.power, self.hard, self.mask_only) def __repr__(self): return self.__class__.__name__ + \ '(kernel_size={}, power={}, hard={}, mask_only={})'.format( self.kernel_size, self.power, self.hard, self.mask_only) def hpss(mag_specgrams, kernel_size=31, power=2.0, hard=False, mask_only=False): """ A function that performs harmonic-percussive source separation. Original method is by Derry Fitzgerald (https://www.researchgate.net/publication/254583990_HarmonicPercussive_Separation_using_Median_Filtering). Args: mag_specgrams (Tensor): any magnitude spectrograms in batch, (not in a decibel scale!) in a shape of (batch, ch, freq, time) kernel_size (int or (int, int)): odd-numbered if tuple, 1st: width of percussive-enhancing filter (one along freq axis) 2nd: width of harmonic-enhancing filter (one along time axis) if int, it's applied for both perc/harm filters power (float): to which the enhanced spectrograms are used in computing soft masks. hard (bool): whether the mask will be binarized (True) or not mask_only (bool): if true, returns the masks only. Returns: ret (Tuple): A tuple of four ret[0]: magnitude spectrograms - harmonic parts (Tensor, in same size with `mag_specgrams`) ret[1]: magnitude spectrograms - percussive parts (Tensor, in same size with `mag_specgrams`) ret[2]: harmonic mask (Tensor, in same size with `mag_specgrams`) ret[3]: percussive mask (Tensor, in same size with `mag_specgrams`) """ def _enhance_either_hpss(mag_specgrams_padded, out, kernel_size, power, which, offset): """ A helper function for HPSS Args: mag_specgrams_padded (Tensor): one that median filtering can be directly applied out (Tensor): The tensor to store the result kernel_size (int): The kernel size of median filter power (float): to which the enhanced spectrograms are used in computing soft masks. which (str): either 'harm' or 'perc' offset (int): the padded length """ if which == 'harm': for t in range(out.shape[3]): out[:, :, :, t] = torch.median(mag_specgrams_padded[:, :, offset:-offset, t:t + kernel_size], dim=3)[0] elif which == 'perc': for f in range(out.shape[2]): out[:, :, f, :] = torch.median(mag_specgrams_padded[:, :, f:f + kernel_size, offset:-offset], dim=2)[0] else: raise NotImplementedError('it should be either but you passed which={}'.format(which)) if power != 1.0: out.pow_(power) # end of the helper function eps = 1e-6 if not (isinstance(kernel_size, tuple) or isinstance(kernel_size, int)): raise TypeError('kernel_size is expected to be either tuple of input, but it is: %s' % type(kernel_size)) if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) pad = (kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2,) harm, perc, ret = torch.empty_like(mag_specgrams), torch.empty_like(mag_specgrams), torch.empty_like(mag_specgrams) mag_specgrams_padded = F.pad(mag_specgrams, pad=pad, mode='reflect') _enhance_either_hpss(mag_specgrams_padded, out=perc, kernel_size=kernel_size[0], power=power, which='perc', offset=kernel_size[1] // 2) _enhance_either_hpss(mag_specgrams_padded, out=harm, kernel_size=kernel_size[1], power=power, which='harm', offset=kernel_size[0] // 2) if hard: mask_harm = harm > perc mask_perc = harm < perc else: mask_harm = (harm + eps) / (harm + perc + eps) mask_perc = (perc + eps) / (harm + perc + eps) if mask_only: return None, None, mask_harm, mask_perc return mag_specgrams * mask_harm, mag_specgrams * mask_perc, mask_harm, mask_perc # def pss_src(x, kernel_size=31, power=2.0, hard=False): # """perform percusive source separation using `hpss()`. # x: (batch, time)""" # n_fft = 1024 # hop_length = 256 # x_stft = torch.stft(x, n_fft=n_fft, hop_length=hop_length) # x_mag = x_stft.pow(2).sum(-1).unsqueeze(1) # add channel dim # _, _, _, mask_perc = hpss(x_mag, kernel_size, power, hard, mask_only=True) # mask_perc.squeeze_(1).unsqueeze_(3) # remove channel, add last dim for complex # x_perc = time_freq.istft(x_stft * mask_perc, hop_length=hop_length, length=x.shape[1]) # return x_perc ================================================ FILE: torchaudio_contrib/functional.py ================================================ import torch import math def _mel_to_hertz(mel, htk): """ Converting mel values into frequency """ mel = torch.as_tensor(mel).type(torch.get_default_dtype()) if htk: return 700. * (10 ** (mel / 2595.) - 1.) f_min = 0.0 f_sp = 200.0 / 3 hz = f_min + f_sp * mel min_log_hz = 1000.0 min_log_mel = (min_log_hz - f_min) / f_sp logstep = math.log(6.4) / 27.0 return torch.where(mel >= min_log_mel, min_log_hz * torch.exp(logstep * (mel - min_log_mel)), hz) def _hertz_to_mel(hz, htk): """ Converting frequency into mel values """ hz = torch.as_tensor(hz).type(torch.get_default_dtype()) if htk: return 2595. * torch.log10(torch.tensor(1., dtype=torch.get_default_dtype()) + (hz / 700.)) f_min = 0.0 f_sp = 200.0 / 3 mel = (hz - f_min) / f_sp min_log_hz = 1000.0 min_log_mel = (min_log_hz - f_min) / f_sp logstep = math.log(6.4) / 27.0 return torch.where(hz >= min_log_hz, min_log_mel + torch.log(hz / min_log_hz) / logstep, mel) def stft(waveforms, fft_length, hop_length=None, win_length=None, window=None, center=True, pad_mode='reflect', normalized=False, onesided=True): """Compute a short-time Fourier transform of the input waveform(s). It wraps `torch.stft` but after reshaping the input audio to allow for `waveforms` that `.dim()` >= 3. It follows most of the `torch.stft` default value, but for `window`, if it's not specified (`None`), it uses hann window. Args: waveforms (Tensor): Tensor of audio signal of size `(*, channel, time)` fft_length (int): FFT size [sample] hop_length (int): Hop size [sample] between STFT frames. Defaults to `fft_length // 4` (75%-overlapping windows) by `torch.stft`. win_length (int): Size of STFT window. Defaults to `fft_length` by `torch.stft`. window (Tensor): 1-D Tensor. Defaults to Hann Window of size `win_length` *unlike* `torch.stft`. center (bool): Whether to pad `waveforms` on both sides so that the `t`-th frame is centered at time `t * hop_length`. Defaults to `True` by `torch.stft`. pad_mode (str): padding method (see `torch.nn.functional.pad`). Defaults to `'reflect'` by `torch.stft`. normalized (bool): Whether the results are normalized. Defaults to `False` by `torch.stft`. onesided (bool): Whether the half + 1 frequency bins are returned to removethe symmetric part of STFT of real-valued signal. Defaults to `True` by `torch.stft`. Returns: complex_specgrams (Tensor): `(*, channel, num_freqs, time, complex=2)` Example: >>> waveforms = torch.randn(16, 2, 10000) # (batch, channel, time) >>> x = stft(waveforms, 2048, 512) >>> x.shape torch.Size([16, 2, 1025, 20]) """ leading_dims = waveforms.shape[:-1] waveforms = waveforms.reshape(-1, waveforms.size(-1)) if window is None: if win_length is None: window = torch.hann_window(fft_length) else: window = torch.hann_window(win_length) complex_specgrams = torch.stft(waveforms, n_fft=fft_length, hop_length=hop_length, win_length=win_length, window=window, center=center, pad_mode=pad_mode, normalized=normalized, onesided=onesided) complex_specgrams = complex_specgrams.reshape( leading_dims + complex_specgrams.shape[1:]) return complex_specgrams def complex_norm(complex_tensor, power=1.0): """Compute the norm of complex tensor input Args: complex_tensor (Tensor): Tensor shape of `(*, complex=2)` power (float): Power of the norm. Defaults to `1.0`. Returns: Tensor: power of the normed input tensor, shape of `(*, )` """ if power == 1.0: return torch.norm(complex_tensor, 2, -1) return torch.norm(complex_tensor, 2, -1).pow(power) def create_mel_filter(num_freqs, num_mels, min_freq, max_freq, htk): """ Creates filter matrix to transform fft frequency bins into mel frequency bins. Equivalent to librosa.filters.mel(sample_rate, fft_len, htk=True, norm=None). Args: num_freqs (int): number of filter banks from stft. num_mels (int): number of mel bins. min_freq (float): minimum frequency. max_freq (float): maximum frequency. htk (bool): whether following htk-mel scale or not Returns: mel_filterbank (Tensor): (num_freqs, num_mels) """ # Convert to find mel lower/upper bounds m_min = _hertz_to_mel(min_freq, htk) m_max = _hertz_to_mel(max_freq, htk) # Compute stft frequency values stft_freqs = torch.linspace(min_freq, max_freq, num_freqs) # Find mel values, and convert them to frequency units m_pts = torch.linspace(m_min, m_max, num_mels + 2) f_pts = _mel_to_hertz(m_pts, htk) f_diff = f_pts[1:] - f_pts[:-1] # (num_mels + 1) # (num_freqs, num_mels + 2) slopes = f_pts.unsqueeze(0) - stft_freqs.unsqueeze(1) down_slopes = (-1. * slopes[:, :-2]) / f_diff[:-1] # (num_freqs, num_mels) up_slopes = slopes[:, 2:] / f_diff[1:] # (num_freqs, num_mels) mel_filterbank = torch.clamp(torch.min(down_slopes, up_slopes), min=0.) return mel_filterbank def apply_filterbank(mag_specgrams, filterbank): """ Transform spectrogram given a filterbank matrix. Args: mag_specgrams (Tensor): (batch, channel, num_freqs, time) filterbank (Tensor): (num_freqs, num_bands) Returns: (Tensor): (batch, channel, num_bands, time) """ return torch.matmul(mag_specgrams.transpose(-2, -1), filterbank).transpose(-2, -1) def angle(complex_tensor): """ Return angle of a complex tensor with shape (*, 2). """ return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0]) def magphase(complex_tensor, power=1.): """ Separate a complex-valued spectrogram with shape (*,2) into its magnitude and phase. """ mag = complex_norm(complex_tensor, power) phase = angle(complex_tensor) return mag, phase def phase_vocoder(complex_specgrams, rate, phase_advance): """ Phase vocoder. Given a STFT tensor, speed up in time without modifying pitch by a factor of `rate`. Args: complex_specgrams (Tensor): (*, channel, num_freqs, time, complex=2) rate (float): Speed-up factor. phase_advance (Tensor): Expected phase advance in each bin. (num_freqs, 1). Returns: complex_specgrams_stretch (Tensor): (*, channel, num_freqs, ceil(time/rate), complex=2). Example: >>> num_freqs, hop_length = 1025, 512 >>> # (batch, channel, num_freqs, time, complex=2) >>> complex_specgrams = torch.randn(16, 1, num_freqs, 300, 2) >>> rate = 1.3 # Slow down by 30% >>> phase_advance = torch.linspace( >>> 0, math.pi * hop_length, num_freqs)[..., None] >>> x = phase_vocoder(complex_specgrams, rate, phase_advance) >>> x.shape # with 231 == ceil(300 / 1.3) torch.Size([16, 1, 1025, 231, 2]) """ ndim = complex_specgrams.dim() time_slice = [slice(None)] * (ndim - 2) time_steps = torch.arange(0, complex_specgrams.size( -2), rate, device=complex_specgrams.device) alphas = torch.remainder(time_steps, torch.tensor(1., device=complex_specgrams.device)) phase_0 = angle(complex_specgrams[time_slice + [slice(1)]]) # Time Padding complex_specgrams = torch.nn.functional.pad( complex_specgrams, [0, 0, 0, 2]) complex_specgrams_0 = complex_specgrams[time_slice + [time_steps.long()]] # (new_bins, num_freqs, 2) complex_specgrams_1 = complex_specgrams[time_slice + [(time_steps + 1).long()]] angle_0 = angle(complex_specgrams_0) angle_1 = angle(complex_specgrams_1) norm_0 = torch.norm(complex_specgrams_0, dim=-1) norm_1 = torch.norm(complex_specgrams_1, dim=-1) phase = angle_1 - angle_0 - phase_advance phase = phase - 2 * math.pi * torch.round(phase / (2 * math.pi)) # Compute Phase Accum phase = phase + phase_advance phase = torch.cat([phase_0, phase[time_slice + [slice(-1)]]], dim=-1) phase_acc = torch.cumsum(phase, -1) mag = alphas * norm_1 + (1 - alphas) * norm_0 real_stretch = mag * torch.cos(phase_acc) imag_stretch = mag * torch.sin(phase_acc) complex_specgrams_stretch = torch.stack( [real_stretch, imag_stretch], dim=-1) return complex_specgrams_stretch def amplitude_to_db(x, ref=1.0, amin=1e-7): """ Amplitude-to-decibel conversion (logarithmic mapping with base=10) By using `amin=1e-7`, it assumes 32-bit floating point input. If the data precision differs, use approproate `amin` accordingly. Args: x (Tensor): Input amplitude ref (float): Amplitude value that is equivalent to 0 decibel amin (float): Minimum amplitude. Any input that is smaller than `amin` is clamped to `amin`. Returns: (Tensor): same size of x, after conversion """ x = x.pow(2.) x = torch.clamp(x, min=amin) return 10.0 * (torch.log10(x) - torch.log10(torch.tensor(ref, device=x.device, requires_grad=False, dtype=x.dtype))) def db_to_amplitude(x, ref=1.0): """ Decibel-to-amplitude conversion (exponential mapping with base=10) Args: x (Tensor): Input in decibel to be converted ref (float): Amplitude value that is equivalent to 0 decibel Returns: (Tensor): same size of x, after conversion """ power_spec = torch.pow(10.0, x / 10.0 + torch.log10(torch.tensor(ref, device=x.device, requires_grad=False, dtype=x.dtype))) return power_spec.pow(0.5) def mu_law_encoding(x, n_quantize=256): """Apply mu-law encoding to the input tensor. Usually applied to waveforms Args: x (Tensor): input value n_quantize (int): quantization level. For 8-bit encoding, set 256 (2 ** 8). Returns: (Tensor): same size of x, after encoding """ if not x.dtype.is_floating_point: x = x.to(torch.float) mu = torch.tensor(n_quantize - 1, dtype=x.dtype, requires_grad=False) # confused about dtype here.. x_mu = x.sign() * torch.log1p(mu * x.abs()) / torch.log1p(mu) x_mu = ((x_mu + 1) / 2 * mu + 0.5).long() return x_mu def mu_law_decoding(x_mu, n_quantize=256, dtype=torch.get_default_dtype()): """Apply mu-law decoding (expansion) to the input tensor. Args: x_mu (Tensor): mu-law encoded input n_quantize (int): quantization level. For 8-bit decoding, set 256 (2 ** 8). dtype: specifies `dtype` for the decoded value. Default: `torch.get_default_dtype()` Returns: (Tensor): mu-law decoded tensor """ if not x_mu.dtype.is_floating_point: x_mu = x_mu.to(dtype) mu = torch.tensor(n_quantize - 1, dtype=x_mu.dtype, requires_grad=False) # confused about dtype here.. x = (x_mu / mu) * 2 - 1. x = x.sign() * (torch.exp(x.abs() * torch.log1p(mu)) - 1.) / mu return x ================================================ FILE: torchaudio_contrib/layers.py ================================================ import torch import math import torch.nn as nn from .functional import stft, complex_norm, \ create_mel_filter, phase_vocoder, apply_filterbank, \ amplitude_to_db, db_to_amplitude, \ mu_law_encoding, mu_law_decoding class _ModuleNoStateBuffers(nn.Module): """ Extension of nn.Module that removes buffers from state_dict. """ def state_dict(self, destination=None, prefix='', keep_vars=False): ret = super(_ModuleNoStateBuffers, self).state_dict( destination, prefix, keep_vars) for k in self._buffers: del ret[prefix + k] return ret def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): # temporarily hide the buffers; we do not want to restore them buffers = self._buffers self._buffers = {} result = super(_ModuleNoStateBuffers, self)._load_from_state_dict( state_dict, prefix, *args, **kwargs) self._buffers = buffers return result class STFT(_ModuleNoStateBuffers): """Compute a short-time Fourier transform of the input waveform(s). It essentially wraps `torch.stft` but after reshaping the input audio to allow for `waveforms` that `.dim()` >= 3. It follows most of the `torch.stft` default value, but for `window`, if it's not specified (`None`), it uses hann window. Args: fft_length (int): FFT size [sample] hop_length (int): Hop size [sample] between STFT frames. Defaults to `fft_length // 4` (75%-overlapping windows) by `torch.stft`. win_length (int): Size of STFT window. Defaults to `fft_length` by `torch.stft`. window (Tensor): 1-D Tensor. Defaults to Hann Window of size `win_length` *unlike* `torch.stft`. center (bool): Whether to pad `waveforms` on both sides so that the `t`-th frame is centered at time `t * hop_length`. Defaults to `True` by `torch.stft`. pad_mode (str): padding method (see `torch.nn.functional.pad`). Defaults to `'reflect'` by `torch.stft`. normalized (bool): Whether the results are normalized. Defaults to `False` by `torch.stft`. onesided (bool): Whether the half + 1 frequency bins are returned to remove the symmetric part of STFT of real-valued signal. Defaults to `True` by `torch.stft`. """ def __init__(self, fft_length, hop_length=None, win_length=None, window=None, center=True, pad_mode='reflect', normalized=False, onesided=True): super(STFT, self).__init__() self.fft_length = fft_length self.hop_length = hop_length self.win_length = win_length self.center = center self.pad_mode = pad_mode self.normalized = normalized self.onesided = onesided if window is None: if win_length is None: window = torch.hann_window(fft_length) else: window = torch.hann_window(win_length) self.register_buffer('window', window) def forward(self, waveforms): """ Args: waveforms (Tensor): Tensor of audio signal of size `(*, channel, time)` Returns: complex_specgrams (Tensor): `(*, channel, num_freqs, time, complex=2)` """ complex_specgrams = stft(waveforms, self.fft_length, hop_length=self.hop_length, win_length=self.win_length, window=self.window, center=self.center, pad_mode=self.pad_mode, normalized=self.normalized, onesided=self.onesided) return complex_specgrams def __repr__(self): param_str1 = '(fft_length={}, hop_length={}, win_length={})'.format( self.fft_length, self.hop_length, self.win_length) param_str2 = '(center={}, pad_mode={}, normalized={}, onesided={})'.format( self.center, self.pad_mode, self.normalized, self.onesided) return self.__class__.__name__ + param_str1 + param_str2 class ComplexNorm(nn.Module): """Compute the norm of complex tensor input Args: power (float): Power of the norm. Defaults to `1.0`. """ def __init__(self, power=1.0): super(ComplexNorm, self).__init__() self.power = power def forward(self, complex_tensor): """ Args: complex_tensor (Tensor): Tensor shape of `(*, complex=2)` Returns: Tensor: norm of the input tensor, shape of `(*, )` """ return complex_norm(complex_tensor, self.power) def __repr__(self): return self.__class__.__name__ + '(power={})'.format(self.power) class ApplyFilterbank(_ModuleNoStateBuffers): """ Applies a filterbank transform. """ def __init__(self, filterbank): super(ApplyFilterbank, self).__init__() self.register_buffer('filterbank', filterbank) def forward(self, mag_specgrams): """ Args: mag_specgrams (Tensor): (channel, time, freq) or (batch, channel, time, freq). Returns: (Tensor): freq -> filterbank.size(0) """ return apply_filterbank(mag_specgrams, self.filterbank) class Filterbank(object): """ Base class for providing a filterbank matrix. """ def __init__(self): super(Filterbank, self).__init__() def get_filterbank(self): raise NotImplementedError class MelFilterbank(Filterbank): """ Provides a filterbank matrix to convert a spectrogram into a mel frequency spectrogram. Args: num_freqs (int, optional): number of filter banks from stft. Defaults to 2048//2 + 1. num_mels (int): number of mel bins. Defaults to 128. min_freq (float): minimum frequency. Defaults to 0. max_freq (float, optional): maximum frequency. Defaults to sample_rate // 2. sample_rate (int): sample rate of audio signal. Defaults to None. htk (bool, optional): use HTK formula instead of Slaney. Defaults to False. """ def __init__(self, num_freqs=1025, num_mels=128, min_freq=0.0, max_freq=None, sample_rate=None, htk=False): super(MelFilterbank, self).__init__() if sample_rate is None and max_freq is None: raise ValueError('Either max_freq or sample_rate should be specified.' ', but both are None.') self.num_freqs = num_freqs self.num_mels = num_mels self.min_freq = min_freq self.max_freq = max_freq if max_freq else sample_rate // 2 self.htk = htk def get_filterbank(self): return create_mel_filter( num_freqs=self.num_freqs, num_mels=self.num_mels, min_freq=self.min_freq, max_freq=self.max_freq, htk=self.htk) def __repr__(self): param_str1 = '(num_freqs={}, snum_mels={}'.format( self.num_freqs, self.num_mels) param_str2 = ', min_freq={}, max_freq={})'.format( self.min_freq, self.max_freq) param_str3 = ', htk={}'.format( self.htk) return self.__class__.__name__ + param_str1 + param_str2 + param_str3 class TimeStretch(_ModuleNoStateBuffers): """ Stretch stft in time without modifying pitch for a given rate. Args: hop_length (int): Number audio of frames between STFT columns. num_freqs (int, optional): number of filter banks from stft. fixed_rate (float): rate to speed up or slow down by. Defaults to None (in which case a rate must be passed to the forward method per batch). """ def __init__(self, hop_length, num_freqs, fixed_rate=None): super(TimeStretch, self).__init__() self.fixed_rate = fixed_rate phase_advance = torch.linspace( 0, math.pi * hop_length, num_freqs)[..., None] self.register_buffer('phase_advance', phase_advance) def forward(self, complex_specgrams, overriding_rate=None): """ Args: complex_specgrams (Tensor): complex spectrogram (*, channel, freq, time, complex=2) overriding_rate (float or None): speed up to apply to this batch. If no rate is passed, use self.fixed_rate. Returns: (Tensor): (*, channel, num_freqs, ceil(time/rate), complex=2) """ if overriding_rate is None: rate = self.fixed_rate if rate is None: raise ValueError("If no fixed_rate is specified" ", must pass a valid rate to the forward method.") else: rate = overriding_rate if rate == 1.0: return complex_specgrams return phase_vocoder(complex_specgrams, rate, self.phase_advance) def __repr__(self): param_str = '(fixed_rate={})'.format(self.fixed_rate) return self.__class__.__name__ + param_str def Spectrogram(fft_length, hop_length=None, win_length=None, window=None, center=True, pad_mode='reflect', normalized=False, onesided=True, power=1.): """Get spectrogram module, which is a Sequential module of `[STFT(), ComplexNorm()]`. Args: fft_length (int): FFT size [sample] hop_length (int): Hop size [sample] between STFT frames. Defaults to `fft_length // 4` (75%-overlapping windows) by `torch.stft`. win_length (int): Size of STFT window. Defaults to `fft_length` by `torch.stft`. window (Tensor): 1-D Tensor. Defaults to Hann Window of size `win_length` *unlike* `torch.stft`. center (bool): Whether to pad `waveforms` on both sides so that the `t`-th frame is centered at time `t * hop_length`. Defaults to `True` by `torch.stft`. pad_mode (str): padding method (see `torch.nn.functional.pad`). Defaults to `'reflect'` by `torch.stft`. normalized (bool): Whether the results are normalized. Defaults to `False` by `torch.stft`. onesided (bool): Whether the half + 1 frequency bins are returned to remove the symmetric part of STFT of real-valued signal. Defaults to `True` by `torch.stft`. power (float): Exponent of the magnitude. Defaults to `1.0`. """ return nn.Sequential( STFT( fft_length, hop_length, win_length, window, center, pad_mode, normalized, onesided), ComplexNorm(power)) def Melspectrogram( num_mels=128, sample_rate=22050, min_freq=0.0, max_freq=None, num_freqs=None, htk=False, mel_filterbank=None, **kwargs): """ Get melspectrogram module. Args: num_mels (int): number of mel bins. Defaults to 128. sample_rate (int): sample rate of audio signal. Defaults to 22050. min_freq (float): minimum frequency. Defaults to 0. max_freq (float, optional): maximum frequency. Defaults to sample_rate // 2. num_freqs (int, optional): number of filter banks from stft. Defaults to fft_len//2 + 1 if 'fft_len' in kwargs else 1025. htk (bool, optional): use HTK formula instead of Slaney. Defaults to False. mel_filterbank (class, optional): MelFilterbank class to build filterbank matrix **kwargs: torchaudio_contrib.Spectrogram parameters. """ fft_length = kwargs.get('fft_length', None) num_freqs = fft_length // 2 + 1 if fft_length else 1025 # keunwoo: Why is num_freqs specified like this and not by the passed argument? # Check if custom MelFilterbank is passed if mel_filterbank is None: mel_filterbank = MelFilterbank mel_fb_matrix = mel_filterbank( num_mels=num_mels, sample_rate=sample_rate, min_freq=min_freq, max_freq=max_freq, num_freqs=num_freqs, htk=htk).get_filterbank() return nn.Sequential(*Spectrogram(power=2., **kwargs), ApplyFilterbank(mel_fb_matrix)) class AmplitudeToDb(_ModuleNoStateBuffers): """ Amplitude-to-decibel conversion (logarithmic mapping with base=10) By using `amin=1e-7`, it assumes 32-bit floating point input. If the data precision differs, use approproate `amin` accordingly. Args: ref (float): Amplitude value that is equivalent to 0 decibel amin (float): Minimum amplitude. Any input that is smaller than `amin` is clamped to `amin`. """ def __init__(self, ref=1.0, amin=1e-7): super(AmplitudeToDb, self).__init__() self.ref = ref self.amin = amin assert ref > amin, "Reference value is expected to be bigger than amin, but I have" \ "ref:{} and amin:{}".format(ref, amin) def forward(self, x): """ Args: x (Tensor): Input amplitude Returns: (Tensor): same size of x, after conversion """ return amplitude_to_db(x, ref=self.ref, amin=self.amin) def __repr__(self): param_str = '(ref={}, amin={})'.format(self.ref, self.amin) return self.__class__.__name__ + param_str class DbToAmplitude(_ModuleNoStateBuffers): """ Decibel-to-amplitude conversion (exponential mapping with base=10) Args: x (Tensor): Input in decibel to be converted ref (float): Amplitude value that is equivalent to 0 decibel Returns: (Tensor): same size of x, after conversion """ def __init__(self, ref=1.0): super(DbToAmplitude, self).__init__() self.ref = ref def forward(self, x): """ Args: x (Tensor): Input in decibel to be converted Returns: (Tensor): same size of x, after conversion """ return db_to_amplitude(x, ref=self.ref) def __repr__(self): param_str = '(ref={})'.format(self.ref) return self.__class__.__name__ + param_str class MuLawEncoding(_ModuleNoStateBuffers): """Apply mu-law encoding to the input tensor. Usually applied to waveforms Args: n_quantize (int): quantization level. For 8-bit encoding, set 256 (2 ** 8). """ def __init__(self, n_quantize=256): super(MuLawEncoding, self).__init__() self.n_quantize = n_quantize def forward(self, x): """ Args: x (Tensor): input value Returns: (Tensor): same size of x, after encoding """ return mu_law_encoding(x, self.n_quantize) def __repr__(self): param_str = '(n_quantize={})'.format(self.n_quantize) return self.__class__.__name__ + param_str class MuLawDecoding(_ModuleNoStateBuffers): """Apply mu-law decoding (expansion) to the input tensor. Usually applied to waveforms Args: n_quantize (int): quantization level. For 8-bit decoding, set 256 (2 ** 8). """ def __init__(self, n_quantize=256): super(MuLawDecoding, self).__init__() self.n_quantize = n_quantize def forward(self, x_mu): """ Args: x_mu (Tensor): mu-law encoded input Returns: (Tensor): mu-law decoded tensor """ return mu_law_decoding(x_mu, self.n_quantize) def __repr__(self): param_str = '(n_quantize={})'.format(self.n_quantize) return self.__class__.__name__ + param_str