Repository: tencent-ailab/bddm
Branch: main
Commit: 2cebe0e6b7fd
Files: 43
Total size: 70.9 MB
Directory structure:
gitextract_wr3lh52z/
├── .gitignore
├── LICENSE
├── README.md
├── bddm/
│ ├── loader/
│ │ ├── __init__.py
│ │ ├── dataset.py
│ │ └── stft.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── diffwave.py
│ │ └── galr.py
│ ├── sampler/
│ │ ├── __init__.py
│ │ └── sampler.py
│ ├── trainer/
│ │ ├── __init__.py
│ │ ├── bmuf.py
│ │ ├── ema.py
│ │ ├── loss.py
│ │ └── trainer.py
│ └── utils/
│ ├── check_utils.py
│ ├── diffusion_utils.py
│ └── log_utils.py
└── egs/
├── lj/
│ ├── DiffWave.pkl
│ ├── DiffWave_GALR.pkl
│ ├── conf.yml
│ ├── demo_case/
│ │ ├── LJ001-0001.mel
│ │ ├── LJ004-0233.mel
│ │ └── LJ050-0270.mel
│ ├── eval.list
│ ├── generate.sh
│ ├── main.py
│ ├── noise_schedules/
│ │ ├── 12steps_PESQ4.07_STOI0.978.ns
│ │ └── 7steps_PESQ4.05_STOI0.975.ns
│ ├── schedule.sh
│ ├── train.sh
│ └── valid.list
└── vctk/
├── DiffWave.pkl
├── DiffWave_GALR.pkl
├── conf.yml
├── demo_case/
│ ├── p259_311.mel
│ └── p287_123.mel
├── generate.sh
├── main.py
├── noise_schedules/
│ └── 8steps_PESQ3.88_STOI0.987.ns
├── schedule.sh
└── train.sh
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
egs/lj/exp
egs/lj/LJSpeech1.1
*.log
*.pyc
__pycache__/
*/__pycache__/
*.zip
*/*.zip
source/
make.bat
Makefile
*.swp
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2022 Tencent
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# Bilateral Denoising Diffusion Models (BDDMs)
[](https://github.com/tencent-ailab/bddm)

[](https://arxiv.org/abs/2203.13508)
[](https://bilateral-denoising-diffusion-model.github.io)
This is the official PyTorch implementation of the following paper:
> **BDDM: BILATERAL DENOISING DIFFUSION MODELS FOR FAST AND HIGH-QUALITY SPEECH SYNTHESIS** \
> Max W. Y. Lam, Jun Wang, Dan Su, Dong Yu
> **Abstract**: *Diffusion probabilistic models (DPMs) and their extensions have emerged as competitive generative models yet confront challenges of efficient sampling. We propose a new bilateral denoising diffusion model (BDDM) that parameterizes both the forward and reverse processes with a schedule network and a score network, which can train with a novel bilateral modeling objective. We show that the new surrogate objective can achieve a lower bound of the log marginal likelihood tighter than a conventional surrogate. We also find that BDDM allows inheriting pre-trained score network parameters from any DPMs and consequently enables speedy and stable learning of the schedule network and optimization of a noise schedule for sampling. Our experiments demonstrate that BDDMs can generate high-fidelity audio samples with as few as three sampling steps. Moreover, compared to other state-of-the-art diffusion-based neural vocoders, BDDMs produce comparable or higher quality samples indistinguishable from human speech, notably with only seven sampling steps (143x faster than WaveGrad and 28.6x faster than DiffWave).*
> **Paper**: Published at ICLR 2022 on [OpenReview](https://openreview.net/pdf?id=L7wzpQttNO)

This implementation supports model training and audio generation, and also provides the pre-trained models for the benchmark [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) and [VCTK](https://datashare.ed.ac.uk/handle/10283/2651) dataset.
Visit our [demo page](https://bilateral-denoising-diffusion-model.github.io) for audio samples.
### Updates:
- **May 20, 2021:** Released our follow-up work [FastDiff](https://github.com/Rongjiehuang/FastDiff) on GitHub, where we futher optimized the speed-and-quality trade-off.
- **May 10, 2021:** Added the experiment configurations and model checkpoints for the VCTK dataset.
- **May 9, 2021:** Added the searched noise schedules for the LJSpeech and VCTK datasets.
- **March 20, 2021:** Released the PyTorch implementation of BDDM with pre-trained models for the LJSpeech dataset.
### Recipes:
- (Option 1) To train the BDDM scheduling network yourself, you can download the pre-trained score network from [philsyn/DiffWave-Vocoder](https://github.com/philsyn/DiffWave-Vocoder/blob/master/exp/ch128_T200_betaT0.02/logs/checkpoint/1000000.pkl) (provided at ```egs/lj/DiffWave.pkl```), and follow the training steps below. **(Start from Step I.)**
- (Option 2) To search for noise schedules using BDDM, we provide a pre-trained BDDM for LJSpeech at ```egs/lj/DiffWave-GALR.pkl``` and for VCTK at ```egs/vctk/DiffWave-GALR.pkl``` . **(Start from Step III.)**
- (Option 3) To directly generate samples using BDDM, we provide the searched schedules for LJSpeech at ```egs/lj/noise_schedules``` and for VCTK at ```egs/vctk/noise_schedules``` (check ```conf.yml``` for the respective configurations). **(Start from Step IV.)**
## Getting Started
We provide an example of how you can generate high-fidelity samples using BDDMs.
To try BDDM on your own dataset, simply clone this repo in your local machine provided with NVIDIA GPU + CUDA cuDNN and follow the below intructions.
### Dependencies
- [pytorch](https://github.com/pytorch/pytorch)>=1.7.1
- [librosa](https://github.com/librosa/librosa)>=0.7.1
- [pystoi](https://github.com/mpariente/pystoi)==0.3.3
- [pypesq](https://github.com/youngjamespark/python-pypesq)==0.2.0
### Step I. Data Preparation and Configuraion ###
Download the [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) dataset.
For training, we first need to setup a file **conf.yml** for configuring the data loader, the score and the schedule networks, the training procedure, the noise scheduling and sampling parameters.
**Note:** Appropriately modify the paths in ```"train_data_dir"``` and ```"valid_data_dir"``` for training; and the path in ```"gen_data_dir"``` for sampling. All dir paths should be link to a directory that store the waveform audios (in **.wav**) or the Mel-spectrogram files (in **.mel**).
### Step II. Training a Schedule Network ###
Suppose that a well-trained score network (theta) is stored at ```$theta_path```, we start by modifying ```"load": $theta_path``` in **conf.yml**.
After modifying the relevant hyperparameters for a schedule network (especially ```"tau"```), we can train the schedule network (f_phi in paper) using:
```bash
# Training on device 0 (supports multi-GPU training)
sh train.sh 0 conf.yml
```
**Note**: In practice, we found that **10K** training steps would be enough to obtain a promising scheduling network. This normally takes no more than half an hour for training with one GPU.
### Step III. Searching for Noise Schedules ###
Given a well-trained BDDM (theta, phi), we can now run the noise scheduling algorithm to find the best schedule (optimizing the trade-off between quality and speed).
First, we set ```"load"``` in ```conf.yml``` to the path of the trained BDDM.
After setting the maximum number of sampling steps in scheduling (```"N"```), we run:
```bash
# Scheduling on device 0 (only supports single-GPU scheduling)
sh schedule.sh 0 conf.yml
```
### Step IV. Evaluation or Generation ###
For evaluation, we set ```"gen_data_dir"``` in ```conf.yml``` to the path of a directory that stores the test set of audios (in ```.wav```).
For generation, we set ```"gen_data_dir"``` in ```conf.yml``` to the path of a directory that stores the Mel-spectrogram (by default in ```.mel``` generated by [TacotronSTFT](https://github.com/NVIDIA/tacotron2/blob/master/layers.py) or by our dataset loader ```bddm/loader/dataset.py```).
Then, we run:
```bash
# Generation/evaluation on device 0 (only supports single-GPU generation)
sh generate.sh 0 conf.yml
```
## Acknowledgements
This implementation uses parts of the code from the following Github repos:\
[Tacotron2](https://github.com/NVIDIA/tacotron2)\
[DiffWave-Vocoder](https://github.com/philsyn/DiffWave-Vocoder)\
as described in our code.
## Citations ##
```
@inproceedings{lam2022bddm,
title={BDDM: Bilateral Denoising Diffusion Models for Fast and High-Quality Speech Synthesis},
author={Lam, Max WY and Wang, Jun and Su, Dan and Yu, Dong},
booktitle={International Conference on Learning Representations},
year={2022}
}
```
## License ##
Copyright 2022 Tencent
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
## Disclaimer ##
This is not an officially supported Tencent product.
================================================
FILE: bddm/loader/__init__.py
================================================
#!/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
# Loader
#
# Author: Max W. Y. Lam (maxwylam@tencent.com)
# Copyright (c) 2021Tencent. All Rights Reserved
#
########################################################################
================================================
FILE: bddm/loader/dataset.py
================================================
#!/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
# Dataset and DataLoader for Neural Vocoding
#
# Author: Max W. Y. Lam (maxwylam@tencent.com)
# Copyright (c) 2021Tencent. All Rights Reserved
#
########################################################################
import os
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.data as data
from scipy.io.wavfile import read
from .stft import TacotronSTFT
MAX_WAV_VALUE = 32768
class SpectrogramDataset(data.Dataset):
def __init__(self, data_dir, n_gpus, is_sampling, sampling_rate,
seg_len, fil_len, hop_len, win_len, mel_fmin, mel_fmax):
"""
A torch.data.Dataset class that loads the audio files for training
or loads the Mel-spectrogram files for sampling.
Parameters:
data_dir (str): the path to the directory storing .wav/.mel files
n_gpus (int): the number of GPUs for training
is_sampling (bool): whether the dataset is used for sampling or not
sampling_rate (int): the sampling rate of audios
seg_len (int): the segment length (number of samples) for training
fil_len (int): the filter length for computing STFT
hop_len (int): the hop length for computing STFT
win_len (int): the window length for computing STFT
mel_fmin (int): the minimum frequency for computing STFT
mel_fmax (int): the maximum frequency for computing STFT
"""
self.n_gpus = n_gpus
self.is_sampling = is_sampling
self.seg_len = seg_len
self.hop_len = hop_len
self.n_mels = self.seg_len // self.hop_len
self.sampling_rate = sampling_rate
if is_sampling:
# Find all Mel-spectrogram files in the given data directory
self.mel_files = self.find_all_mels_in_dir(data_dir)
if len(self.mel_files) == 0:
# Find audios when no pre-computed mel spectrograms can be found
self.audio_files = self.find_all_wavs_in_dir(data_dir)
# Note that no mel file is loaded for generation
self.mel_files = None
else:
# Note that no audio file is loaded for generation
self.audio_files = None
else:
# Find all audio files in the given data directory
self.audio_files = self.find_all_wavs_in_dir(data_dir)
# Use the standard STFT operation defined in Tacotron 2
self.stft = TacotronSTFT(filter_length=fil_len,
hop_length=hop_len,
win_length=win_len,
sampling_rate=sampling_rate,
mel_fmin=mel_fmin,
mel_fmax=mel_fmax)
self.reset()
def reset(self):
"""
Reset the loader by shuffling the file list
"""
if self.is_sampling and self.audio_files is None:
np.random.shuffle(self.mel_files)
else:
np.random.shuffle(self.audio_files)
self.n_mels = self.seg_len // self.hop_len
self.n_mels = int(self.n_mels * (1 + np.random.rand()))
# Make sure the number of samples is divisible by n_gpus
if len(self.audio_files) % self.n_gpus != 0:
remainder = len(self.audio_files) % self.n_gpus
self.audio_files = self.audio_files[:-remainder]
def find_all_wavs_in_dir(self, data_dir):
"""
Load all .wav files in data_dir
Parameters:
data_dir (str): the path to the directory storing .wav files
Returns:
files_list (list): the list of wav file paths
"""
files = [f for f in Path(data_dir).glob('*.wav')]
if len(files) == 0:
files = [f for f in Path(data_dir).glob('*_wav.npy')]
return files
def find_all_mels_in_dir(self, data_dir):
"""
Load all .mel files in data_dir
Parameters:
data_dir (str): the path to the directory storing .mel files
Returns:
files_list (list): the list of mel file paths
"""
files = [f for f in Path(data_dir).glob('*.mel')]
return files
def crop_audio_and_mel(self, audio, mel_spec):
"""
Randomly crop audio and mel_spec into a fixed-length segment
Parameters:
audio (tensor): the full audio
mel_spec (tensor): the full mel-spectrogram computed using TacotronSTFT
Returns:
audio (tensor): the cropped audio
mel_spec (tensor): the cropped mel-spectrogram
"""
n_mels = self.n_mels
seg_len = n_mels * self.hop_len
if audio.size(-1) >= seg_len:
if mel_spec.size(-1) > n_mels:
max_mel_start = mel_spec.size(-1) - n_mels
mel_start = np.random.randint(0, max_mel_start)
mel_spec = mel_spec[..., mel_start:mel_start+n_mels]
audio_start = mel_start * self.hop_len
audio = audio[..., audio_start:audio_start+seg_len]
elif mel_spec.size(-1) == n_mels:
audio = audio[..., :seg_len]
else:
audio = audio[..., :seg_len]
mel_spec = F.pad(mel_spec, (0, n_mels - mel_spec.size(-1)), 'constant', 0)
else:
audio = F.pad(audio, (0, seg_len - audio.size(-1)), 'constant', 0)
mel_spec = F.pad(mel_spec, (0, n_mels - mel_spec.size(-1)), 'constant', 0)
return audio, mel_spec
def __getitem__(self, index):
"""
Get a pair of data (mel-spectrogram, audio) given an index
Parameters:
index (int): the index for loading one sample
Returns:
audio_key (str): the audio key
mel_spec (tensor): the mel-spectrogram computed using TacotronSTFT
audio (tensor): the ground-truth audio
"""
if self.is_sampling and self.audio_files is None:
# Load Mel-spectrogram for sampling
mel_spec = torch.load(self.mel_files[index], map_location='cpu').float()
if mel_spec.ndim == 3:
mel_spec = mel_spec[0]
audio_key = str(self.mel_files[index])
# Try to find the paired source audio
wav_path = str(self.mel_files[index])[:-4]+'.wav'
if os.path.isfile(wav_path):
sampling_rate, audio = read(wav_path)
audio = torch.from_numpy(audio[None]).float() / MAX_WAV_VALUE
assert sampling_rate == self.sampling_rate
return audio_key, mel_spec, audio
else:
return audio_key, mel_spec, []
mel_spec = None
# Load the audio file to torch.FloatTensor
if str(self.audio_files[index])[-3:] == 'npy':
audio = np.load(self.audio_files[index])[0]
mel_spec = np.load(str(self.audio_files[index]).replace('wav', 'mel'))[0]
# Load into torch.FloatTensor
audio = torch.from_numpy(audio).float()
mel_spec = torch.from_numpy(mel_spec).float()
else:
sampling_rate, audio = read(self.audio_files[index])
# Make sure the sampling rate is correctly defined
assert sampling_rate == self.sampling_rate
# Normalize the audio into [-1, 1]
audio = audio / MAX_WAV_VALUE
# Load into torch.FloatTensor
audio = torch.from_numpy(audio).float()
# Compute Mel-spectrogram (shape = [T] -> [filter_length, L])
if mel_spec is None:
audio = audio[None]
mel_spec = self.stft.mel_spectrogram(audio)
mel_spec = torch.squeeze(mel_spec, 0)
if not self.is_sampling:
audio, mel_spec = self.crop_audio_and_mel(audio, mel_spec)
if self.is_sampling:
# Save ground-truth Mel-spectrogram into .mel file
torch.save(mel_spec, str(self.audio_files[index])[:-4]+'.mel')
return str(self.audio_files[index]), mel_spec, audio
return mel_spec, audio
def __len__(self):
"""
Get the number of data
Returns:
data_len (int): the number of .wav/.mel files found in data_dir
"""
if self.audio_files is None:
return len(self.mel_files)
return len(self.audio_files)
def create_train_and_valid_dataloader(config):
"""
Create two torch.data.DataLoader for training and validation
Parameters:
config (namespace): BDDM Configuration
Returns:
tr_loader (DataLoader): the data loader for training
vl_loader (DataLoader): the data loader for validation
"""
n_gpus = 1
if 'WORLD_SIZE' in os.environ.keys():
n_gpus = int(os.environ['WORLD_SIZE'])
conf_keys = SpectrogramDataset.__init__.__code__.co_varnames
data_config = {k: v for k, v in vars(config).items() if k in conf_keys}
data_config["data_dir"] = config.train_data_dir
data_config["is_sampling"] = False
data_config["n_gpus"] = n_gpus
dataset = SpectrogramDataset(**data_config)
assert len(dataset) > 0, f"Error: No .wav can be found at {config.train_data_dir} !"
sampler = data.distributed.DistributedSampler(dataset) if n_gpus > 1 else None
tr_loader = data.DataLoader(dataset,
sampler=sampler,
batch_size=config.batch_size,
num_workers=config.n_worker,
pin_memory=False,
drop_last=True)
data_config["data_dir"] = config.valid_data_dir
dataset = SpectrogramDataset(**data_config)
assert len(dataset) > 0, f"Error: No .wav can be found at {config.valid_data_dir} !"
sampler = data.distributed.DistributedSampler(dataset) if n_gpus > 1 else None
vl_loader = data.DataLoader(dataset,
sampler=sampler,
batch_size=1,
num_workers=config.n_worker,
pin_memory=False)
return tr_loader, vl_loader
def create_generation_dataloader(config):
"""
Create a torch.data.DataLoader for generation
Parameters:
config (namespace): BDDM Configuration
Returns:
gen_loader (DataLoader): the data loader for generation
"""
conf_keys = SpectrogramDataset.__init__.__code__.co_varnames
data_config = {k: v for k, v in vars(config).items() if k in conf_keys}
data_config["data_dir"] = config.gen_data_dir
data_config["is_sampling"] = True
data_config["n_gpus"] = 1
dataset = SpectrogramDataset(**data_config)
gen_loader = data.DataLoader(dataset,
batch_size=1, # variable audio length
num_workers=config.n_worker,
pin_memory=False)
return gen_loader
if __name__ == "__main__":
from argparse import Namespace
data_config = {
"data_dir": "LJSpeech-1.1/train_wavs",
"sampling_rate": 22050,
"training": True,
"seg_len": 22050,
"fil_len": 1024,
"hop_len": 256,
"win_len": 1024,
"mel_fmin": 0.0,
"mel_fmax": 8000.0,
"extra_key": "extra_val"
}
data_config = Namespace(**data_config)
loader = create_train_and_valid_dataloader(data_config)
for batch in loader:
print(len(batch))
print(batch[0].shape, batch[1].shape)
break
================================================
FILE: bddm/loader/stft.py
================================================
"""
BSD 3-Clause License
Copyright (c) 2017, Prem Seetharaman
All rights reserved.
* Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this
list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from this
software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from scipy.signal import get_window
from librosa.filters import mel as mel_func
from librosa.util import pad_center
class STFT(nn.Module):
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
def __init__(self, filter_length=800, hop_length=200, win_length=800,
window='hann'):
super(STFT, self).__init__()
self.filter_length = filter_length
self.hop_length = hop_length
self.win_length = win_length
self.window = window
self.forward_transform = None
scale = self.filter_length / self.hop_length
fourier_basis = np.fft.fft(np.eye(self.filter_length))
cutoff = int((self.filter_length / 2 + 1))
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
np.imag(fourier_basis[:cutoff, :])])
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
inverse_basis = torch.FloatTensor(
np.linalg.pinv(scale * fourier_basis).T[:, None, :])
if window is not None:
assert(filter_length >= win_length)
# get window and zero center pad it to filter_length
fft_window = get_window(window, win_length, fftbins=True)
fft_window = pad_center(fft_window, filter_length)
fft_window = torch.from_numpy(fft_window).float()
# window the bases
forward_basis *= fft_window
inverse_basis *= fft_window
self.register_buffer('forward_basis', forward_basis.float())
self.register_buffer('inverse_basis', inverse_basis.float())
def transform(self, input_data):
num_batches = input_data.size(0)
num_samples = input_data.size(1)
self.num_samples = num_samples
# similar to librosa, reflect-pad the input
input_data = input_data.view(num_batches, 1, num_samples)
input_data = F.pad(
input_data.unsqueeze(1),
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
mode='reflect')
input_data = input_data.squeeze(1)
forward_transform = F.conv1d(
input_data,
Variable(self.forward_basis, requires_grad=False),
stride=self.hop_length,
padding=0)
cutoff = int((self.filter_length / 2) + 1)
real_part = forward_transform[:, :cutoff, :]
imag_part = forward_transform[:, cutoff:, :]
magnitude = torch.sqrt(real_part**2 + imag_part**2)
phase = torch.autograd.Variable(
torch.atan2(imag_part.data, real_part.data))
return magnitude, phase
def forward(self, input_data):
self.magnitude, self.phase = self.transform(input_data)
reconstruction = self.inverse(self.magnitude, self.phase)
return reconstruction
class TacotronSTFT(nn.Module):
"""
Adapted from https://github.com/NVIDIA/tacotron2/blob/master/layers.py
"""
def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
mel_fmax=8000.0):
super(TacotronSTFT, self).__init__()
self.n_mel_channels = n_mel_channels
self.sampling_rate = sampling_rate
self.stft_fn = STFT(filter_length, hop_length, win_length)
mel_basis = mel_func(
sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
mel_basis = torch.from_numpy(mel_basis).float()
self.register_buffer('mel_basis', mel_basis)
def spectral_normalize(self, x):
return torch.log(torch.clamp(x, min=1e-5))
def spectral_de_normalize(self, x):
return torch.exp(x)
def mel_spectrogram(self, y):
"""
Compute mel-spectrograms from a batch of waves
"""
assert(torch.min(y.data) >= -1)
assert(torch.max(y.data) <= 1)
magnitudes, _ = self.stft_fn.transform(y)
magnitudes = magnitudes.data
mel_output = torch.matmul(self.mel_basis, magnitudes)
mel_output = self.spectral_normalize(mel_output)
return mel_output
================================================
FILE: bddm/models/__init__.py
================================================
#!/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
# Globally Attentive Locally Recurrent (GALR) Networks
# (https://arxiv.org/abs/2101.05014)
#
# Author: Max W. Y. Lam (maxwylam@tencent.com)
# Copyright (c) 2021Tencent. All Rights Reserved
#
########################################################################
from .diffwave import DiffWave
from .galr import GALR
def get_score_network(config):
if config.score_net == 'DiffWave':
conf_keys = DiffWave.__init__.__code__.co_varnames
model_config = {k: v for k, v in vars(config).items() if k in conf_keys}
return DiffWave(**model_config)
def get_schedule_network(config):
if config.schedule_net == 'GALR':
conf_keys = GALR.__init__.__code__.co_varnames
model_config = {k: v for k, v in vars(config).items() if k in conf_keys}
return GALR(**model_config)
================================================
FILE: bddm/models/diffwave.py
================================================
#!/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
# DiffWave: A Versatile Diffusion Model for Audio Synthesis
# (https://arxiv.org/abs/2009.09761)
# Modified from https://github.com/philsyn/DiffWave-Vocoder
#
# Author: Max W. Y. Lam (maxwylam@tencent.com)
# Copyright (c) 2021Tencent. All Rights Reserved
#
########################################################################
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in):
"""
Embed a diffusion step $t$ into a higher dimensional space
E.g. the embedding vector in the 128-dimensional space is
[sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)),
cos(t * 10^(0*4/63)), ... , cos(t * 10^(63*4/63))]
Parameters:
diffusion_steps (torch.long tensor, shape=(batchsize, 1)):
diffusion steps for batch data
diffusion_step_embed_dim_in (int, default=128):
dimensionality of the embedding space for discrete diffusion steps
Returns:
the embedding vectors (torch.tensor, shape=(batchsize, diffusion_step_embed_dim_in)):
"""
assert diffusion_step_embed_dim_in % 2 == 0
half_dim = diffusion_step_embed_dim_in // 2
_embed = np.log(10000) / (half_dim - 1)
_embed = torch.exp(torch.arange(half_dim) * -_embed).cuda()
_embed = diffusion_steps * _embed
diffusion_step_embed = torch.cat((torch.sin(_embed),
torch.cos(_embed)), 1)
return diffusion_step_embed
"""
Below scripts were borrowed from
https://github.com/philsyn/DiffWave-Vocoder/blob/master/WaveNet.py
"""
def swish(x):
return x * torch.sigmoid(x)
# dilated conv layer with kaiming_normal initialization
# from https://github.com/ksw0306/FloWaveNet/blob/master/modules.py
class Conv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1):
super().__init__()
self.padding = dilation * (kernel_size - 1) // 2
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size,
dilation=dilation, padding=self.padding)
self.conv = nn.utils.weight_norm(self.conv)
nn.init.kaiming_normal_(self.conv.weight)
def forward(self, x):
out = self.conv(x)
return out
# conv1x1 layer with zero initialization
# from https://github.com/ksw0306/FloWaveNet/blob/master/modules.py but the scale parameter is removed
class ZeroConv1d(nn.Module):
def __init__(self, in_channel, out_channel):
super().__init__()
self.conv = nn.Conv1d(in_channel, out_channel, kernel_size=1, padding=0)
self.conv.weight.data.zero_()
self.conv.bias.data.zero_()
def forward(self, x):
out = self.conv(x)
return out
# every residual block (named residual layer in paper)
# contains one noncausal dilated conv
class ResidualBlock(nn.Module):
def __init__(self, res_channels, skip_channels, dilation,
diffusion_step_embed_dim_out):
super().__init__()
self.res_channels = res_channels
# Use a FC layer for diffusion step embedding
self.fc_t = nn.Linear(diffusion_step_embed_dim_out, self.res_channels)
# Dilated conv layer
self.dilated_conv_layer = Conv(self.res_channels, 2 * self.res_channels,
kernel_size=3, dilation=dilation)
# Add mel spectrogram upsampler and conditioner conv1x1 layer
self.upsample_conv2d = nn.ModuleList()
for s in [16, 16]:
conv_trans2d = nn.ConvTranspose2d(1, 1, (3, 2 * s),
padding=(1, s // 2),
stride=(1, s))
conv_trans2d = nn.utils.weight_norm(conv_trans2d)
nn.init.kaiming_normal_(conv_trans2d.weight)
self.upsample_conv2d.append(conv_trans2d)
# 80 is mel bands
self.mel_conv = Conv(80, 2 * self.res_channels, kernel_size=1)
# Residual conv1x1 layer, connect to next residual layer
self.res_conv = nn.Conv1d(res_channels, res_channels, kernel_size=1)
self.res_conv = nn.utils.weight_norm(self.res_conv)
nn.init.kaiming_normal_(self.res_conv.weight)
# Skip conv1x1 layer, add to all skip outputs through skip connections
self.skip_conv = nn.Conv1d(res_channels, skip_channels, kernel_size=1)
self.skip_conv = nn.utils.weight_norm(self.skip_conv)
nn.init.kaiming_normal_(self.skip_conv.weight)
def forward(self, input_data):
x, mel_spec, diffusion_step_embed = input_data
h = x
batch_size, n_channels, seq_len = x.shape
assert n_channels == self.res_channels
# Add in diffusion step embedding
part_t = self.fc_t(diffusion_step_embed)
part_t = part_t.view([batch_size, self.res_channels, 1])
h += part_t
# Dilated conv layer
h = self.dilated_conv_layer(h)
# Upsample spectrogram to size of audio
mel_spec = torch.unsqueeze(mel_spec, dim=1)
mel_spec = F.leaky_relu(self.upsample_conv2d[0](mel_spec), 0.4, inplace=False)
mel_spec = F.leaky_relu(self.upsample_conv2d[1](mel_spec), 0.4, inplace=False)
mel_spec = torch.squeeze(mel_spec, dim=1)
assert mel_spec.size(2) >= seq_len
if mel_spec.size(2) > seq_len:
mel_spec = mel_spec[:, :, :seq_len]
mel_spec = self.mel_conv(mel_spec)
h += mel_spec
# Gated-tanh nonlinearity
out = torch.tanh(h[:, :self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels:, :])
# Residual and skip outputs
res = self.res_conv(out)
assert x.shape == res.shape
skip = self.skip_conv(out)
# Normalize for training stability
return (x + res) * math.sqrt(0.5), skip
class ResidualGroup(nn.Module):
def __init__(self, res_channels, skip_channels, num_res_layers, dilation_cycle,
diffusion_step_embed_dim_in,
diffusion_step_embed_dim_mid,
diffusion_step_embed_dim_out):
super().__init__()
self.num_res_layers = num_res_layers
self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in
# Use the shared two FC layers for diffusion step embedding
self.fc_t1 = nn.Linear(diffusion_step_embed_dim_in, diffusion_step_embed_dim_mid)
self.fc_t2 = nn.Linear(diffusion_step_embed_dim_mid, diffusion_step_embed_dim_out)
# Stack all residual blocks with dilations 1, 2, ... , 512, ... , 1, 2, ..., 512
self.residual_blocks = nn.ModuleList()
for n in range(self.num_res_layers):
self.residual_blocks.append(
ResidualBlock(res_channels, skip_channels,
dilation=2 ** (n % dilation_cycle),
diffusion_step_embed_dim_out=diffusion_step_embed_dim_out))
def forward(self, input_data):
x, mel_spectrogram, diffusion_steps = input_data
# Embed diffusion step t
diffusion_step_embed = calc_diffusion_step_embedding(
diffusion_steps, self.diffusion_step_embed_dim_in)
diffusion_step_embed = swish(self.fc_t1(diffusion_step_embed))
diffusion_step_embed = swish(self.fc_t2(diffusion_step_embed))
# Pass all residual layers
h = x
skip = 0
for n in range(self.num_res_layers):
# Use the output from last residual layer
h, skip_n = self.residual_blocks[n]((h, mel_spectrogram, diffusion_step_embed))
# Accumulate all skip outputs
skip += skip_n
# Normalize for training stability
return skip * math.sqrt(1.0 / self.num_res_layers)
class DiffWave(nn.Module):
def __init__(self, in_channels, res_channels, skip_channels, out_channels,
num_res_layers, dilation_cycle,
diffusion_step_embed_dim_in,
diffusion_step_embed_dim_mid,
diffusion_step_embed_dim_out):
super().__init__()
# Initial conv1x1 with relu
self.init_conv = nn.Sequential(Conv(in_channels, res_channels, kernel_size=1), nn.ReLU(inplace=False))
# All residual layers
self.residual_layer = ResidualGroup(res_channels,
skip_channels,
num_res_layers,
dilation_cycle,
diffusion_step_embed_dim_in,
diffusion_step_embed_dim_mid,
diffusion_step_embed_dim_out)
# Final conv1x1 -> relu -> zeroconv1x1
self.final_conv = nn.Sequential(Conv(skip_channels, skip_channels, kernel_size=1),
nn.ReLU(inplace=False), ZeroConv1d(skip_channels, out_channels))
def forward(self, input_data):
audio, mel_spectrogram, diffusion_steps = input_data
x = audio
x = self.init_conv(x).clone()
x = self.residual_layer((x, mel_spectrogram, diffusion_steps))
return self.final_conv(x)
================================================
FILE: bddm/models/galr.py
================================================
#!/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
# Globally Attentive Locally Recurrent (GALR) Networks
# (https://arxiv.org/abs/2101.05014)
#
# Author: Max W. Y. Lam (maxwylam@tencent.com)
# Copyright (c) 2021Tencent. All Rights Reserved
#
########################################################################
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class Encoder(nn.Module):
def __init__(self, enc_dim, win_len):
"""
1D Convoluation based Waveform Encoder
Parameters:
enc_dim (int): Dimension of each frame (e.g. choice in paper: ``128``).
win_len (int): Window length for processing raw signal samples (e.g. a
common choice: ``16``). By default, the windows are half-overlapping.
"""
super().__init__()
# 1D convolutional layer
self.enc_conv = nn.Conv1d(1, enc_dim,
kernel_size=win_len,
stride=win_len//2, bias=False)
def forward(self, signals):
"""
Non-linearly encode signals from raw waveform to frame sequences.
Parameters:
signals (tensor): A batch of signals in shape `[B, T]`, where `T` is the
maximum length of these `B` signals.
Returns:
frames (tensor): A batch of encoded feature (a.k.a. frames) sequences in shape
`[B, D, L]`, where `B` is the batch size, `D` is the encoded feature
dimension (enc_dim), and `L` is the length of this feature sequence.
"""
frames = F.relu(self.enc_conv(signals.unsqueeze(1)))
return frames
class BiLSTMproj(nn.Module):
def __init__(self, enc_dim, hid_dim):
"""
Locally Recurrent Layer (Sec 2.2.1 in https://arxiv.org/abs/2101.05014).
It consists of a bi-directional LSTM followed by a linear projection.
Parameters:
enc_dim (int): Dimension of each frame (e.g. choice in paper: ``128``).
hid_dim (int): Number of hidden nodes used in the Bi-LSTM.
"""
super().__init__()
# Bi-LSTM with learnable (h_0, c_0) state
self.rnn = nn.LSTM(enc_dim, hid_dim,
1, dropout=0, batch_first=True, bidirectional=True)
self.cell_init = nn.Parameter(torch.rand(1, 1, hid_dim))
self.hidden_init = nn.Parameter(torch.rand(1, 1, hid_dim))
# Linear projection layer
self.proj = nn.Linear(hid_dim * 2, enc_dim)
def forward(self, intra_segs):
"""
Process through a locally recurrent layer along the intra-segment
direction.
Parameters:
frames (tensor): A batch of intra-segments in shape `[B*S, K, D]`, where
`B` is the batch size, `S` is the number of segments, 'K' is the
segment length (seg_len) and `D` is the feature dimension (enc_dim).
Returns:
lr_output (tensor): A batch of processed segments with the same shape as the input.
"""
batch_size_seq_len = intra_segs.size(0)
cell = self.cell_init.repeat(2, batch_size_seq_len, 1)
hidden = self.hidden_init.repeat(2, batch_size_seq_len, 1)
rnn_output, _ = self.rnn(intra_segs, (hidden, cell))
lr_output = self.proj(rnn_output)
return lr_output
class AttnPositionalEncoding(nn.Module):
def __init__(self, enc_dim, attn_max_len=5000):
"""
Positional Encoding for Multi-Head Attention
Parameters:
enc_dim (int): Dimension of each frame (e.g. choice in paper: ``128``).
attn_max_len (int): Maximum length of the sequence to be processed by
multi-head attention.
"""
super().__init__()
pe = torch.zeros(attn_max_len, enc_dim)
position = torch.arange(0, attn_max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, enc_dim, 2).float() * (-math.log(10000.0)/enc_dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
"""
Compute positional encoding
Parameters:
x (tensor): the sequence to be addded with the positional encoding vector
Returns:
output (tensor): the encoded input
"""
output = x + self.pe[:x.size(0), :]
return output
class GlobalAttnLayer(nn.Module):
def __init__(self, enc_dim, n_attn_head, attn_dropout):
"""
Globally Attentive Layer (Sec 2.2.2 in
https://arxiv.org/abs/2101.05014)
Parameters:
enc_dim (int): Dimension of each frame (e.g. choice in paper: ``128``).
n_attn_head (int): Number of heads for the multi-head attention. (e.g.
choice in paper: ``8``)
attn_dropout (float): Dropout rate for multi-head attention (e.g. choice in
paper: ``0.1``).
"""
super().__init__()
self.attn = nn.MultiheadAttention(enc_dim,
n_attn_head, dropout=attn_dropout)
self.norm = nn.LayerNorm(enc_dim)
self.dropout = nn.Dropout(attn_dropout)
def forward(self, inter_segs):
"""
Process through a globally attentive layer along the inter-segment
direction.
Parameters:
inter_segs (tensor): A batch of inter-segments in shape `[S, B*K, D]`, where
`B` is the batch size, `S` is the number of segments, 'K' is the
segment length (seg_len) and `D` is the feature dimension (enc_dim).
Returns:
output (tensor): A batch of processed segments with the same shape as the input.
"""
output, _ = self.attn(inter_segs, inter_segs, inter_segs)
output = self.norm(output + self.dropout(output))
return output
class DeepGlobalAttnLayer(nn.Module):
def __init__(self, enc_dim, n_attn_head, attn_dropout, n_attn_layer=1):
"""
A Stack of Globally Attentive Layers (Sec 2.2.2 in
https://arxiv.org/abs/2101.05014)
Parameters:
enc_dim (int): Dimension of each frame (e.g. choice in paper: ``128``).
n_attn_head (int): Number of heads for the multi-head attention. (e.g.
choice in paper: ``8``)
attn_dropout (float): Dropout rate for multi-head attention (e.g. choice in
paper: ``0.1``).
n_attn_layer (int): Number of globally attentive layers stacked. The
setting in paper by default is ``1``.
"""
super().__init__()
self.attn_in_norm = nn.LayerNorm(enc_dim)
self.pos_enc = AttnPositionalEncoding(enc_dim)
self.attn_layer = nn.ModuleList([
GlobalAttnLayer(enc_dim, n_attn_head, attn_dropout) for _ in range(n_attn_layer)])
def forward(self, inter_segs):
"""
Process through a stack of globally attentive layers.
Parameters:
inter_segs (tensor): A batch of inter-segments in shape `[S, B*K, D]`, where
`B` is the batch size, `S` is the number of segments, 'K' is the
segment length (seg_len) and `D` is the feature dimension (enc_dim).
Returns:
output (tensor): A batch of processed segments with the same shape as the input.
"""
output = self.attn_in_norm(inter_segs)
output = self.pos_enc(output)
for block in self.attn_layer:
output = block(output)
return output
class GALRBlock(nn.Module):
def __init__(self, enc_dim, hid_dim, seg_len, low_dim=8, n_attn_head=8, attn_dropout=0.1):
"""
Globally Attentive Locally Recurrent (GALR) Block (Sec. 2.2 in
https://arxiv.org/abs/2101.05014)
Parameters:
enc_dim (int): Dimension of each frame (e.g. choice in paper: ``128``).
hid_dim (int): Number of hidden nodes used in the Bi-LSTM.
seg_len (int): Segment length for processing frame sequence (e.g. a
``64`` when win_len is ``16``). By default, the segments are
half-overlapping.
low_dim (int): Lower dimension for speeding up GALR. (Sec. 2.2.3)
(e.g. ``8`` when seg_len is ``64``).
n_attn_head (int): Number of heads for the multi-head attention. (e.g.
choice in paper: ``8``)
attn_dropout (float): Dropout rate for multi-head attention (e.g. choice in
paper: ``0.1``).
"""
super().__init__()
self.low_dim = low_dim
self.local_rnn = BiLSTMproj(enc_dim, hid_dim)
self.local_norm = nn.GroupNorm(1, enc_dim)
self.global_rnn = DeepGlobalAttnLayer(enc_dim, n_attn_head, attn_dropout)
self.global_norm = nn.GroupNorm(1, enc_dim)
self.ld_map = nn.Linear(seg_len, low_dim)
self.ld_inv_map = nn.Linear(low_dim, seg_len)
def forward(self, segments):
"""
Process through a GALR block.
Parameters:
segments (tensor): A batch of 3D segments in shape `[B, D, K, S]`, where
`B` is the batch size, `S` is the number of segments, 'K' is the
segment length (seg_len) and `D` is the feature dimension (enc_dim).
Returns:
segments (tensor): A batch of processed segments with the same shape as the input.
"""
batch_size, feat_dim, seg_len, n_segs = segments.size()
# Change the sequence direction for intra-segment processing
local_input = segments.transpose(1, 3).reshape(batch_size * n_segs, seg_len, feat_dim)
# Process through a locally recurrent layer
local_output = self.local_rnn(local_input)
# Reshape to match the dimensionality of the input for residual connection
local_output = local_output.view(batch_size, n_segs, seg_len, feat_dim).transpose(1, 3).contiguous()
# Add a layer normalization before the residual connection
local_output = self.local_norm(local_output)
# Add residual connection
segments = segments + local_output
# Change the sequence direction for intra-segment processing
global_input = segments.permute(3, 2, 0, 1).contiguous().view(n_segs, seg_len, batch_size, feat_dim)
# Perform low-dimensional mapping for speeding up GALR
global_input = self.ld_map(global_input.transpose(1, -1))
# Reshape for intra-segment processing (sequence, batch size, feature dim)
global_input = global_input.transpose(1, -1).contiguous().view(n_segs, -1, feat_dim)
# Process through a globally attentive layer
global_output = self.global_rnn(global_input)
# Reshape for low-dimensional inverse mapping
global_output = global_output.view(n_segs, self.low_dim, -1, feat_dim).transpose(1, -1)
# Map the low-dimensional features back to the original size
global_output = self.ld_inv_map(global_output)
# Reshape to match the dimensionality of the input for residual connection
global_output = global_output.permute(2, 1, 3, 0).contiguous()
# Add a layer normalization before the residual connection
global_output = self.global_norm(global_output)
# Add residual connection
segments = segments + global_output
return segments
class _GALR(nn.Module):
def __init__(self, n_block, enc_dim, hid_dim, win_len, seg_len,
low_dim=8, n_attn_head=8, attn_dropout=0.1):
"""
Globally Attentive Locally Recurrent (GALR) Networks
(https://arxiv.org/abs/2101.05014)
Parameters:
n_block (int): Number of GALR blocks (e.g. choice in paper: ``6``).
enc_dim (int): Dimension of each frame (e.g. choice in paper: ``128``).
hid_dim (int): Number of hidden nodes used in the Bi-LSTM.
win_len (int): Window length for processing raw signal samples (e.g. a
common choice: ``8``). By default, the windows are half-overlapping.
seg_len (int): Segment length for processing frame sequence (e.g. a
``64`` when win_len is ``8``). By default, the segments are half-overlapping.
low_dim (int): Lower dimension for speeding up GALR. (Sec. 2.2.3)
(e.g. ``8`` when seg_len is ``64``).
n_attn_head (int): Number of heads for the multi-head attention. (e.g.
choice in paper: ``8``)
attn_dropout (flaot): Dropout rate for multi-head attention (e.g. choice in
paper: ``0.1``).
"""
super().__init__()
self.win_len = win_len
self.seg_len = seg_len
self.encoder = Encoder(enc_dim, win_len)
self.bottleneck = nn.Conv1d(enc_dim, enc_dim, 1, bias=False)
# GALR blocks
self.blocks = nn.ModuleList([
GALRBlock(enc_dim, hid_dim, seg_len, low_dim, n_attn_head, attn_dropout)
for i in range(n_block)])
# Many-to-one gated layer applied to GALR's last block
self.block_gate = nn.Sequential(
nn.PReLU(),
nn.Conv2d(enc_dim, 1, 1)
)
def forward(self, noisy_signals):
"""
Process through a GALR network for noise estimation
Parameters:
noisy_signals (tensor): A batch of 1D signals in shape `[B, T]`, where
`B` is the batch size, `T` is the maximum length of the `B` signals.
Returns:
est_ratios (tensor): A batch of scalar noise scale ratios in `[B, 1]` shape.
"""
noisy_signals_padded, _ = self.pad_zeros(noisy_signals)
mix_frames = self.encoder(noisy_signals_padded)
mix_frames = self.bottleneck(mix_frames)
block_feature, _ = self.split_feature(mix_frames)
for block in self.blocks:
block_feature = block(block_feature)
est_segments = self.block_gate(block_feature)
est_ratios = torch.sigmoid(est_segments.mean([2, 3]))
est_ratios = est_ratios.mean(1, keepdim=True)
return est_ratios
def pad_zeros(self, signals):
"""
Pad a batch of signals with zeros before encoding.
Parameters:
signals (tensor): A batch of 1D signals in shape `[B, T]`, where `B` is
the batch size, `T` is the maximum length of the `B` signals.
Returns:
signals (tensor): A batch of padded signals and the length of zeros used for
padding. (in shape `[B, T]`)
rest (int): the redundant spaces created in padding
"""
batch_size, sig_len = signals.shape
self.hop_size = self.win_len // 2
rest = self.win_len - (self.hop_size + sig_len % self.win_len) % self.win_len
if rest > 0:
pad = torch.zeros(batch_size, rest).type(signals.type())
signals = torch.cat([signals, pad], 1)
pad_aux = torch.zeros(batch_size, self.hop_size).type(signals.type())
signals = torch.cat([pad_aux, signals, pad_aux], 1)
return signals, rest
def pad_segment(self, frames):
"""
Pad a batch of frames with zeros before segmentation.
Parameters:
frames (tensor): A batch of 2D frames in shape `[B, D, L]`, where `B` is
the batch size, `D` is the encoded feature dimension (enc_dim), and
`L` is the length of this feature sequence.
Returns:
frames (tensor): A batch of padded frames and the length of zeros used for
padding. (in shape `[B, D, L]`)
rest (int): the redundant spaces created in padding
"""
batch_size, feat_dim, seq_len = frames.shape
stride = self.seg_len // 2
rest = self.seg_len - (stride + seq_len % self.seg_len) % self.seg_len
if rest > 0:
pad = Variable(torch.zeros(batch_size, feat_dim, rest)).type(frames.type())
frames = torch.cat([frames, pad], 2)
pad_aux = Variable(torch.zeros(batch_size, feat_dim, stride)).type(frames.type())
frames = torch.cat([pad_aux, frames, pad_aux], 2)
return frames, rest
def split_feature(self, frames):
"""
Perform segmentation by dividing every K (seg_len) consecutive frames
into S segments.
Parameters:
frames (tensor): A batch of 2D frames in shape `[B, D, L]`, where `B` is
the batch size, `D` is the encoded feature Dension (enc_D), and
`L` is the length of this feature sequence.
Returns:
segments (tensor): A batch of 3D segments and the length of zeros used for
padding. (in shape `[B, D, K, S]`)
rest (int): the redundant spaces created in padding
"""
frames, rest = self.pad_segment(frames)
batch_size, feat_dim, _ = frames.shape
stride = self.seg_len // 2
lsegs = frames[:, :, :-stride].contiguous().view(batch_size, feat_dim, -1, self.seg_len)
rsegs = frames[:, :, stride:].contiguous().view(batch_size, feat_dim, -1, self.seg_len)
segments = torch.cat([lsegs, rsegs], -1).view(batch_size, feat_dim, -1, self.seg_len)
segments = segments.transpose(2, 3).contiguous()
return segments, rest
class GALR(nn.Module):
def __init__(self, blocks=2, input_dim=128, hidden_dim=128, window_length=8, segment_size=64):
"""
GALR Schedule Network
Parameters:
blocks (int): Number of GALR blocks (e.g. choice in paper: ``6``).
input_dim (int): Dimension of each frame (e.g. choice in paper: ``128``).
hidden_dim (int): Number of hidden nodes used in the Bi-LSTM.
window_length (int): Window length for processing raw signal samples (e.g. a
common choice: ``8``). By default, the windows are half-overlapping.
segment_size (int): Segment length for processing frame sequence (e.g. a
``64`` when win_len is ``8``). By default, the segments are half-overlapping.
"""
super().__init__()
self.ratio_nn = _GALR(blocks, input_dim, hidden_dim, window_length, segment_size)
def forward(self, audio, scales):
"""
Estimate the next beta scale using the GALR schedule network
Parameters:
audio (tensor): A batch of 1D signals in shape `[B, T]`, where
`B` is the batch size, `T` is the maximum length of the `B` signals.
scales (list): [beta_nxt, delta] for computing the upper bound
Returns:
beta (tensor): A batch of scalar noise scale ratios in `[B, 1, 1]` shape.
"""
beta_nxt, delta = scales
bounds = torch.cat([beta_nxt, delta], 1)
mu, _ = torch.min(bounds, 1)
mu = mu[:, None]
ratio = self.ratio_nn(audio)
beta = mu * ratio
return beta.view(audio.size(0), 1, 1)
================================================
FILE: bddm/sampler/__init__.py
================================================
#!/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
# Sampler
#
# Author: Max W. Y. Lam (maxwylam@tencent.com)
# Copyright (c) 2021Tencent. All Rights Reserved
#
########################################################################
from .sampler import Sampler
================================================
FILE: bddm/sampler/sampler.py
================================================
#!/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
# BDDM Sampler (Supports Noise Scheduling and Sampling)
#
# Author: Max W. Y. Lam (maxwylam@tencent.com)
# Copyright (c) 2021Tencent. All Rights Reserved
#
########################################################################
from __future__ import absolute_import
import os
import librosa
import torch
import numpy as np
from scipy.io.wavfile import write as wavwrite
from pystoi import stoi
from pypesq import pesq
from bddm.utils.log_utils import log
from bddm.utils.check_utils import check_score_network
from bddm.utils.check_utils import check_schedule_network
from bddm.utils.diffusion_utils import compute_diffusion_params
from bddm.utils.diffusion_utils import map_noise_scale_to_time_step
from bddm.models import get_score_network, get_schedule_network
from bddm.loader.dataset import create_generation_dataloader
from bddm.loader.dataset import create_train_and_valid_dataloader
class Sampler(object):
metric2index = {"PESQ": 0, "STOI": 1}
def __init__(self, config):
"""
Sampler Class, implements the Noise Scheduling and Sampling algorithms in BDDMs
Parameters:
config (namespace): BDDM Configuration
"""
self.config = config
self.exp_dir = config.exp_dir
self.clip = config.grad_clip
self.load = config.load
self.model = get_score_network(config).cuda().eval()
self.schedule = None
# Initialize diffusion parameters using a pre-specified linear schedule
noise_schedule = torch.linspace(config.beta_0, config.beta_T, config.T).cuda()
self.diff_params = compute_diffusion_params(noise_schedule)
if self.config.command != 'train':
# Find schedule net, if not trained then use DDPM or DDIM sampling mode
schedule_net_trained, schedule_net_path = check_schedule_network(config)
if not schedule_net_trained:
_, score_net_path = check_score_network(config)
assert score_net_path is not None, 'Error: No score network can be found!'
self.config.load = score_net_path
else:
self.config.load = schedule_net_path
self.model.schedule_net = get_schedule_network(config).cuda().eval()
# Perform noise scheduling when noise schedule file (.schedule) is not given
if self.config.command == 'generate' and self.config.sampling_noise_schedule != '':
# Generation mode given pre-searched noise schedule
self.schedule = torch.load(self.config.sampling_noise_schedule, map_location='cpu')
self.reset()
def reset(self):
"""
Reset sampling environment
"""
if self.config.command != 'train' and self.load != '':
package = torch.load(self.load, map_location=lambda storage, loc: storage.cuda())
self.model.load_state_dict(package['model_state_dict'])
log('Loaded checkpoint %s' % self.load, self.config)
if self.config.command == 'generate':
# Given Mel-spectrogram directory for speech vocoding
self.gen_loader = create_generation_dataloader(self.config)
else:
# Sample a reference audio sample for noise scheduling
_, self.vl_loader = create_train_and_valid_dataloader(self.config)
self.draw_reference_data_pair()
def draw_reference_data_pair(self):
"""
Draw a new input-output pair for noise scheduling
"""
self.ref_spec, self.ref_audio = next(iter(self.vl_loader))
self.ref_spec, self.ref_audio = self.ref_spec.cuda(), self.ref_audio.cuda()
def generate(self):
"""
Start the generation process
"""
generate_dir = os.path.join(self.exp_dir, 'generated')
os.makedirs(generate_dir, exist_ok=True)
scores = {metric: [] for metric in self.metric2index}
for filepath, mel_spec, audio in self.gen_loader:
mel_spec = mel_spec.cuda()
generated_audio, n_steps = self.sampling(schedule=self.schedule,
condition=mel_spec,
ddim=self.config.use_ddim_steps)
audio_key = '.'.join(filepath[0].split('/')[-1].split('.')[:-1])
if len(audio) > 0:
# Assess the generated audio with the reference audio
self.ref_audio = audio
score = self.assess(generated_audio, audio_key=audio_key)
for metric in self.metric2index:
scores[metric].append(score[self.metric2index[metric]])
model_name = 'BDDM' if self.schedule is not None else (
'DDIM' if self.config.use_ddim_steps else 'DDPM')
generated_file = os.path.join(generate_dir,
'%s_by_%s-%d.wav'%(audio_key, model_name, n_steps))
wavwrite(generated_file, self.config.sampling_rate,
generated_audio.squeeze().cpu().numpy())
log('Generated '+generated_file, self.config)
log('Avg Scores: PESQ = %.3f +/- %.3f, %.5f +/- %.5f'%(
np.mean(scores['PESQ']), 1.96 * np.std(scores['PESQ']),
np.mean(scores['STOI']), 1.96 * np.std(scores['STOI'])), self.config)
def noise_scheduling(self, ddim=False):
"""
Start the noise scheduling process
Parameters:
ddim (bool): whether to use the DDIM's p_theta for noise scheduling or not
Returns:
ts_infer (tensor): the step indices estimated by BDDM
a_infer (tensor): the alphas estimated by BDDM
b_infer (tensor): the betas estimated by BDDM
s_infer (tensor): the std. deviations estimated by BDDM
"""
max_steps = self.diff_params["N"]
alpha = self.diff_params["alpha"]
alpha_param = self.diff_params["alpha_param"]
beta_param = self.diff_params["beta_param"]
min_beta = self.diff_params["beta"].min()
betas = []
x = torch.normal(0, 1, size=self.ref_audio.shape).cuda()
with torch.no_grad():
b_cur = torch.ones(1, 1, 1).cuda() * beta_param
a_cur = torch.ones(1, 1, 1).cuda() * alpha_param
for n in range(max_steps - 1, -1, -1):
step = map_noise_scale_to_time_step(a_cur.squeeze().item(), alpha)
if step >= 0:
betas.append(b_cur.squeeze().item())
else:
break
ts = (step * torch.ones((1, 1))).cuda()
e = self.model((x, self.ref_spec.clone(), ts,))
a_nxt = a_cur / (1 - b_cur).sqrt()
if ddim:
c1 = a_nxt / a_cur
c2 = -(1 - a_cur**2.).sqrt() * c1
x = c1 * x + c2 * e
c3 = (1 - a_nxt**2.).sqrt()
x = x + c3 * e
else:
x = x - b_cur / torch.sqrt(1 - a_cur**2.) * e
x = x / torch.sqrt(1 - b_cur)
if n > 0:
z = torch.normal(0, 1, size=self.ref_audio.shape).cuda()
x = x + torch.sqrt((1 - a_nxt**2.) / (1 - a_cur**2.) * b_cur) * z
a_nxt, beta_nxt = a_cur, b_cur
a_cur = a_nxt / (1 - beta_nxt).sqrt()
if a_cur > 1:
break
b_cur = self.model.schedule_net(x.squeeze(1),
(beta_nxt.view(-1, 1), (1 - a_cur**2.).view(-1, 1)))
if b_cur.squeeze().item() < min_beta:
break
b_infer = torch.FloatTensor(betas[::-1]).cuda()
a_infer = 1 - b_infer
s_infer = b_infer + 0
for n in range(1, len(b_infer)):
a_infer[n] *= a_infer[n-1]
s_infer[n] *= (1 - a_infer[n-1]) / (1 - a_infer[n])
a_infer = torch.sqrt(a_infer)
s_infer = torch.sqrt(s_infer)
# Mapping noise scales to time steps
ts_infer = []
for n in range(len(b_infer)):
step = map_noise_scale_to_time_step(a_infer[n], alpha)
if step >= 0:
ts_infer.append(step)
ts_infer = torch.FloatTensor(ts_infer)
return ts_infer, a_infer, b_infer, s_infer
def sampling(self, schedule=None, condition=None,
ddim=0, return_sequence=False, audio_size=None):
"""
Perform the sampling algorithm
Parameters:
schedule (list): the [ts_infer, a_infer, b_infer, s_infer]
returned by the noise scheduling algorithm
condition (tensor): the condition for computing scores
ddim (bool): whether to use the DDIM for sampling or not
return_sequence (bool): whether returning all steps' samples or not
audio_size (list): the shape of the audio to be sampled
Returns:
xs (list): (if return_sequence) the list of generated audios
x (tensor): the generated audio(s) in shape=audio_size
N (int): the number of sampling steps
"""
n_steps = self.diff_params["T"]
if condition is None:
condition = self.ref_spec
if audio_size is None:
audio_length = condition.size(-1) * self.config.hop_len
audio_size = (1, 1, audio_length)
if schedule is None:
if ddim > 1:
# Use DDIM (linear) for sampling ({ddim} steps)
ts_infer = torch.linspace(0, n_steps - 1, ddim).long()
a_infer = self.diff_params["alpha"].index_select(0, ts_infer.cuda())
b_infer = self.diff_params["beta"].index_select(0, ts_infer.cuda())
s_infer = self.diff_params["sigma"].index_select(0, ts_infer.cuda())
else:
# Use DDPM for sampling (complete T steps)
# P.S. if ddim = 1, run DDIM reverse process for T steps
ts_infer = torch.linspace(0, n_steps - 1, n_steps)
a_infer = self.diff_params["alpha"]
b_infer = self.diff_params["beta"]
s_infer = self.diff_params["sigma"]
else:
ts_infer, a_infer, b_infer, s_infer = schedule
sampling_steps = len(ts_infer)
x = torch.normal(0, 1, size=audio_size).cuda()
if return_sequence:
xs = []
with torch.no_grad():
for n in range(sampling_steps - 1, -1, -1):
if sampling_steps > 50 and (sampling_steps - n) % 50 == 0:
# Log progress per 50 steps when sampling_steps is large
log('\tComputed %d / %d steps !'%(
sampling_steps - n, sampling_steps), self.config)
ts = (ts_infer[n] * torch.ones((1, 1))).cuda()
e = self.model((x, condition, ts,))
if ddim:
if n > 0:
a_nxt = a_infer[n - 1]
else:
a_nxt = a_infer[n] / (1 - b_infer[n]).sqrt()
c1 = a_nxt / a_infer[n]
c2 = -(1 - a_infer[n]**2.).sqrt() * c1
c3 = (1 - a_nxt**2.).sqrt()
x = c1 * x + (c2 + c3) * e
else:
x = x - b_infer[n] / torch.sqrt(1 - a_infer[n]**2.) * e
x = x / torch.sqrt(1 - b_infer[n])
if n > 0:
z = torch.normal(0, 1, size=audio_size).cuda()
x = x + s_infer[n] * z
if return_sequence:
xs.append(x)
if return_sequence:
return xs
return x, sampling_steps
def noise_scheduling_with_params(self, alpha_param, beta_param):
"""
Run noise scheduling for once given the (alpha_param, beta_param) pair
Parameters:
alpha_param (float): a hyperparameter defining the alpha_hat value at step N
beta_param (float): a hyperparameter defining the beta_hat value at step N
"""
log('TRY alpha_param=%.2f, beta_param=%.2f:'%(
alpha_param, beta_param), self.config)
# Define the pair key
key = '%.2f,%.2f' % (alpha_param, beta_param)
# Set alpha_param and beta_param in self.diff_params
self.diff_params['alpha_param'] = alpha_param
self.diff_params['beta_param'] = beta_param
# Use DDPM reverse process for noise scheduling
ddpm_schedule = self.noise_scheduling(ddim=False)
log("\tSearched a %d-step schedule using DDPM reverse process" % (
len(ddpm_schedule[0])), self.config)
generated_audio, _ = self.sampling(schedule=ddpm_schedule)
# Compute objective scores
ddpm_score = self.assess(generated_audio)
# Get the number of sampling steps with this schedule
steps = len(ddpm_schedule[0])
# Compare the performance with previous same-step schedule using the metric
if steps not in self.steps2score:
# Save the first schedule with this number of steps
self.steps2score[steps] = [key, ] + ddpm_score
self.steps2schedule[steps] = ddpm_schedule
log('\tFound the first %d-step schedule: (PESQ, STOI) = (%.2f, %.3f)'%(
steps, ddpm_score[0], ddpm_score[1]), self.config)
elif ddpm_score[0] > self.steps2score[steps][1] and ddpm_score[1] > self.steps2score[steps][2]:
# Found a better same-step schedule achieving a higher score
log('\tFound a better %d-step schedule: (PESQ, STOI) = (%.2f, %.3f) -> (%.2f, %.3f)'%(
steps, self.steps2score[steps][1], self.steps2score[steps][2],
ddpm_score[0], ddpm_score[1]), self.config)
self.steps2score[steps] = [key, ] + ddpm_score
self.steps2schedule[steps] = ddpm_schedule
# Use DDIM reverse process for noise scheduling
ddim_schedule = self.noise_scheduling(ddim=True)
log("\tSearched a %d-step schedule using DDIM reverse process" % (
len(ddim_schedule[0])), self.config)
generated_audio, _ = self.sampling(schedule=ddim_schedule)
# Compute objective scores
ddim_score = self.assess(generated_audio)
# Get the number of sampling steps with this schedule
steps = len(ddim_schedule[0])
# Compare the performance with previous same-step schedule using the metric
if steps not in self.steps2score:
# Save the first schedule with this number of steps
self.steps2score[steps] = [key, ] + ddim_score
self.steps2schedule[steps] = ddim_schedule
log('\tFound the first %d-step schedule: (PESQ, STOI) = (%.2f, %.3f)'%(
steps, ddim_score[0], ddim_score[1]), self.config)
elif ddim_score[0] > self.steps2score[steps][1] and ddim_score[1] > self.steps2score[steps][2]:
# Found a better same-step schedule achieving a higher score
log('\tFound a better %d-step schedule: (PESQ, STOI) = (%.2f, %.3f) -> (%.2f, %.3f)'%(
steps, self.steps2score[steps][1], self.steps2score[steps][2],
ddim_score[0], ddim_score[1]), self.config)
self.steps2score[steps] = [key, ] + ddim_score
self.steps2schedule[steps] = ddim_schedule
def noise_scheduling_without_params(self):
"""
Search for the best noise scheduling hyperparameters: (alpha_param, beta_param)
"""
# Noise scheduling mode, given N
self.reverse_process = 'BDDM'
assert 'N' in vars(self.config).keys(), 'Error: N is undefined for BDDM!'
self.diff_params["N"] = self.config.N
# Init search result dictionaries
self.steps2schedule, self.steps2score = {}, {}
search_bins = int(self.config.bddm_search_bins)
# Define search range of alpha_param
alpha_last = self.diff_params["alpha"][-1].squeeze().item()
alpha_first = self.diff_params["alpha"][0].squeeze().item()
alpha_diff = (alpha_first - alpha_last) / (search_bins + 1)
alpha_param_list = [alpha_last + alpha_diff * (i + 1) for i in range(search_bins)]
# Define search range of beta_param
beta_diff = 1. / (search_bins + 1)
beta_param_list = [beta_diff * (i + 1) for i in range(search_bins)]
# Search for beta_param and alpha_param, take O(search_bins^2)
for beta_param in beta_param_list:
for alpha_param in alpha_param_list:
if alpha_param > (1 - beta_param) ** 0.5:
# Invalid range
continue
# Update the scores and noise schedules with (alpha_param, beta_param)
self.noise_scheduling_with_params(alpha_param, beta_param)
# Lastly, repeat the random starting point (x_hat_N) and choose the best schedule
noise_schedule_dir = os.path.join(self.exp_dir, 'noise_schedules')
os.makedirs(noise_schedule_dir, exist_ok=True)
steps_list = list(self.steps2score.keys())
for steps in steps_list:
log("-"*80, self.config)
log("Select the best out of %d x_hat_N ~ N(0,I) for %d steps:"%(
self.config.noise_scheduling_attempts, steps), self.config)
# Get current best pair
key = self.steps2score[steps][0]
# Get back the best (alpha_param, beta_param) pair for a given steps
alpha_param, beta_param = list(map(float, key.split(',')))
# Repeat K times for a given number of steps
for _ in range(self.config.noise_scheduling_attempts):
# Random +/- 5%
_alpha_param = alpha_param * (0.95 + np.random.rand() * 0.1)
_beta_param = beta_param * (0.95 + np.random.rand() * 0.1)
# Update the scores and noise schedules with (alpha_param, beta_param)
self.noise_scheduling_with_params(_alpha_param, _beta_param)
# Save the best searched noise schedule ({N}steps_{key}_{metric}{best_score}.ns)
for steps in sorted(self.steps2score.keys(), reverse=True):
filepath = os.path.join(noise_schedule_dir, '%dsteps_PESQ%.2f_STOI%.3f.ns'%(
steps, self.steps2score[steps][1], self.steps2score[steps][2]))
torch.save(self.steps2schedule[steps], filepath)
log("Saved searched schedule: %s" % filepath, self.config)
def assess(self, generated_audio, audio_key=None):
"""
Assess the generated audio using objective metrics: PESQ and STOI.
P.S. Users should first install pypesq and pystoi using pip
Parameters:
generated_audio (tensor): the generated audio to be assessed
audio_key (str): the key of the respective audio
Returns:
pesq_score (float): the PESQ score (the higher the better)
stoi_score (float): the STOI score (the higher the better)
"""
est_audio = generated_audio.squeeze().cpu().numpy()
ref_audio = self.ref_audio.squeeze().cpu().numpy()
if est_audio.shape[-1] > ref_audio.shape[-1]:
est_audio = est_audio[..., :ref_audio.shape[-1]]
else:
ref_audio = ref_audio[..., :est_audio.shape[-1]]
# Compute STOI using PySTOI
# PySTOI (https://github.com/mpariente/pystoi)
stoi_score = stoi(ref_audio, est_audio, self.config.sampling_rate, extended=False)
# Resample audio to 16K Hz to compute PESQ using PyPESQ (supports only 8K / 16K)
# PyPESQ (https://github.com/vBaiCai/python-pesq)
if self.config.sampling_rate != 16000:
est_audio_16k = librosa.resample(est_audio, self.config.sampling_rate, 16000)
ref_audio_16k = librosa.resample(ref_audio, self.config.sampling_rate, 16000)
pesq_score = pesq(ref_audio_16k, est_audio_16k, 16000)
else:
pesq_score = pesq(ref_audio, est_audio, 16000)
# Log scores
log('\t%sScores: PESQ = %.3f, STOI = %.5f'%(
'' if audio_key is None else audio_key+' ', pesq_score, stoi_score), self.config)
# Return scores: the higher the better
return [pesq_score, stoi_score]
================================================
FILE: bddm/trainer/__init__.py
================================================
#!/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
# Trainer
#
# Author: Max W. Y. Lam (maxwylam@tencent.com)
# Copyright (c) 2021Tencent. All Rights Reserved
#
########################################################################
from .trainer import Trainer
================================================
FILE: bddm/trainer/bmuf.py
================================================
#!/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
# BMUF Multi-GPU Training Method
#
# Author: Max W. Y. Lam (maxwylam@tencent.com)
# Copyright (c) 2021Tencent. All Rights Reserved
#
########################################################################
import os
import sys
import psutil
import torch
import torch.nn as nn
import torch.distributed as dist
class BmufTrainer(object):
def __init__(self, master_node, rank, world_size, model, block_momentum, block_lr):
"""
Basic BMUF Trainer Class, implements Nesterov Block Momentum
Parameters:
master_node (int): master node index, zero in most cases
rank (int): local rank, eg, 0-7 if 8GPUs are used
world_size (int): total number of workers
model (nn.Module): PyTorch model
block_momentum (float): block momentum value
block_lr (float): block learning rate
"""
assert torch.cuda.is_available(), "Distributed mode requires CUDA."
self.master_node = 0 # By default, use device 0 as master node
self.model = model
self.rank = rank
self.world_size = world_size
self.block_momentum = block_momentum
self.block_lr = block_lr
dist.init_process_group(backend="nccl", init_method="env://")
param_vec = nn.utils.parameters_to_vector(model.parameters())
self.param = param_vec.data.clone()
dist.broadcast(tensor=self.param, src=self.master_node, async_op=False)
num_param = self.param.numel()
self.delta_prev = torch.FloatTensor([0]*num_param).to(self.param.device)
if self.rank == self.master_node:
self.delta_prev = torch.FloatTensor([0]*num_param).cuda(self.master_node)
else:
self.delta_prev = None
self._copy_vec_to_param(self.param)
# for fining the best model among different ranks
self.min_val_loss = torch.FloatTensor([float('inf')]).to(self.param.device)
def check_all_processes_running(self):
"""
Check whether all processes are healthy
"""
pid = os.getpid()
for p in range(pid-self.rank, pid-self.rank+self.world_size):
if not psutil.pid_exists(p):
sys.exit(-1)
def get_average_valid_loss(self, val_loss):
"""
Get average validatin loss through all reduce operations
Parameters:
val_loss (Tensor): calculated validation loss in each process
Returns:
save_or_not (bool): whether save the current validated model or not
"""
save_or_not = False
dist.all_reduce(tensor=val_loss)
val_loss = val_loss / float(self.world_size)
# save the min valid loss among all gpus
if 'min_val_loss' not in self.__dict__ or val_loss < self.min_val_loss:
self.min_val_loss = val_loss.clone()
save_or_not = True
return save_or_not
def update_and_sync(self, val_loss=None):
"""
Update and synchronize block gradients
Parameters:
val_loss (tensor): calculated validation loss in each process
Returns:
save_or_not (bool): whether save the current validated model or not
"""
self.check_all_processes_running()
save_or_not = False
if val_loss is not None:
save_or_not = self.get_average_valid_loss(val_loss)
cur_param_vec = nn.utils.parameters_to_vector(self.model.parameters()).data
delta = self.param - cur_param_vec
# Gather block gradients into delta
dist.reduce(tensor=delta, dst=self.master_node)
# Check if model params are still healthy
if torch.isnan(delta).sum().item():
print('Found nan, exit!', flush=True)
sys.exit(-1)
if self.rank == self.master_node:
# Local rank is master node
delta = delta / float(self.world_size)
self.delta_prev = self.block_momentum * self.delta_prev + \
(self.block_lr * (1 - self.block_momentum) * delta)
self.param -= (1+self.block_momentum) * self.delta_prev
dist.broadcast(tensor=self.param, src=self.master_node, async_op=False)
self._copy_vec_to_param(self.param)
return save_or_not
def _copy_vec_to_param(self, vec):
"""
Copy a vectorized array to the model parameters
Parameters:
vec (tensor): a single vector represents the parameters of a model.
"""
# Ensure vec of type Tensor
if not isinstance(vec, torch.Tensor):
raise TypeError('expected torch.Tensor, but got: {}'
.format(torch.typename(vec)))
# Pointer for slicing the vector for each parameter
pointer = 0
for param in self.model.parameters():
# The length of the parameter
num_param = param.numel()
# Slice the vector, reshape it, and replace the old data of the parameter
param.data = param.data.copy_(vec[pointer:pointer + num_param]
.view_as(param).data)
# Increment the pointer
pointer += num_param
================================================
FILE: bddm/trainer/ema.py
================================================
#!/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
# EMA Helper Class
#
# Author: Max W. Y. Lam (maxwylam@tencent.com)
# Copyright (c) 2021Tencent. All Rights Reserved
#
########################################################################
import torch.nn as nn
class EMAHelper(object):
def __init__(self, mu=0.999):
"""
Exponential Moving Average Training Helper Class
Parameters:
mu (float): decaying rate
"""
self.mu = mu
self.shadow = {}
def register(self, module):
"""
Register module by copying all learnable parameters to self.shadow
Parameters:
module (nn.Module): model to be trained
"""
if isinstance(module, nn.DataParallel):
module = module.module
for name, param in module.named_parameters():
if param.requires_grad:
self.shadow[name] = param.data.clone()
def update(self, module):
"""
Update self.shadow using the module parameters
Parameters:
module (nn.Module): model in training
"""
if isinstance(module, nn.DataParallel):
module = module.module
for name, param in module.named_parameters():
if param.requires_grad:
self.shadow[name].data = (
1. - self.mu) * param.data + self.mu * self.shadow[name].data
def ema(self, module):
"""
Copy self.shadow to the module parameters
Parameters:
module (nn.Module): model in training
"""
if isinstance(module, nn.DataParallel):
module = module.module
for name, param in module.named_parameters():
if param.requires_grad:
param.data.copy_(self.shadow[name].data)
def ema_copy(self, module):
"""
Initialize a new module using self.shadow as the parameters
Parameters:
module (nn.Module): model in training
"""
if isinstance(module, nn.DataParallel):
inner_module = module.module
module_copy = type(inner_module)(
inner_module.config).to(inner_module.config.device)
module_copy.load_state_dict(inner_module.state_dict())
module_copy = nn.DataParallel(module_copy)
else:
module_copy = type(module)(module.config).to(module.config.device)
module_copy.load_state_dict(module.state_dict())
self.ema(module_copy)
return module_copy
def state_dict(self):
"""
Get self.shadow as the state dict
Returns:
shadow (dict): state dict
"""
return self.shadow
def load_state_dict(self, state_dict):
"""
Load a state dict to self.shadow
Parameters:
state dict (dict): state dict to be copied to self.shadow
"""
self.shadow = state_dict
================================================
FILE: bddm/trainer/loss.py
================================================
#!/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
# Implements Training Losses for BDDMs
#
# Author: Max W. Y. Lam (maxwylam@tencent.com)
# Copyright (c) 2021Tencent. All Rights Reserved
#
########################################################################
import torch
import torch.nn as nn
class ScoreLoss(nn.Module):
def __init__(self, config, diff_params):
"""
Score Loss Class, implements DDPM's simplified loss (see Eq. 5 in BDDM's paper)
Parameters:
config (namespace): BDDM Configuration
diff_params (dict): Dictionary that stores pre-computed diffusion parameters
"""
super().__init__()
self.config = config
self.num_steps = diff_params["T"]
self.alpha = diff_params["alpha"]
def forward(self, model, mels, audios):
"""
Compute the training loss for learning theta
Parameters:
model (nn.Module): the score network
mels (tensor): shape=(batch size, frames, spectrogram dim)
audios (tensor): shape=(batch size, 1, length of audio)
Returns:
score loss (tensor): shape=(batch size,)
"""
batch_size = audios.size(0)
ts = torch.randint(low=0, high=self.num_steps, size=(batch_size, 1, 1)).cuda()
noise_scales = self.alpha[ts]
z = torch.normal(0, 1, size=audios.shape).cuda()
noisy_audios = noise_scales * audios + (1 - noise_scales**2.).sqrt() * z
e = model((noisy_audios, mels, ts.view(batch_size, 1),))
# Use WaveGrad's L1Loss for speech generation
theta_loss = nn.L1Loss(reduction='none')(e, z[:, :, :e.size(-1)]).mean([1, 2])
return theta_loss
class StepLoss(nn.Module):
def __init__(self, config, diff_params):
"""
Step Loss Class, implements BDDM's step loss (see Eq. 14 in BDDM's paper)
Parameters:
config (namespace): BDDM Configuration
diff_params (dict): Dictionary that stores pre-computed diffusion parameters
"""
super().__init__()
self.config = config
self.num_steps = diff_params["T"]
self.alpha = diff_params["alpha"]
self.tau = diff_params["tau"]
def forward(self, model, mels, audios):
"""
Compute the training loss for learning phi
Parameters:
model (nn.Module): the score network & the schedule network
mels (tensor): shape=(batch size, frames, spectrogram dim)
audios (tensor): shape=(batch size, 1, length of audio)
Returns:
step loss (tensor): shape=(batch size,)
"""
batch_size = audios.size(0)
ts = torch.randint(self.tau, self.num_steps-self.tau, size=(batch_size,)).cuda()
alpha_cur = self.alpha.index_select(0, ts).view(batch_size, 1, 1)
alpha_nxt = self.alpha.index_select(0, ts+self.tau).view(batch_size, 1, 1)
b_nxt = 1 - (alpha_nxt / alpha_cur)**2.
delta = (1 - alpha_cur**2.).sqrt()
z = torch.normal(0, 1, size=audios.shape).cuda()
noisy_audios = alpha_cur * audios + delta * z
e = model((noisy_audios, mels, ts.view(batch_size, 1),))
beta_bounds = (b_nxt.view(batch_size, 1), delta.view(batch_size, 1)**2.)
b_hat = model.schedule_net(noisy_audios.squeeze(1), beta_bounds)
delta, b_hat, z, e = delta.squeeze(1), b_hat.squeeze(1), z.squeeze(1), e.squeeze(1)
phi_loss = delta**2. / (2. * (delta**2. - b_hat))
phi_loss = phi_loss * (z - b_hat / (delta**2.) * e).square()
phi_loss = phi_loss + torch.log(1e-8 + delta**2. / (b_hat + 1e-8)) / 4.
phi_loss = phi_loss.sum(-1) + (b_hat / delta**2 - 1) / 2. * audios.size(-1)
return phi_loss
================================================
FILE: bddm/trainer/trainer.py
================================================
#!/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
# BDDM Trainer (Support Multi-GPU Training using BMUF method)
#
# Author: Max W. Y. Lam (maxwylam@tencent.com)
# Copyright (c) 2021Tencent. All Rights Reserved
#
########################################################################
from __future__ import absolute_import
import os
import time
import copy
import torch
from bddm.sampler import Sampler
from bddm.trainer.ema import EMAHelper
from bddm.trainer.bmuf import BmufTrainer
from bddm.trainer.loss import ScoreLoss, StepLoss
from bddm.utils.log_utils import log
from bddm.utils.check_utils import check_score_network
from bddm.utils.diffusion_utils import compute_diffusion_params
from bddm.models import get_score_network, get_schedule_network
from bddm.loader.dataset import create_train_and_valid_dataloader
class Trainer(object):
def __init__(self, config):
"""
Trainer Class, implements a general multi-GPU training framework in PyTorch
Parameters:
config (namespace): BDDM Configuration
"""
self.config = config
self.exp_dir = config.exp_dir
self.clip = config.grad_clip
self.load = config.load
self.model = get_score_network(config).cuda()
# Define training target
if self.config.resume_training and self.load != '':
self.training_target = 'score_nets'
else:
score_net_trained, score_net_path = check_score_network(config)
self.training_target = 'schedule_nets' if score_net_trained else 'score_nets'
torch.autograd.set_detect_anomaly(True)
# Initialize diffusion parameters using a pre-specified linear schedule
noise_schedule = torch.linspace(config.beta_0, config.beta_T, config.T).cuda()
self.diff_params = compute_diffusion_params(noise_schedule)
if self.training_target == 'schedule_nets':
if self.load == '':
self.load = score_net_path
self.diff_params["tau"] = config.tau
for p in self.model.parameters():
p.requires_grad = False
# Define the schedule net as a sub-module of the score net for convenience
self.model.schedule_net = get_schedule_network(config).cuda()
self.loss_func = StepLoss(config, self.diff_params)
self.n_training_steps = config.schedule_net_training_steps
# In practice using batch size = 1 would lead to much lower step loss
config.batch_size = 1
model_to_train = self.model.schedule_net
else:
self.loss_func = ScoreLoss(config, self.diff_params)
self.n_training_steps = config.score_net_training_steps
model_to_train = self.model
# Define optimizer
self.optimizer = torch.optim.AdamW(model_to_train.parameters(),
lr=config.lr, weight_decay=config.weight_decay, amsgrad=True)
# Define EMA training helper
self.ema_helper = EMAHelper(mu=config.ema_rate)
self.ema_helper.register(model_to_train)
# Initialize BMUF trainer
self.world_size = int(os.environ['WORLD_SIZE'])
self.bmuf_trainer = BmufTrainer(0, config.local_rank, self.world_size,
model_to_train, config.bmuf_block_momentum, config.bmuf_block_lr)
self.sync_period = config.bmuf_sync_period
self.device = torch.device("cuda:{}".format(config.local_rank))
self.local_rank = config.local_rank
# Get data loaders
self.tr_loader, self.vl_loader = create_train_and_valid_dataloader(config)
if self.training_target == 'score_nets':
# Define a Sampler for quality validation (should be added after BMUF)
self.valid_sampler = Sampler(config)
self.valid_sampler.model = get_score_network(config).cuda()
self.reset()
def reset(self):
"""
Reset training environment
"""
self.tr_loss, self.vl_loss = [], []
self.training_step = 0
if self.load != '':
package = torch.load(self.load, map_location=lambda storage, loc: storage.cuda())
init_state_dict = self.model.state_dict()
mismatch_params = set()
# Remove the checkpoint params that are not found in model
for key in list(package['model_state_dict'].keys()):
if key not in init_state_dict.keys():
param = copy.deepcopy(package['model_state_dict'][key])
del package['model_state_dict'][key]
log('ignored: %s in checkpoint not found in model'%key, self.config)
elif package['model_state_dict'][key].size() != init_state_dict[key].size():
log(package['model_state_dict'][key].size(), self.config)
log(init_state_dict[key].size(), self.config)
log('ignored: %s in checkpoint size mismatched'%key, self.config)
del package['model_state_dict'][key]
# Replace the ignored checkpoint params by the init params
for key in list(init_state_dict.keys()):
if key not in package['model_state_dict'].keys():
mismatch_params.add(key)
log('ignored: %s in model not found in checkpoint'%key, self.config)
package['model_state_dict'][key] = init_state_dict[key]
self.model.load_state_dict(package['model_state_dict'])
if self.config.resume_training and len(mismatch_params) == 0:
# Load steps to resume training
if self.training_target == 'score_nets' and 'score_net_training_step' in package:
self.training_step = package['score_net_training_step']
elif self.training_target == 'schedule_nets' and 'schedule_net_training_step' in package:
self.training_step = package['schedule_net_training_step']
if self.config.freeze_checkpoint_params and len(mismatch_params) > 0:
# Only update new parameters defined in model
for key, param in self.model.named_parameters():
if key not in mismatch_params:
param.requires_grad = False
log('Loaded checkpoint %s' % self.load, self.config)
# Create save folder
os.makedirs(self.exp_dir, exist_ok=True)
self.prev_val_loss, self.min_val_loss = float("inf"), float("inf")
self.val_no_impv, self.halving = 0, 0
def train(self):
"""
Start the main training process
"""
best_state = copy.deepcopy(self.model.state_dict())
while self.training_step < self.n_training_steps:
# Train one epoch
log("Start training %s from step %d ......."%(
self.training_target, self.training_step), self.config)
self.bmuf_trainer.check_all_processes_running()
self.model.train()
start = time.time()
tr_avg_loss = self._run_one_epoch(validate=False)
log('-' * 85, self.config)
log('Train Summary | Step {} | Time {:.2f}s | Train Loss {:.5f}'.format(
self.training_step, time.time()-start, tr_avg_loss), self.config)
log('-' * 85, self.config)
# Start validation
log('Start validation ......', self.config)
self.model.eval()
start = time.time()
with torch.no_grad():
val_loss = self._run_one_epoch(validate=True)
log('-' * 85, self.config)
log('Valid Summary | Step {} | Time {:.2f}s | Valid Loss {:.5f}'.format(
self.training_step, time.time()-start, val_loss.item()), self.config)
log('-' * 85, self.config)
save_or_not = self.bmuf_trainer.update_and_sync(val_loss=val_loss)
if save_or_not:
self.val_no_impv = 0
if self.bmuf_trainer.rank == self.bmuf_trainer.master_node:
model_serialized = self.serialize()
file_path = os.path.join(self.exp_dir, self.training_target,
'%d.pkl' % self.training_step)
torch.save(model_serialized, file_path)
log("Found better model, saved to %s" % file_path, self.config)
if val_loss >= self.min_val_loss:
# LR decays
self.val_no_impv += 1
if self.val_no_impv == self.config.patience:
log("No imporvement for %d epochs, early stopped!"%(
self.config.patience), self.config)
self.bmuf_trainer.kill_all_processes()
break
if self.val_no_impv >= self.config.patience // 2:
self.model.load_state_dict(best_state)
else:
self.val_no_impv = 0
self.min_val_loss = val_loss
best_state = copy.deepcopy(self.ema_helper.state_dict())
def _run_one_epoch(self, validate=False):
"""
Run one epoch
Parameters:
validate (bool): whether to run a valiation epoch or a training epoch
Returns:
average loss (float): the average training/validation loss
"""
start = time.time()
total_loss, total_cnt = 0, 0
if validate and self.training_target == 'score_nets':
# Use EMA state dict for validation
self.valid_sampler.model.load_state_dict(self.ema_helper.state_dict())
# To validate score nets, we use Sampler to test sample quality instead of loss
generated_audio, _ = self.valid_sampler.sampling()
# Compute objective scores (score = PESQ)
quality_score = self.valid_sampler.assess(generated_audio)[0]
return - torch.FloatTensor([quality_score])[0].cuda()
data_loader = self.vl_loader if validate else self.tr_loader
data_loader.dataset.reset()
start_step = self.training_step
for i, batch in enumerate(data_loader):
mels, audios = list(map(lambda x: x.cuda(), batch))
loss = self.loss_func(self.model, mels, audios)
total_loss += loss.detach().sum()
total_cnt += len(loss)
avg_loss = loss.mean()
if not validate:
self.optimizer.zero_grad()
avg_loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
self.optimizer.step()
# Apply block momentum and sync parameters
self.bmuf_trainer.update_and_sync()
self.bmuf_trainer.check_all_processes_running()
# n_gpus * batch_size
self.training_step += len(loss) * self.bmuf_trainer.world_size
if i % self.config.log_period == 0:
log('Train Step {} | Avg. Loss {:.5f} | New Loss {:.5f} | {:.2f}s/step'.format(
self.training_step, total_loss / total_cnt, avg_loss,
(time.time() - start) / (self.training_step - start_step)), self.config)
if self.training_target == 'schedule_nets':
self.ema_helper.update(self.model.schedule_net)
else:
self.ema_helper.update(self.model)
if self.training_step >= self.n_training_steps or\
max(i, self.training_step - start_step) >= self.config.steps_per_epoch:
# Release grad memory
self.optimizer.zero_grad()
return total_loss / total_cnt
else:
if i % self.config.log_period == 0:
log('Valid Step {} | Avg. Loss {:.5f} | New Loss {:.5f} | {:.2f}s/step'.format(
i + 1, total_loss/total_cnt, avg_loss,
(time.time() - start) / (i + 1)), self.config)
if not validate:
# Release grad memory
self.optimizer.zero_grad()
torch.cuda.empty_cache()
return total_loss / total_cnt
def serialize(self):
"""
Pack the model and configurations into a dictionary
Returns:
package (dict): the serialized package to be saved
"""
if self.training_target == 'schedule_nets':
model_state = copy.deepcopy(self.model.state_dict())
ema_state = copy.deepcopy(self.ema_helper.state_dict())
for p in self.ema_helper.state_dict():
model_state['schedule_net.'+p] = ema_state[p]
else:
model_state = copy.deepcopy(self.ema_helper.state_dict())
if self.config.save_fp16:
for p in model_state:
model_state[p] = model_state[p].half()
package = {
# hyper-parameter
'config': self.config,
# state
'model_state_dict': model_state
}
if self.training_target == 'score_nets':
package['score_net_training_step'] = self.training_step
package['schedule_net_training_step'] = 0
else:
package['score_net_training_step'] = self.config.score_net_training_steps
package['schedule_net_training_step'] = self.training_step
return package
================================================
FILE: bddm/utils/check_utils.py
================================================
#!/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
# Check Utils: Find Checkpoints
#
# Author: Max W. Y. Lam (maxwylam@tencent.com)
# Copyright (c) 2021Tencent. All Rights Reserved
#
########################################################################
import os
import glob
import torch
def check_score_network(config):
"""
Check if the score network is trained by searching the ${exp_dir}/score_nets
Parameters:
config (namespace): the configuration given by the user
Returns:
result (bool): a boolean to determine if the score network is trained
path (str): the path to the score network checkpoint if trained
"""
if config.load != '':
ckpt = torch.load(config.load)
if 'score_net_training_step' in ckpt.keys():
if ckpt['score_net_training_step'] >= config.score_net_training_steps:
return True, config.load
else:
# We suppose that an external checkpoint is already well-trained
return True, config.load
max_training_steps = 0
for path in glob.glob(os.path.join(config.exp_dir, 'score_nets', '[0-9]*.pkl')):
n_training_steps = int(path.split('/')[-1].split('.')[0])
max_training_steps = max(max_training_steps, n_training_steps)
if max_training_steps == 0:
return False, None
load_path = os.path.join(config.exp_dir, 'score_nets', '%d.pkl'%(max_training_steps))
return True, load_path
def check_schedule_network(config):
"""
Check if the schedule network is trained by searching the ${exp_dir}/schedule_nets
Parameters:
config (namespace): the configuration given by the user
Returns:
result (bool): a boolean to determine if the schedule network is trained
path (str): the path to the BDDM checkpoint if trained
"""
if config.load != '':
ckpt = torch.load(config.load)
if 'schedule_net_training_step' in ckpt.keys():
return True, config.load
else:
return False, config.load
max_training_steps = 0
for path in glob.glob(os.path.join(config.exp_dir, 'schedule_nets', '[0-9]*.pkl')):
n_training_steps = int(path.split('/')[-1].split('.')[0])
max_training_steps = max(max_training_steps, n_training_steps)
if max_training_steps == 0:
# We suppose that an external checkpoint only contains a well-trained score network
return False, config.load
schedule_net_load_path = os.path.join(
config.exp_dir, 'schedule_nets', '%d.pkl'%(max_training_steps))
return True, schedule_net_load_path
================================================
FILE: bddm/utils/diffusion_utils.py
================================================
#!/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
# Diffusion Utils: Pre-compute Variables
#
# Author: Max W. Y. Lam (maxwylam@tencent.com)
# Copyright (c) 2021Tencent. All Rights Reserved
#
########################################################################
import torch
def compute_diffusion_params(beta):
"""
Compute the diffusion parameters defined in BDDMs
Parameters:
beta (tensor): the beta schedule
Returns:
diff_params (dict): a dictionary of diffusion hyperparameters including:
T (int), beta/alpha/sigma (torch.tensor on cpu, shape=(T, ))
These cpu tensors are changed to cuda tensors on each individual gpu
"""
alpha = 1 - beta
sigma = beta + 0
for t in range(1, len(beta)):
alpha[t] *= alpha[t-1]
sigma[t] *= (1-alpha[t-1]) / (1-alpha[t])
alpha = torch.sqrt(alpha)
sigma = torch.sqrt(sigma)
diff_params = {"T": len(beta), "beta": beta, "alpha": alpha, "sigma": sigma}
return diff_params
def map_noise_scale_to_time_step(alpha_infer, alpha):
"""
Map an alpha_infer to an approx. time step in the alpha tensor.
(Modified from Algorithm 3 Fast Sampling in DiffWave)
Parameters:
alpha_infer (float): noise scale at time `n` for inference
alpha (tensor): noise scales used in training, shape=(T, )
Returns:
t (float): approximated time step in alpha tensor
"""
if alpha_infer < alpha[-1]:
return len(alpha) - 1
if alpha_infer > alpha[0]:
return 0
for t in range(len(alpha) - 1):
if alpha[t+1] <= alpha_infer <= alpha[t]:
step_diff = alpha[t] - alpha_infer
step_diff /= alpha[t] - alpha[t+1]
return t + step_diff.item()
return -1
================================================
FILE: bddm/utils/log_utils.py
================================================
#!/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
# Log Utils: Print Log with Time/PID
#
# Author: Max W. Y. Lam (maxwylam@tencent.com)
# Copyright (c) 2021Tencent. All Rights Reserved
#
########################################################################
import os
import sys
import time
def ctime():
"""
Get time now
Returns:
time_string (str): current time in string
"""
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
def head():
"""
Get header for logging: time and PID of the current process
Returns:
header_string (str): the header string for logging
"""
return "%s %d" % (ctime(), os.getpid())
def log(msg, config):
"""
Save log to the device[id].log & print device0.log to STDOUT
Parameters:
msg (sting): message to be logged
config (namespace): BDDM Configuration
"""
if config.local_rank == 0:
sys.stdout.write("[%s] %s\n" % (head(), msg))
sys.stdout.flush()
os.makedirs(config.exp_dir, exist_ok=True)
with open(os.path.join(config.exp_dir, 'device%d.log'%(config.local_rank)), 'a') as f:
f.write("[%s] %s\n" % (head(), msg))
f.flush()
================================================
FILE: egs/lj/DiffWave.pkl
================================================
[File too large to display: 26.6 MB]
================================================
FILE: egs/lj/DiffWave_GALR.pkl
================================================
[File too large to display: 15.4 MB]
================================================
FILE: egs/lj/conf.yml
================================================
# General config
exp_dir: './exp'
seed: 0
load: 'DiffWave_GALR.pkl'
# Generation config
sampling_noise_schedule: './noise_schedules/12steps_PESQ4.07_STOI0.978.ns'
gen_data_dir: "./demo_case"
#gen_data_dir: "./test_wavs"
use_ddim_steps: 0 # options: 0-DDPM, 1-DDIM, M-DDIM M steps (when noise schedule is not given)
# Diffusion config
T: 200
beta_0: 0.0001
beta_T: 0.02
tau: 66
# Noise scheduling config
N: 12
bddm_search_bins: 9
noise_scheduling_attempts: 10
# Score network config
score_net: 'DiffWave'
num_res_layers: 30
in_channels: 1
res_channels: 128
skip_channels: 128
dilation_cycle: 10
out_channels: 1
diffusion_step_embed_dim_in: 128
diffusion_step_embed_dim_mid: 512
diffusion_step_embed_dim_out: 512
# Schedule network config
schedule_net: 'GALR'
blocks: 1
hidden_dim: 128
input_dim: 128
window_length: 8
segment_size: 64
# Trainer config
resume_training: False
save_fp16: True
steps_per_epoch: 1000
score_net_training_steps: 5000000
schedule_net_training_steps: 10000
lr: 0.00001
freeze_checkpoint_params: False # only used when provided load path
save_generated_dir: 'gen_wav'
batch_size: 1
patience: 100
grad_clip: 5.0
weight_decay: 0.00001
log_period: 1
bmuf_block_momentum: 0.5
bmuf_block_lr: 0.7
bmuf_sync_period: 2
ema_rate: 0.999
# Data config
train_data_dir: "LJSpeech-1.1/train_wavs"
valid_data_dir: "LJSpeech-1.1/val_wavs"
n_worker: 10
sampling_rate: 22050
seg_len: 16000
fil_len: 1024
hop_len: 256
win_len: 1024
mel_fmin: 0.0
mel_fmax: 8000.0
================================================
FILE: egs/lj/eval.list
================================================
LJ001-0001.wav
LJ001-0047.wav
LJ003-0140.wav
LJ003-0145.wav
LJ003-0235.wav
LJ003-0240.wav
LJ004-0195.wav
LJ005-0059.wav
LJ005-0206.wav
LJ005-0269.wav
LJ006-0141.wav
LJ006-0301.wav
LJ007-0104.wav
LJ007-0106.wav
LJ007-0192.wav
LJ007-0208.wav
LJ008-0043.wav
LJ008-0265.wav
LJ009-0071.wav
LJ009-0190.wav
LJ010-0056.wav
LJ010-0082.wav
LJ011-0053.wav
LJ011-0062.wav
LJ011-0221.wav
LJ012-0005.wav
LJ012-0015.wav
LJ013-0004.wav
LJ013-0060.wav
LJ015-0063.wav
LJ015-0254.wav
LJ016-0070.wav
LJ016-0318.wav
LJ017-0049.wav
LJ017-0251.wav
LJ018-0029.wav
LJ018-0071.wav
LJ018-0105.wav
LJ019-0329.wav
LJ019-0350.wav
LJ019-0358.wav
LJ020-0086.wav
LJ020-0102.wav
LJ022-0164.wav
LJ022-0185.wav
LJ023-0117.wav
LJ024-0103.wav
LJ025-0143.wav
LJ026-0030.wav
LJ027-0020.wav
LJ028-0130.wav
LJ028-0155.wav
LJ028-0302.wav
LJ028-0432.wav
LJ028-0465.wav
LJ028-0497.wav
LJ028-0499.wav
LJ029-0018.wav
LJ030-0088.wav
LJ031-0072.wav
LJ032-0010.wav
LJ032-0057.wav
LJ035-0061.wav
LJ035-0124.wav
LJ035-0200.wav
LJ036-0093.wav
LJ037-0061.wav
LJ037-0186.wav
LJ037-0195.wav
LJ037-0237.wav
LJ038-0008.wav
LJ038-0051.wav
LJ038-0092.wav
LJ039-0072.wav
LJ040-0210.wav
LJ040-0216.wav
LJ041-0075.wav
LJ042-0092.wav
LJ042-0159.wav
LJ042-0181.wav
LJ042-0224.wav
LJ042-0248.wav
LJ043-0026.wav
LJ043-0150.wav
LJ045-0123.wav
LJ045-0128.wav
LJ045-0147.wav
LJ045-0216.wav
LJ046-0026.wav
LJ046-0043.wav
LJ046-0162.wav
LJ046-0186.wav
LJ047-0058.wav
LJ047-0114.wav
LJ047-0131.wav
LJ048-0022.wav
LJ049-0155.wav
LJ049-0165.wav
LJ050-0043.wav
LJ050-0134.wav
================================================
FILE: egs/lj/generate.sh
================================================
#!/bin/bash
## example usage (single GPU): sh generate.sh 0 conf.yml
cuda=$1
CUDA_VISIBLE_DEVICES=$cuda python3 main.py \
--command generate --config $2
================================================
FILE: egs/lj/main.py
================================================
import os
import sys
sys.path.append('../../')
import json
import shutil
import hashlib
import argparse
import numpy as np
import yaml
import torch
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
from bddm import trainer, sampler
from bddm.utils.log_utils import log
def dict_hash_5char(dictionary):
''' Map a unique dictionary into a 5-character string '''
dhash = hashlib.md5()
encoded = json.dumps(dictionary, sort_keys=True).encode()
dhash.update(encoded)
return dhash.hexdigest()[:5]
def start_exp(config, config_hash):
''' Create experiment directory or set it to an existing directory '''
if config.load != '' and '_nets' in config.load:
config.exp_dir = '/'.join(config.load.split('/')[:-2])
else:
config.exp_dir += '/%s-%s_conf-hash-%s' % (
config.score_net, config.schedule_net, config_hash)
if config.local_rank != 0:
return
log('Experiment directory: %s' % (config.exp_dir), config)
# Backup the config file
shutil.copyfile(config.config, os.path.join(config.exp_dir, 'conf.yml'))
# Create a backup scripts sub-folder
os.makedirs(os.path.join(config.exp_dir, 'backup_scripts'), exist_ok=True)
# Backup all .py files under bddm/
backup_files = []
for root, _, files in os.walk("../../"):
if 'egs' in root:
continue
for f in files:
if f.endswith(".py"):
backup_files.append(os.path.join(root, f))
for src_file in backup_files:
basename = src_file
while '../' in basename:
basename = basename.replace('../', '')
basename = basename.replace('./', '')
dst_file = os.path.join(config.exp_dir, 'backup_scripts', basename)
dst_dir = '/'.join(dst_file.split('/')[:-1])
if not os.path.exists(dst_dir):
os.makedirs(dst_dir)
shutil.copyfile(src_file, dst_file)
# Prepare sub-folders for saving model checkpoints
os.makedirs(os.path.join(config.exp_dir, 'score_nets'), exist_ok=True)
os.makedirs(os.path.join(config.exp_dir, 'schedule_nets'), exist_ok=True)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Bilateral Denoising Diffusion Models')
parser.add_argument('--command',
type=str,
default='train',
help='available commands: train | search | generate')
parser.add_argument('--config',
'-c',
type=str,
default='conf.yml',
help='config .yml path')
parser.add_argument('--local_rank',
type=int,
default=0,
help='process device ID for multi-GPU training')
arg_config = parser.parse_args()
# Parse yaml and define configurations
config = arg_config.__dict__
with open(arg_config.config) as f:
yaml_config = yaml.safe_load(f)
HASH = dict_hash_5char(yaml_config)
for key in yaml_config:
config[key] = yaml_config[key]
config = argparse.Namespace(**config)
# Set random seed for reproducible results
np.random.seed(config.seed)
torch.manual_seed(config.seed)
torch.cuda.manual_seed_all(config.seed)
torch.cuda.set_device(config.local_rank)
# Check if the command is valid or not
commands = ['train', 'schedule', 'generate']
assert config.command in commands, 'Error: %s command not found.'%(config.command)
# Create/retrieve exp dir
start_exp(config, HASH)
log('Argv: %s' % (' '.join(sys.argv)), config)
try:
if config.command == 'train':
# Create Trainer for training
trainer = trainer.Trainer(config)
trainer.train()
elif config.command == 'schedule':
# Create Sampler for noise scheduling
sampler = sampler.Sampler(config)
sampler.noise_scheduling_without_params()
elif config.command == 'generate':
# Create Sampler for generation
# NOTE: Remember to define "gen_data_dir" in conf.yml before data generation
sampler = sampler.Sampler(config)
sampler.generate()
log('-' * 80, config)
except KeyboardInterrupt:
log('-' * 80, config)
log('Exiting early', config)
================================================
FILE: egs/lj/schedule.sh
================================================
#!/bin/bash
## example usage (single GPU): sh schedule.sh 0 conf.yml
cuda=$1
CUDA_VISIBLE_DEVICES=$cuda python3 main.py \
--command schedule --config $2
================================================
FILE: egs/lj/train.sh
================================================
#!/bin/bash
## example usage: sh train.sh 0,1,2,3 conf.yml
cuda=$1
comma=${cuda//[^,]}
nproc=$((${#comma}+1))
CUDA_VISIBLE_DEVICES=$cuda python3 -m torch.distributed.launch \
--nproc_per_node $nproc --master_port $RANDOM main.py \
--command train --config $2
================================================
FILE: egs/lj/valid.list
================================================
LJ001-0023.wav
LJ001-0084.wav
LJ001-0144.wav
LJ001-0163.wav
LJ001-0179.wav
LJ002-0042.wav
LJ002-0140.wav
LJ002-0201.wav
LJ002-0245.wav
LJ003-0004.wav
LJ003-0009.wav
LJ003-0143.wav
LJ003-0227.wav
LJ003-0305.wav
LJ003-0316.wav
LJ003-0336.wav
LJ004-0059.wav
LJ004-0098.wav
LJ004-0130.wav
LJ004-0140.wav
LJ004-0173.wav
LJ004-0199.wav
LJ005-0071.wav
LJ005-0108.wav
LJ005-0226.wav
LJ006-0053.wav
LJ006-0055.wav
LJ006-0081.wav
LJ006-0099.wav
LJ006-0103.wav
LJ006-0107.wav
LJ006-0152.wav
LJ007-0169.wav
LJ007-0232.wav
LJ008-0007.wav
LJ008-0053.wav
LJ008-0110.wav
LJ008-0200.wav
LJ008-0263.wav
LJ008-0280.wav
LJ008-0281.wav
LJ009-0068.wav
LJ009-0158.wav
LJ009-0179.wav
LJ009-0184.wav
LJ009-0185.wav
LJ009-0234.wav
LJ009-0235.wav
LJ009-0304.wav
LJ010-0011.wav
LJ010-0057.wav
LJ010-0106.wav
LJ010-0191.wav
LJ010-0282.wav
LJ010-0314.wav
LJ011-0064.wav
LJ011-0076.wav
LJ011-0130.wav
LJ011-0144.wav
LJ011-0249.wav
LJ011-0275.wav
LJ012-0046.wav
LJ012-0061.wav
LJ012-0100.wav
LJ012-0133.wav
LJ012-0145.wav
LJ012-0173.wav
LJ012-0225.wav
LJ012-0288.wav
LJ012-0290.wav
LJ013-0080.wav
LJ013-0096.wav
LJ013-0103.wav
LJ013-0137.wav
LJ013-0170.wav
LJ013-0247.wav
LJ013-0258.wav
LJ014-0063.wav
LJ014-0148.wav
LJ014-0202.wav
LJ014-0205.wav
LJ014-0278.wav
LJ014-0294.wav
LJ014-0309.wav
LJ014-0314.wav
LJ015-0136.wav
LJ015-0177.wav
LJ015-0295.wav
LJ016-0214.wav
LJ016-0237.wav
LJ016-0284.wav
LJ016-0332.wav
LJ017-0047.wav
LJ017-0050.wav
LJ017-0078.wav
LJ017-0243.wav
LJ018-0047.wav
LJ018-0055.wav
LJ018-0142.wav
LJ018-0262.wav
LJ018-0337.wav
LJ018-0361.wav
LJ019-0017.wav
LJ019-0058.wav
LJ019-0074.wav
LJ019-0080.wav
LJ019-0097.wav
LJ019-0123.wav
LJ019-0187.wav
LJ019-0253.wav
LJ019-0274.wav
LJ019-0317.wav
LJ020-0025.wav
LJ021-0132.wav
LJ021-0137.wav
LJ021-0180.wav
LJ022-0004.wav
LJ022-0133.wav
LJ022-0148.wav
LJ022-0152.wav
LJ023-0003.wav
LJ023-0033.wav
LJ024-0139.wav
LJ024-0141.wav
LJ025-0056.wav
LJ025-0104.wav
LJ025-0117.wav
LJ026-0022.wav
LJ027-0063.wav
LJ027-0075.wav
LJ028-0026.wav
LJ028-0083.wav
LJ028-0116.wav
LJ028-0142.wav
LJ028-0293.wav
LJ028-0343.wav
LJ028-0387.wav
LJ028-0488.wav
LJ028-0502.wav
LJ029-0035.wav
LJ029-0060.wav
LJ030-0013.wav
LJ030-0121.wav
LJ030-0154.wav
LJ030-0241.wav
LJ031-0083.wav
LJ031-0123.wav
LJ031-0205.wav
LJ032-0078.wav
LJ032-0119.wav
LJ032-0179.wav
LJ032-0220.wav
LJ033-0035.wav
LJ033-0066.wav
LJ033-0085.wav
LJ033-0143.wav
LJ034-0056.wav
LJ034-0073.wav
LJ034-0178.wav
LJ034-0206.wav
LJ035-0002.wav
LJ035-0079.wav
LJ036-0007.wav
LJ036-0013.wav
LJ036-0015.wav
LJ036-0067.wav
LJ036-0173.wav
LJ037-0056.wav
LJ037-0067.wav
LJ037-0075.wav
LJ037-0085.wav
LJ037-0102.wav
LJ037-0145.wav
LJ037-0163.wav
LJ038-0015.wav
LJ038-0019.wav
LJ038-0071.wav
LJ038-0083.wav
LJ038-0282.wav
LJ038-0294.wav
LJ038-0303.wav
LJ039-0076.wav
LJ039-0088.wav
LJ040-0034.wav
LJ040-0091.wav
LJ040-0123.wav
LJ040-0229.wav
LJ041-0103.wav
LJ041-0113.wav
LJ042-0037.wav
LJ042-0049.wav
LJ042-0060.wav
LJ042-0120.wav
LJ042-0127.wav
LJ042-0147.wav
LJ042-0161.wav
LJ042-0201.wav
LJ043-0010.wav
LJ043-0074.wav
LJ043-0089.wav
LJ043-0128.wav
LJ044-0076.wav
LJ044-0139.wav
LJ044-0140.wav
LJ044-0200.wav
LJ044-0209.wav
LJ045-0184.wav
LJ045-0191.wav
LJ046-0169.wav
LJ046-0203.wav
LJ047-0031.wav
LJ047-0120.wav
LJ047-0132.wav
LJ047-0153.wav
LJ047-0178.wav
LJ047-0192.wav
LJ047-0201.wav
LJ047-0238.wav
LJ048-0046.wav
LJ048-0093.wav
LJ048-0119.wav
LJ048-0139.wav
LJ048-0193.wav
LJ048-0207.wav
LJ048-0270.wav
LJ049-0044.wav
LJ049-0089.wav
LJ049-0140.wav
LJ049-0179.wav
LJ050-0012.wav
LJ050-0049.wav
LJ050-0070.wav
LJ050-0078.wav
LJ050-0136.wav
LJ050-0179.wav
LJ050-0270.wav
================================================
FILE: egs/vctk/DiffWave.pkl
================================================
[File too large to display: 13.4 MB]
================================================
FILE: egs/vctk/DiffWave_GALR.pkl
================================================
[File too large to display: 15.4 MB]
================================================
FILE: egs/vctk/conf.yml
================================================
# General config
exp_dir: './exp'
seed: 0
load: 'DiffWave_GALR.pkl'
# Generation config
sampling_noise_schedule: './noise_schedules/8steps_PESQ3.88_STOI0.987.ns'
gen_data_dir: "./demo_case"
#gen_data_dir: "./test_wavs"
use_ddim_steps: 0 # options: 0-DDPM, 1-DDIM, M-DDIM M steps (when noise schedule is not given)
# Diffusion config
T: 200
beta_0: 0.0001
beta_T: 0.02
tau: 25
# Noise scheduling config
N: 8
bddm_search_bins: 9
noise_scheduling_attempts: 10
# Score network config
score_net: 'DiffWave'
num_res_layers: 30
in_channels: 1
res_channels: 128
skip_channels: 128
dilation_cycle: 10
out_channels: 1
diffusion_step_embed_dim_in: 128
diffusion_step_embed_dim_mid: 512
diffusion_step_embed_dim_out: 512
# Schedule network config
schedule_net: 'GALR'
blocks: 1
hidden_dim: 128
input_dim: 128
window_length: 8
segment_size: 64
# Trainer config
resume_training: False
save_fp16: True
steps_per_epoch: 1000
score_net_training_steps: 5000000
schedule_net_training_steps: 5000
lr: 0.00001
freeze_checkpoint_params: False # only used when provided load path
save_generated_dir: 'gen_wav'
batch_size: 1
patience: 100
grad_clip: 5.0
weight_decay: 0.00001
log_period: 1
bmuf_block_momentum: 0.5
bmuf_block_lr: 0.7
bmuf_sync_period: 2
ema_rate: 0.999
# Data config
train_data_dir: "VCTK-Corpus/wav"
valid_data_dir: "val_case"
n_worker: 10
sampling_rate: 22050
seg_len: 16000
fil_len: 1024
hop_len: 256
win_len: 1024
mel_fmin: 0.0
mel_fmax: 8000.0
================================================
FILE: egs/vctk/generate.sh
================================================
#!/bin/bash
## example usage (single GPU): sh generate.sh 0 conf.yml
cuda=$1
CUDA_VISIBLE_DEVICES=$cuda python3 main.py \
--command generate --config $2
================================================
FILE: egs/vctk/main.py
================================================
import os
import sys
sys.path.append('../../')
import json
import shutil
import hashlib
import argparse
import numpy as np
import yaml
import torch
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
from bddm import trainer, sampler
from bddm.utils.log_utils import log
def dict_hash_5char(dictionary):
''' Map a unique dictionary into a 5-character string '''
dhash = hashlib.md5()
encoded = json.dumps(dictionary, sort_keys=True).encode()
dhash.update(encoded)
return dhash.hexdigest()[:5]
def start_exp(config, config_hash):
''' Create experiment directory or set it to an existing directory '''
if config.load != '' and '_nets' in config.load:
config.exp_dir = '/'.join(config.load.split('/')[:-2])
else:
config.exp_dir += '/%s-%s_conf-hash-%s' % (
config.score_net, config.schedule_net, config_hash)
if config.local_rank != 0:
return
log('Experiment directory: %s' % (config.exp_dir), config)
# Backup the config file
shutil.copyfile(config.config, os.path.join(config.exp_dir, 'conf.yml'))
# Create a backup scripts sub-folder
os.makedirs(os.path.join(config.exp_dir, 'backup_scripts'), exist_ok=True)
# Backup all .py files under bddm/
backup_files = []
for root, _, files in os.walk("../../"):
if 'egs' in root:
continue
for f in files:
if f.endswith(".py"):
backup_files.append(os.path.join(root, f))
for src_file in backup_files:
basename = src_file
while '../' in basename:
basename = basename.replace('../', '')
basename = basename.replace('./', '')
dst_file = os.path.join(config.exp_dir, 'backup_scripts', basename)
dst_dir = '/'.join(dst_file.split('/')[:-1])
if not os.path.exists(dst_dir):
os.makedirs(dst_dir)
shutil.copyfile(src_file, dst_file)
# Prepare sub-folders for saving model checkpoints
os.makedirs(os.path.join(config.exp_dir, 'score_nets'), exist_ok=True)
os.makedirs(os.path.join(config.exp_dir, 'schedule_nets'), exist_ok=True)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Bilateral Denoising Diffusion Models')
parser.add_argument('--command',
type=str,
default='train',
help='available commands: train | search | generate')
parser.add_argument('--config',
'-c',
type=str,
default='conf.yml',
help='config .yml path')
parser.add_argument('--local_rank',
type=int,
default=0,
help='process device ID for multi-GPU training')
arg_config = parser.parse_args()
# Parse yaml and define configurations
config = arg_config.__dict__
with open(arg_config.config) as f:
yaml_config = yaml.safe_load(f)
HASH = dict_hash_5char(yaml_config)
for key in yaml_config:
config[key] = yaml_config[key]
config = argparse.Namespace(**config)
# Set random seed for reproducible results
np.random.seed(config.seed)
torch.manual_seed(config.seed)
torch.cuda.manual_seed_all(config.seed)
torch.cuda.set_device(config.local_rank)
# Check if the command is valid or not
commands = ['train', 'schedule', 'generate']
assert config.command in commands, 'Error: %s command not found.'%(config.command)
# Create/retrieve exp dir
start_exp(config, HASH)
log('Argv: %s' % (' '.join(sys.argv)), config)
try:
if config.command == 'train':
# Create Trainer for training
trainer = trainer.Trainer(config)
trainer.train()
elif config.command == 'schedule':
# Create Sampler for noise scheduling
sampler = sampler.Sampler(config)
sampler.noise_scheduling_without_params()
elif config.command == 'generate':
# Create Sampler for generation
# NOTE: Remember to define "gen_data_dir" in conf.yml before data generation
sampler = sampler.Sampler(config)
sampler.generate()
log('-' * 80, config)
except KeyboardInterrupt:
log('-' * 80, config)
log('Exiting early', config)
================================================
FILE: egs/vctk/schedule.sh
================================================
#!/bin/bash
## example usage (single GPU): sh schedule.sh 0 conf.yml
cuda=$1
CUDA_VISIBLE_DEVICES=$cuda python3 main.py \
--command schedule --config $2
================================================
FILE: egs/vctk/train.sh
================================================
#!/bin/bash
## example usage: sh train.sh 0,1,2,3 conf.yml
cuda=$1
comma=${cuda//[^,]}
nproc=$((${#comma}+1))
CUDA_VISIBLE_DEVICES=$cuda python3 -m torch.distributed.launch \
--nproc_per_node $nproc --master_port $RANDOM main.py \
--command train --config $2
gitextract_wr3lh52z/
├── .gitignore
├── LICENSE
├── README.md
├── bddm/
│ ├── loader/
│ │ ├── __init__.py
│ │ ├── dataset.py
│ │ └── stft.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── diffwave.py
│ │ └── galr.py
│ ├── sampler/
│ │ ├── __init__.py
│ │ └── sampler.py
│ ├── trainer/
│ │ ├── __init__.py
│ │ ├── bmuf.py
│ │ ├── ema.py
│ │ ├── loss.py
│ │ └── trainer.py
│ └── utils/
│ ├── check_utils.py
│ ├── diffusion_utils.py
│ └── log_utils.py
└── egs/
├── lj/
│ ├── DiffWave.pkl
│ ├── DiffWave_GALR.pkl
│ ├── conf.yml
│ ├── demo_case/
│ │ ├── LJ001-0001.mel
│ │ ├── LJ004-0233.mel
│ │ └── LJ050-0270.mel
│ ├── eval.list
│ ├── generate.sh
│ ├── main.py
│ ├── noise_schedules/
│ │ ├── 12steps_PESQ4.07_STOI0.978.ns
│ │ └── 7steps_PESQ4.05_STOI0.975.ns
│ ├── schedule.sh
│ ├── train.sh
│ └── valid.list
└── vctk/
├── DiffWave.pkl
├── DiffWave_GALR.pkl
├── conf.yml
├── demo_case/
│ ├── p259_311.mel
│ └── p287_123.mel
├── generate.sh
├── main.py
├── noise_schedules/
│ └── 8steps_PESQ3.88_STOI0.987.ns
├── schedule.sh
└── train.sh
SYMBOL INDEX (112 symbols across 15 files)
FILE: bddm/loader/dataset.py
class SpectrogramDataset (line 27) | class SpectrogramDataset(data.Dataset):
method __init__ (line 29) | def __init__(self, data_dir, n_gpus, is_sampling, sampling_rate,
method reset (line 78) | def reset(self):
method find_all_wavs_in_dir (line 93) | def find_all_wavs_in_dir(self, data_dir):
method find_all_mels_in_dir (line 107) | def find_all_mels_in_dir(self, data_dir):
method crop_audio_and_mel (line 119) | def crop_audio_and_mel(self, audio, mel_spec):
method __getitem__ (line 149) | def __getitem__(self, index):
method __len__ (line 208) | def __len__(self):
function create_train_and_valid_dataloader (line 220) | def create_train_and_valid_dataloader(config):
function create_generation_dataloader (line 259) | def create_generation_dataloader(config):
FILE: bddm/loader/stft.py
class STFT (line 43) | class STFT(nn.Module):
method __init__ (line 45) | def __init__(self, filter_length=800, hop_length=200, win_length=800,
method transform (line 78) | def transform(self, input_data):
method forward (line 108) | def forward(self, input_data):
class TacotronSTFT (line 114) | class TacotronSTFT(nn.Module):
method __init__ (line 118) | def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
method spectral_normalize (line 130) | def spectral_normalize(self, x):
method spectral_de_normalize (line 133) | def spectral_de_normalize(self, x):
method mel_spectrogram (line 136) | def mel_spectrogram(self, y):
FILE: bddm/models/__init__.py
function get_score_network (line 18) | def get_score_network(config):
function get_schedule_network (line 25) | def get_schedule_network(config):
FILE: bddm/models/diffwave.py
function calc_diffusion_step_embedding (line 22) | def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_...
function swish (line 55) | def swish(x):
class Conv (line 61) | class Conv(nn.Module):
method __init__ (line 62) | def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1):
method forward (line 70) | def forward(self, x):
class ZeroConv1d (line 77) | class ZeroConv1d(nn.Module):
method __init__ (line 78) | def __init__(self, in_channel, out_channel):
method forward (line 84) | def forward(self, x):
class ResidualBlock (line 91) | class ResidualBlock(nn.Module):
method __init__ (line 92) | def __init__(self, res_channels, skip_channels, dilation,
method forward (line 127) | def forward(self, input_data):
class ResidualGroup (line 166) | class ResidualGroup(nn.Module):
method __init__ (line 167) | def __init__(self, res_channels, skip_channels, num_res_layers, dilati...
method forward (line 187) | def forward(self, input_data):
class DiffWave (line 209) | class DiffWave(nn.Module):
method __init__ (line 210) | def __init__(self, in_channels, res_channels, skip_channels, out_chann...
method forward (line 231) | def forward(self, input_data):
FILE: bddm/models/galr.py
class Encoder (line 21) | class Encoder(nn.Module):
method __init__ (line 23) | def __init__(self, enc_dim, win_len):
method forward (line 38) | def forward(self, signals):
class BiLSTMproj (line 54) | class BiLSTMproj(nn.Module):
method __init__ (line 56) | def __init__(self, enc_dim, hid_dim):
method forward (line 75) | def forward(self, intra_segs):
class AttnPositionalEncoding (line 95) | class AttnPositionalEncoding(nn.Module):
method __init__ (line 97) | def __init__(self, enc_dim, attn_max_len=5000):
method forward (line 115) | def forward(self, x):
class GlobalAttnLayer (line 128) | class GlobalAttnLayer(nn.Module):
method __init__ (line 130) | def __init__(self, enc_dim, n_attn_head, attn_dropout):
method forward (line 148) | def forward(self, inter_segs):
class DeepGlobalAttnLayer (line 165) | class DeepGlobalAttnLayer(nn.Module):
method __init__ (line 167) | def __init__(self, enc_dim, n_attn_head, attn_dropout, n_attn_layer=1):
method forward (line 187) | def forward(self, inter_segs):
class GALRBlock (line 205) | class GALRBlock(nn.Module):
method __init__ (line 207) | def __init__(self, enc_dim, hid_dim, seg_len, low_dim=8, n_attn_head=8...
method forward (line 234) | def forward(self, segments):
class _GALR (line 279) | class _GALR(nn.Module):
method __init__ (line 281) | def __init__(self, n_block, enc_dim, hid_dim, win_len, seg_len,
method forward (line 317) | def forward(self, noisy_signals):
method pad_zeros (line 338) | def pad_zeros(self, signals):
method pad_segment (line 360) | def pad_segment(self, frames):
method split_feature (line 383) | def split_feature(self, frames):
class GALR (line 407) | class GALR(nn.Module):
method __init__ (line 409) | def __init__(self, blocks=2, input_dim=128, hidden_dim=128, window_len...
method forward (line 425) | def forward(self, audio, scales):
FILE: bddm/sampler/sampler.py
class Sampler (line 34) | class Sampler(object):
method __init__ (line 38) | def __init__(self, config):
method reset (line 72) | def reset(self):
method draw_reference_data_pair (line 89) | def draw_reference_data_pair(self):
method generate (line 96) | def generate(self):
method noise_scheduling (line 128) | def noise_scheduling(self, ddim=False):
method sampling (line 197) | def sampling(self, schedule=None, condition=None,
method noise_scheduling_with_params (line 274) | def noise_scheduling_with_params(self, alpha_param, beta_param):
method noise_scheduling_without_params (line 336) | def noise_scheduling_without_params(self):
method assess (line 389) | def assess(self, generated_audio, audio_key=None):
FILE: bddm/trainer/bmuf.py
class BmufTrainer (line 21) | class BmufTrainer(object):
method __init__ (line 23) | def __init__(self, master_node, rank, world_size, model, block_momentu...
method check_all_processes_running (line 56) | def check_all_processes_running(self):
method get_average_valid_loss (line 65) | def get_average_valid_loss(self, val_loss):
method update_and_sync (line 83) | def update_and_sync(self, val_loss=None):
method _copy_vec_to_param (line 116) | def _copy_vec_to_param(self, vec):
FILE: bddm/trainer/ema.py
class EMAHelper (line 16) | class EMAHelper(object):
method __init__ (line 18) | def __init__(self, mu=0.999):
method register (line 28) | def register(self, module):
method update (line 41) | def update(self, module):
method ema (line 55) | def ema(self, module):
method ema_copy (line 68) | def ema_copy(self, module):
method state_dict (line 87) | def state_dict(self):
method load_state_dict (line 96) | def load_state_dict(self, state_dict):
FILE: bddm/trainer/loss.py
class ScoreLoss (line 17) | class ScoreLoss(nn.Module):
method __init__ (line 19) | def __init__(self, config, diff_params):
method forward (line 32) | def forward(self, model, mels, audios):
class StepLoss (line 54) | class StepLoss(nn.Module):
method __init__ (line 56) | def __init__(self, config, diff_params):
method forward (line 70) | def forward(self, model, mels, audios):
FILE: bddm/trainer/trainer.py
class Trainer (line 31) | class Trainer(object):
method __init__ (line 33) | def __init__(self, config):
method reset (line 93) | def reset(self):
method train (line 138) | def train(self):
method _run_one_epoch (line 189) | def _run_one_epoch(self, validate=False):
method serialize (line 251) | def serialize(self):
FILE: bddm/utils/check_utils.py
function check_score_network (line 18) | def check_score_network(config):
function check_schedule_network (line 46) | def check_schedule_network(config):
FILE: bddm/utils/diffusion_utils.py
function compute_diffusion_params (line 16) | def compute_diffusion_params(beta):
function map_noise_scale_to_time_step (line 39) | def map_noise_scale_to_time_step(alpha_infer, alpha):
FILE: bddm/utils/log_utils.py
function ctime (line 17) | def ctime():
function head (line 27) | def head():
function log (line 37) | def log(msg, config):
FILE: egs/lj/main.py
function dict_hash_5char (line 19) | def dict_hash_5char(dictionary):
function start_exp (line 27) | def start_exp(config, config_hash):
FILE: egs/vctk/main.py
function dict_hash_5char (line 19) | def dict_hash_5char(dictionary):
function start_exp (line 27) | def start_exp(config, config_hash):
Condensed preview — 43 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (143K chars).
[
{
"path": ".gitignore",
"chars": 116,
"preview": "egs/lj/exp\negs/lj/LJSpeech1.1\n*.log\n*.pyc\n__pycache__/\n*/__pycache__/\n*.zip\n*/*.zip\nsource/\nmake.bat\nMakefile\n*.swp\n"
},
{
"path": "LICENSE",
"chars": 11337,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 7601,
"preview": "# Bilateral Denoising Diffusion Models (BDDMs)\n\n[ 2017, Prem Seetharaman\nAll rights reserved.\n\n* Redistribution and use in source "
},
{
"path": "bddm/models/__init__.py",
"chars": 940,
"preview": "#!/bin/env python\n# -*- coding: utf-8 -*-\n########################################################################\n#\n# "
},
{
"path": "bddm/models/diffwave.py",
"chars": 9454,
"preview": "#!/bin/env python\n# -*- coding: utf-8 -*-\n########################################################################\n#\n# "
},
{
"path": "bddm/models/galr.py",
"chars": 19247,
"preview": "#!/bin/env python\n# -*- coding: utf-8 -*-\n########################################################################\n#\n# "
},
{
"path": "bddm/sampler/__init__.py",
"chars": 333,
"preview": "#!/bin/env python\n# -*- coding: utf-8 -*-\n########################################################################\n#\n# "
},
{
"path": "bddm/sampler/sampler.py",
"chars": 20624,
"preview": "#!/bin/env python\n# -*- coding: utf-8 -*-\n########################################################################\n#\n# "
},
{
"path": "bddm/trainer/__init__.py",
"chars": 334,
"preview": "#!/bin/env python\n# -*- coding: utf-8 -*-\n########################################################################\n#\n# "
},
{
"path": "bddm/trainer/bmuf.py",
"chars": 5379,
"preview": "#!/bin/env python\n# -*- coding: utf-8 -*-\n########################################################################\n#\n# "
},
{
"path": "bddm/trainer/ema.py",
"chars": 3040,
"preview": "#!/bin/env python\n# -*- coding: utf-8 -*-\n########################################################################\n#\n# "
},
{
"path": "bddm/trainer/loss.py",
"chars": 3862,
"preview": "#!/bin/env python\n# -*- coding: utf-8 -*-\n########################################################################\n#\n# "
},
{
"path": "bddm/trainer/trainer.py",
"chars": 13572,
"preview": "#!/bin/env python\n# -*- coding: utf-8 -*-\n########################################################################\n#\n# "
},
{
"path": "bddm/utils/check_utils.py",
"chars": 2713,
"preview": "#!/bin/env python\n# -*- coding: utf-8 -*-\n########################################################################\n#\n# "
},
{
"path": "bddm/utils/diffusion_utils.py",
"chars": 1874,
"preview": "#!/bin/env python\n# -*- coding: utf-8 -*-\n########################################################################\n#\n# "
},
{
"path": "bddm/utils/log_utils.py",
"chars": 1293,
"preview": "#!/bin/env python\n# -*- coding: utf-8 -*-\n########################################################################\n#\n# "
},
{
"path": "egs/lj/conf.yml",
"chars": 1466,
"preview": "# General config\nexp_dir: './exp'\nseed: 0\nload: 'DiffWave_GALR.pkl'\n# Generation config\nsampling_noise_schedule: './nois"
},
{
"path": "egs/lj/eval.list",
"chars": 1500,
"preview": "LJ001-0001.wav\nLJ001-0047.wav\nLJ003-0140.wav\nLJ003-0145.wav\nLJ003-0235.wav\nLJ003-0240.wav\nLJ004-0195.wav\nLJ005-0059.wav\n"
},
{
"path": "egs/lj/generate.sh",
"chars": 157,
"preview": "#!/bin/bash\n\n## example usage (single GPU): sh generate.sh 0 conf.yml\n\ncuda=$1\n\nCUDA_VISIBLE_DEVICES=$cuda python3 main."
},
{
"path": "egs/lj/main.py",
"chars": 4403,
"preview": "import os\nimport sys\nsys.path.append('../../')\nimport json\nimport shutil\nimport hashlib\nimport argparse\nimport numpy as "
},
{
"path": "egs/lj/schedule.sh",
"chars": 157,
"preview": "#!/bin/bash\n\n## example usage (single GPU): sh schedule.sh 0 conf.yml\n\ncuda=$1\n\nCUDA_VISIBLE_DEVICES=$cuda python3 main."
},
{
"path": "egs/lj/train.sh",
"chars": 264,
"preview": "#!/bin/bash\n\n## example usage: sh train.sh 0,1,2,3 conf.yml\n\ncuda=$1\ncomma=${cuda//[^,]}\nnproc=$((${#comma}+1))\n\nCUDA_VI"
},
{
"path": "egs/lj/valid.list",
"chars": 3540,
"preview": "LJ001-0023.wav\nLJ001-0084.wav\nLJ001-0144.wav\nLJ001-0163.wav\nLJ001-0179.wav\nLJ002-0042.wav\nLJ002-0140.wav\nLJ002-0201.wav\n"
},
{
"path": "egs/vctk/conf.yml",
"chars": 1442,
"preview": "# General config\nexp_dir: './exp'\nseed: 0\nload: 'DiffWave_GALR.pkl'\n# Generation config\nsampling_noise_schedule: './nois"
},
{
"path": "egs/vctk/generate.sh",
"chars": 157,
"preview": "#!/bin/bash\n\n## example usage (single GPU): sh generate.sh 0 conf.yml\n\ncuda=$1\n\nCUDA_VISIBLE_DEVICES=$cuda python3 main."
},
{
"path": "egs/vctk/main.py",
"chars": 4403,
"preview": "import os\nimport sys\nsys.path.append('../../')\nimport json\nimport shutil\nimport hashlib\nimport argparse\nimport numpy as "
},
{
"path": "egs/vctk/schedule.sh",
"chars": 157,
"preview": "#!/bin/bash\n\n## example usage (single GPU): sh schedule.sh 0 conf.yml\n\ncuda=$1\n\nCUDA_VISIBLE_DEVICES=$cuda python3 main."
},
{
"path": "egs/vctk/train.sh",
"chars": 264,
"preview": "#!/bin/bash\n\n## example usage: sh train.sh 0,1,2,3 conf.yml\n\ncuda=$1\ncomma=${cuda//[^,]}\nnproc=$((${#comma}+1))\n\nCUDA_VI"
}
]
// ... and 12 more files (download for full content)
About this extraction
This page contains the full source code of the tencent-ailab/bddm GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 43 files (70.9 MB), approximately 34.4k tokens, and a symbol index with 112 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.