Full Code of tencent-ailab/bddm for AI

main 2cebe0e6b7fd cached
43 files
70.9 MB
34.4k tokens
112 symbols
1 requests
Download .txt
Repository: tencent-ailab/bddm
Branch: main
Commit: 2cebe0e6b7fd
Files: 43
Total size: 70.9 MB

Directory structure:
gitextract_wr3lh52z/

├── .gitignore
├── LICENSE
├── README.md
├── bddm/
│   ├── loader/
│   │   ├── __init__.py
│   │   ├── dataset.py
│   │   └── stft.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── diffwave.py
│   │   └── galr.py
│   ├── sampler/
│   │   ├── __init__.py
│   │   └── sampler.py
│   ├── trainer/
│   │   ├── __init__.py
│   │   ├── bmuf.py
│   │   ├── ema.py
│   │   ├── loss.py
│   │   └── trainer.py
│   └── utils/
│       ├── check_utils.py
│       ├── diffusion_utils.py
│       └── log_utils.py
└── egs/
    ├── lj/
    │   ├── DiffWave.pkl
    │   ├── DiffWave_GALR.pkl
    │   ├── conf.yml
    │   ├── demo_case/
    │   │   ├── LJ001-0001.mel
    │   │   ├── LJ004-0233.mel
    │   │   └── LJ050-0270.mel
    │   ├── eval.list
    │   ├── generate.sh
    │   ├── main.py
    │   ├── noise_schedules/
    │   │   ├── 12steps_PESQ4.07_STOI0.978.ns
    │   │   └── 7steps_PESQ4.05_STOI0.975.ns
    │   ├── schedule.sh
    │   ├── train.sh
    │   └── valid.list
    └── vctk/
        ├── DiffWave.pkl
        ├── DiffWave_GALR.pkl
        ├── conf.yml
        ├── demo_case/
        │   ├── p259_311.mel
        │   └── p287_123.mel
        ├── generate.sh
        ├── main.py
        ├── noise_schedules/
        │   └── 8steps_PESQ3.88_STOI0.987.ns
        ├── schedule.sh
        └── train.sh

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
egs/lj/exp
egs/lj/LJSpeech1.1
*.log
*.pyc
__pycache__/
*/__pycache__/
*.zip
*/*.zip
source/
make.bat
Makefile
*.swp


================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright 2022 Tencent

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
# Bilateral Denoising Diffusion Models (BDDMs)

[![GitHub Stars](https://img.shields.io/github/stars/tencent-ailab/bddm?style=social)](https://github.com/tencent-ailab/bddm)
![visitors](https://visitor-badge.glitch.me/badge?page_id=tencent-ailab/bddm)
[![arXiv](https://img.shields.io/badge/arXiv-Paper-green.svg)](https://arxiv.org/abs/2203.13508)
[![demo](https://img.shields.io/badge/demo-Samples-orange.svg)](https://bilateral-denoising-diffusion-model.github.io)

This is the official PyTorch implementation of the following paper:

> **BDDM: BILATERAL DENOISING DIFFUSION MODELS FOR FAST AND HIGH-QUALITY SPEECH SYNTHESIS** \
> Max W. Y. Lam, Jun Wang, Dan Su, Dong Yu

> **Abstract**: *Diffusion probabilistic models (DPMs) and their extensions have emerged as competitive generative models yet confront challenges of efficient sampling. We propose a new bilateral denoising diffusion model (BDDM) that parameterizes both the forward and reverse processes with a schedule network and a score network, which can train with a novel bilateral modeling objective. We show that the new surrogate objective can achieve a lower bound of the log marginal likelihood tighter than a conventional surrogate. We also find that BDDM allows inheriting pre-trained score network parameters from any DPMs and consequently enables speedy and stable learning of the schedule network and optimization of a noise schedule for sampling. Our experiments demonstrate that BDDMs can generate high-fidelity audio samples with as few as three sampling steps. Moreover, compared to other state-of-the-art diffusion-based neural vocoders, BDDMs produce comparable or higher quality samples indistinguishable from human speech, notably with only seven sampling steps (143x faster than WaveGrad and 28.6x faster than DiffWave).*

> **Paper**: Published at ICLR 2022 on [OpenReview](https://openreview.net/pdf?id=L7wzpQttNO)

![BDDM](bddm.png)

This implementation supports model training and audio generation, and also provides the pre-trained models for the benchmark [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) and [VCTK](https://datashare.ed.ac.uk/handle/10283/2651) dataset.

Visit our [demo page](https://bilateral-denoising-diffusion-model.github.io) for audio samples.

### Updates:
- **May 20, 2021:** Released our follow-up work [FastDiff](https://github.com/Rongjiehuang/FastDiff) on GitHub, where we futher optimized the speed-and-quality trade-off.
- **May 10, 2021:** Added the experiment configurations and model checkpoints for the VCTK dataset.
- **May 9, 2021:** Added the searched noise schedules for the LJSpeech and VCTK datasets.
- **March 20, 2021:** Released the PyTorch implementation of BDDM with pre-trained models for the LJSpeech dataset.

### Recipes:

- (Option 1) To train the BDDM scheduling network yourself, you can download the pre-trained score network from [philsyn/DiffWave-Vocoder](https://github.com/philsyn/DiffWave-Vocoder/blob/master/exp/ch128_T200_betaT0.02/logs/checkpoint/1000000.pkl) (provided at ```egs/lj/DiffWave.pkl```), and follow the training steps below. **(Start from Step I.)**
- (Option 2) To search for noise schedules using BDDM, we provide a pre-trained BDDM for LJSpeech at ```egs/lj/DiffWave-GALR.pkl``` and for VCTK at ```egs/vctk/DiffWave-GALR.pkl``` . **(Start from Step III.)**
- (Option 3) To directly generate samples using BDDM, we provide the searched schedules for LJSpeech at ```egs/lj/noise_schedules``` and for VCTK at ```egs/vctk/noise_schedules``` (check ```conf.yml``` for the respective configurations). **(Start from Step IV.)**


## Getting Started

We provide an example of how you can generate high-fidelity samples using BDDMs.

To try BDDM on your own dataset, simply clone this repo in your local machine provided with NVIDIA GPU + CUDA cuDNN and follow the below intructions.

### Dependencies

- [pytorch](https://github.com/pytorch/pytorch)>=1.7.1
- [librosa](https://github.com/librosa/librosa)>=0.7.1
- [pystoi](https://github.com/mpariente/pystoi)==0.3.3
- [pypesq](https://github.com/youngjamespark/python-pypesq)==0.2.0

### Step I. Data Preparation and Configuraion ### 

Download the [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) dataset.

For training, we first need to setup a file **conf.yml** for configuring the data loader, the score and the schedule networks, the training procedure, the noise scheduling and sampling parameters.

**Note:** Appropriately modify the paths in ```"train_data_dir"``` and ```"valid_data_dir"``` for training; and the path in ```"gen_data_dir"``` for sampling. All dir paths should be link to a directory that store the waveform audios (in **.wav**) or the Mel-spectrogram files (in **.mel**).

### Step II. Training a Schedule Network ###

Suppose that a well-trained score network (theta) is stored at ```$theta_path```, we start by modifying ```"load": $theta_path``` in **conf.yml**.

After modifying the relevant hyperparameters for a schedule network (especially ```"tau"```), we can train the schedule network (f_phi in paper) using:

```bash
# Training on device 0 (supports multi-GPU training)
sh train.sh 0 conf.yml
```

**Note**: In practice, we found that **10K** training steps would be enough to obtain a promising scheduling network. This normally takes no more than half an hour for training with one GPU.

### Step III. Searching for Noise Schedules ###

Given a well-trained BDDM (theta, phi), we can now run the noise scheduling algorithm to find the best schedule (optimizing the trade-off between quality and speed).

First, we set ```"load"``` in ```conf.yml``` to the path of the trained BDDM.

After setting the maximum number of sampling steps in scheduling (```"N"```), we run:

```bash
# Scheduling on device 0 (only supports single-GPU scheduling)
sh schedule.sh 0 conf.yml
```

### Step IV. Evaluation or Generation ###

For evaluation, we set ```"gen_data_dir"``` in ```conf.yml``` to the path of a directory that stores the test set of audios (in ```.wav```).

For generation, we set ```"gen_data_dir"``` in ```conf.yml``` to the path of a directory that stores the Mel-spectrogram (by default in ```.mel``` generated by [TacotronSTFT](https://github.com/NVIDIA/tacotron2/blob/master/layers.py) or by our dataset loader ```bddm/loader/dataset.py```).

Then, we run:

```bash
# Generation/evaluation on device 0 (only supports single-GPU generation)
sh generate.sh 0 conf.yml
```

## Acknowledgements
This implementation uses parts of the code from the following Github repos:\
[Tacotron2](https://github.com/NVIDIA/tacotron2)\
[DiffWave-Vocoder](https://github.com/philsyn/DiffWave-Vocoder)\
as described in our code.

## Citations ##

```
@inproceedings{lam2022bddm,
  title={BDDM: Bilateral Denoising Diffusion Models for Fast and High-Quality Speech Synthesis},
  author={Lam, Max WY and Wang, Jun and Su, Dan and Yu, Dong},
  booktitle={International Conference on Learning Representations},
  year={2022}
}
```

## License ##

Copyright 2022 Tencent

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

## Disclaimer ##

This is not an officially supported Tencent product.


================================================
FILE: bddm/loader/__init__.py
================================================
#!/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
#  Loader
#
#  Author: Max W. Y. Lam (maxwylam@tencent.com)
#  Copyright (c) 2021Tencent. All Rights Reserved
#
########################################################################


================================================
FILE: bddm/loader/dataset.py
================================================
#!/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
#  Dataset and DataLoader for Neural Vocoding
#
#  Author: Max W. Y. Lam (maxwylam@tencent.com)
#  Copyright (c) 2021Tencent. All Rights Reserved
#
########################################################################


import os
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.data as data
from scipy.io.wavfile import read

from .stft import TacotronSTFT


MAX_WAV_VALUE = 32768


class SpectrogramDataset(data.Dataset):

    def __init__(self, data_dir, n_gpus, is_sampling, sampling_rate,
            seg_len, fil_len, hop_len, win_len, mel_fmin, mel_fmax):
        """
        A torch.data.Dataset class that loads the audio files for training
            or loads the Mel-spectrogram files for sampling.

        Parameters:
            data_dir (str):      the path to the directory storing .wav/.mel files
            n_gpus (int):        the number of GPUs for training
            is_sampling (bool):  whether the dataset is used for sampling or not
            sampling_rate (int): the sampling rate of audios
            seg_len (int):       the segment length (number of samples) for training
            fil_len (int):       the filter length for computing STFT
            hop_len (int):       the hop length for computing STFT
            win_len (int):       the window length for computing STFT
            mel_fmin (int):      the minimum frequency for computing STFT
            mel_fmax (int):      the maximum frequency for computing STFT
        """
        self.n_gpus = n_gpus
        self.is_sampling = is_sampling
        self.seg_len = seg_len
        self.hop_len = hop_len
        self.n_mels = self.seg_len // self.hop_len
        self.sampling_rate = sampling_rate

        if is_sampling:
            # Find all Mel-spectrogram files in the given data directory
            self.mel_files = self.find_all_mels_in_dir(data_dir)
            if len(self.mel_files) == 0:
                # Find audios when no pre-computed mel spectrograms can be found
                self.audio_files = self.find_all_wavs_in_dir(data_dir)
                # Note that no mel file is loaded for generation
                self.mel_files = None
            else:
                # Note that no audio file is loaded for generation
                self.audio_files = None
        else:
            # Find all audio files in the given data directory
            self.audio_files = self.find_all_wavs_in_dir(data_dir)

        # Use the standard STFT operation defined in Tacotron 2
        self.stft = TacotronSTFT(filter_length=fil_len,
                                 hop_length=hop_len,
                                 win_length=win_len,
                                 sampling_rate=sampling_rate,
                                 mel_fmin=mel_fmin,
                                 mel_fmax=mel_fmax)
        self.reset()

    def reset(self):
        """
        Reset the loader by shuffling the file list
        """
        if self.is_sampling and self.audio_files is None:
            np.random.shuffle(self.mel_files)
        else:
            np.random.shuffle(self.audio_files)
            self.n_mels = self.seg_len // self.hop_len
            self.n_mels = int(self.n_mels * (1 + np.random.rand()))
            # Make sure the number of samples is divisible by n_gpus
            if len(self.audio_files) % self.n_gpus != 0:
                remainder = len(self.audio_files) % self.n_gpus
                self.audio_files = self.audio_files[:-remainder]

    def find_all_wavs_in_dir(self, data_dir):
        """
        Load all .wav files in data_dir

        Parameters:
            data_dir (str):   the path to the directory storing .wav files
        Returns:
            files_list (list): the list of wav file paths
        """
        files = [f for f in Path(data_dir).glob('*.wav')]
        if len(files) == 0:
            files = [f for f in Path(data_dir).glob('*_wav.npy')]
        return files

    def find_all_mels_in_dir(self, data_dir):
        """
        Load all .mel files in data_dir

        Parameters:
            data_dir (str):   the path to the directory storing .mel files
        Returns:
            files_list (list): the list of mel file paths
        """
        files = [f for f in Path(data_dir).glob('*.mel')]
        return files

    def crop_audio_and_mel(self, audio, mel_spec):
        """
        Randomly crop audio and mel_spec into a fixed-length segment

        Parameters:
            audio (tensor):    the full audio
            mel_spec (tensor): the full mel-spectrogram computed using TacotronSTFT
        Returns:
            audio (tensor):    the cropped audio
            mel_spec (tensor): the cropped mel-spectrogram
        """
        n_mels = self.n_mels
        seg_len = n_mels * self.hop_len
        if audio.size(-1) >= seg_len:
            if mel_spec.size(-1) > n_mels:
                max_mel_start = mel_spec.size(-1) - n_mels
                mel_start = np.random.randint(0, max_mel_start)
                mel_spec = mel_spec[..., mel_start:mel_start+n_mels]
                audio_start = mel_start * self.hop_len
                audio = audio[..., audio_start:audio_start+seg_len]
            elif mel_spec.size(-1) == n_mels:
                audio = audio[..., :seg_len]
            else:
                audio = audio[..., :seg_len]
                mel_spec = F.pad(mel_spec, (0, n_mels - mel_spec.size(-1)), 'constant', 0)
        else:
            audio = F.pad(audio, (0, seg_len - audio.size(-1)), 'constant', 0)
            mel_spec = F.pad(mel_spec, (0, n_mels - mel_spec.size(-1)), 'constant', 0)
        return audio, mel_spec

    def __getitem__(self, index):
        """
        Get a pair of data (mel-spectrogram, audio) given an index

        Parameters:
            index (int):       the index for loading one sample
        Returns:
            audio_key (str):   the audio key
            mel_spec (tensor): the mel-spectrogram computed using TacotronSTFT
            audio (tensor):    the ground-truth audio
        """
        if self.is_sampling and self.audio_files is None:
            # Load Mel-spectrogram for sampling
            mel_spec = torch.load(self.mel_files[index], map_location='cpu').float()
            if mel_spec.ndim == 3:
                mel_spec = mel_spec[0]
            audio_key = str(self.mel_files[index])
            # Try to find the paired source audio
            wav_path = str(self.mel_files[index])[:-4]+'.wav'
            if os.path.isfile(wav_path):
                sampling_rate, audio = read(wav_path)
                audio = torch.from_numpy(audio[None]).float() / MAX_WAV_VALUE
                assert sampling_rate == self.sampling_rate
                return audio_key, mel_spec, audio
            else:
                return audio_key, mel_spec, []

        mel_spec = None
        # Load the audio file to torch.FloatTensor
        if str(self.audio_files[index])[-3:] == 'npy':
            audio = np.load(self.audio_files[index])[0]
            mel_spec = np.load(str(self.audio_files[index]).replace('wav', 'mel'))[0]
            # Load into torch.FloatTensor
            audio = torch.from_numpy(audio).float()
            mel_spec = torch.from_numpy(mel_spec).float()
        else:
            sampling_rate, audio = read(self.audio_files[index])
            # Make sure the sampling rate is correctly defined
            assert sampling_rate == self.sampling_rate
            # Normalize the audio into [-1, 1]
            audio = audio / MAX_WAV_VALUE
            # Load into torch.FloatTensor
            audio = torch.from_numpy(audio).float()
        # Compute Mel-spectrogram (shape = [T] -> [filter_length, L])
        if mel_spec is None:
            audio = audio[None]
            mel_spec = self.stft.mel_spectrogram(audio)
            mel_spec = torch.squeeze(mel_spec, 0)

        if not self.is_sampling:
            audio, mel_spec = self.crop_audio_and_mel(audio, mel_spec)

        if self.is_sampling:
            # Save ground-truth Mel-spectrogram into .mel file
            torch.save(mel_spec, str(self.audio_files[index])[:-4]+'.mel')
            return str(self.audio_files[index]), mel_spec, audio

        return mel_spec, audio

    def __len__(self):
        """
        Get the number of data

        Returns:
            data_len (int): the number of .wav/.mel files found in data_dir
        """
        if self.audio_files is None:
            return len(self.mel_files)
        return len(self.audio_files)


def create_train_and_valid_dataloader(config):
    """
    Create two torch.data.DataLoader for training and validation

    Parameters:
        config (namespace):     BDDM Configuration
    Returns:
        tr_loader (DataLoader): the data loader for training
        vl_loader (DataLoader): the data loader for validation
    """
    n_gpus = 1
    if 'WORLD_SIZE' in os.environ.keys():
        n_gpus = int(os.environ['WORLD_SIZE'])
    conf_keys = SpectrogramDataset.__init__.__code__.co_varnames
    data_config = {k: v for k, v in vars(config).items() if k in conf_keys}
    data_config["data_dir"] = config.train_data_dir
    data_config["is_sampling"] = False
    data_config["n_gpus"] = n_gpus
    dataset = SpectrogramDataset(**data_config)
    assert len(dataset) > 0, f"Error: No .wav can be found at {config.train_data_dir} !"
    sampler = data.distributed.DistributedSampler(dataset) if n_gpus > 1 else None
    tr_loader = data.DataLoader(dataset,
                                sampler=sampler,
                                batch_size=config.batch_size,
                                num_workers=config.n_worker,
                                pin_memory=False,
                                drop_last=True)
    data_config["data_dir"] = config.valid_data_dir
    dataset = SpectrogramDataset(**data_config)
    assert len(dataset) > 0, f"Error: No .wav can be found at {config.valid_data_dir} !"
    sampler = data.distributed.DistributedSampler(dataset) if n_gpus > 1 else None
    vl_loader = data.DataLoader(dataset,
                                sampler=sampler,
                                batch_size=1,
                                num_workers=config.n_worker,
                                pin_memory=False)
    return tr_loader, vl_loader


def create_generation_dataloader(config):
    """
    Create a torch.data.DataLoader for generation

    Parameters:
        config (namespace):      BDDM Configuration
    Returns:
        gen_loader (DataLoader): the data loader for generation
    """
    conf_keys = SpectrogramDataset.__init__.__code__.co_varnames
    data_config = {k: v for k, v in vars(config).items() if k in conf_keys}
    data_config["data_dir"] = config.gen_data_dir
    data_config["is_sampling"] = True
    data_config["n_gpus"] = 1
    dataset = SpectrogramDataset(**data_config)
    gen_loader = data.DataLoader(dataset,
                                 batch_size=1,  # variable audio length
                                 num_workers=config.n_worker,
                                 pin_memory=False)
    return gen_loader


if __name__ == "__main__":
    from argparse import Namespace
    data_config = {
        "data_dir": "LJSpeech-1.1/train_wavs",
        "sampling_rate": 22050,
        "training": True,
        "seg_len": 22050,
        "fil_len": 1024,
        "hop_len": 256,
        "win_len": 1024,
        "mel_fmin": 0.0,
        "mel_fmax": 8000.0,
        "extra_key": "extra_val"
    }
    data_config = Namespace(**data_config)
    loader = create_train_and_valid_dataloader(data_config)
    for batch in loader:
        print(len(batch))
        print(batch[0].shape, batch[1].shape)
        break


================================================
FILE: bddm/loader/stft.py
================================================
"""
BSD 3-Clause License

Copyright (c) 2017, Prem Seetharaman
All rights reserved.

* Redistribution and use in source and binary forms, with or without
  modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice,
  this list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice, this
  list of conditions and the following disclaimer in the
  documentation and/or other materials provided with the distribution.

* Neither the name of the copyright holder nor the names of its
  contributors may be used to endorse or promote products derived from this
  software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from scipy.signal import get_window
from librosa.filters import mel as mel_func
from librosa.util import pad_center


class STFT(nn.Module):
    """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
    def __init__(self, filter_length=800, hop_length=200, win_length=800,
                 window='hann'):
        super(STFT, self).__init__()
        self.filter_length = filter_length
        self.hop_length = hop_length
        self.win_length = win_length
        self.window = window
        self.forward_transform = None
        scale = self.filter_length / self.hop_length
        fourier_basis = np.fft.fft(np.eye(self.filter_length))

        cutoff = int((self.filter_length / 2 + 1))
        fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
                                   np.imag(fourier_basis[:cutoff, :])])

        forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
        inverse_basis = torch.FloatTensor(
            np.linalg.pinv(scale * fourier_basis).T[:, None, :])

        if window is not None:
            assert(filter_length >= win_length)
            # get window and zero center pad it to filter_length
            fft_window = get_window(window, win_length, fftbins=True)
            fft_window = pad_center(fft_window, filter_length)
            fft_window = torch.from_numpy(fft_window).float()

            # window the bases
            forward_basis *= fft_window
            inverse_basis *= fft_window

        self.register_buffer('forward_basis', forward_basis.float())
        self.register_buffer('inverse_basis', inverse_basis.float())

    def transform(self, input_data):
        num_batches = input_data.size(0)
        num_samples = input_data.size(1)

        self.num_samples = num_samples

        # similar to librosa, reflect-pad the input
        input_data = input_data.view(num_batches, 1, num_samples)
        input_data = F.pad(
            input_data.unsqueeze(1),
            (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
            mode='reflect')
        input_data = input_data.squeeze(1)

        forward_transform = F.conv1d(
            input_data,
            Variable(self.forward_basis, requires_grad=False),
            stride=self.hop_length,
            padding=0)

        cutoff = int((self.filter_length / 2) + 1)
        real_part = forward_transform[:, :cutoff, :]
        imag_part = forward_transform[:, cutoff:, :]

        magnitude = torch.sqrt(real_part**2 + imag_part**2)
        phase = torch.autograd.Variable(
            torch.atan2(imag_part.data, real_part.data))

        return magnitude, phase

    def forward(self, input_data):
        self.magnitude, self.phase = self.transform(input_data)
        reconstruction = self.inverse(self.magnitude, self.phase)
        return reconstruction


class TacotronSTFT(nn.Module):
    """
    Adapted from https://github.com/NVIDIA/tacotron2/blob/master/layers.py
    """
    def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
                 n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
                 mel_fmax=8000.0):
        super(TacotronSTFT, self).__init__()
        self.n_mel_channels = n_mel_channels
        self.sampling_rate = sampling_rate
        self.stft_fn = STFT(filter_length, hop_length, win_length)
        mel_basis = mel_func(
            sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
        mel_basis = torch.from_numpy(mel_basis).float()
        self.register_buffer('mel_basis', mel_basis)

    def spectral_normalize(self, x):
        return torch.log(torch.clamp(x, min=1e-5))

    def spectral_de_normalize(self, x):
        return torch.exp(x)

    def mel_spectrogram(self, y):
        """
        Compute mel-spectrograms from a batch of waves
        """
        assert(torch.min(y.data) >= -1)
        assert(torch.max(y.data) <= 1)

        magnitudes, _ = self.stft_fn.transform(y)
        magnitudes = magnitudes.data
        mel_output = torch.matmul(self.mel_basis, magnitudes)
        mel_output = self.spectral_normalize(mel_output)
        return mel_output


================================================
FILE: bddm/models/__init__.py
================================================
#!/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
#  Globally Attentive Locally Recurrent (GALR) Networks
#  (https://arxiv.org/abs/2101.05014)
#
#  Author: Max W. Y. Lam (maxwylam@tencent.com)
#  Copyright (c) 2021Tencent. All Rights Reserved
#
########################################################################


from .diffwave import DiffWave
from .galr import GALR


def get_score_network(config):
    if config.score_net == 'DiffWave':
        conf_keys = DiffWave.__init__.__code__.co_varnames
        model_config = {k: v for k, v in vars(config).items() if k in conf_keys}
        return DiffWave(**model_config)


def get_schedule_network(config):
    if config.schedule_net == 'GALR':
        conf_keys = GALR.__init__.__code__.co_varnames
        model_config = {k: v for k, v in vars(config).items() if k in conf_keys}
        return GALR(**model_config)


================================================
FILE: bddm/models/diffwave.py
================================================
#!/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
#  DiffWave: A Versatile Diffusion Model for Audio Synthesis
#  (https://arxiv.org/abs/2009.09761)
#  Modified from https://github.com/philsyn/DiffWave-Vocoder
#
#  Author: Max W. Y. Lam (maxwylam@tencent.com)
#  Copyright (c) 2021Tencent. All Rights Reserved
#
########################################################################


import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in):
    """
    Embed a diffusion step $t$ into a higher dimensional space
        E.g. the embedding vector in the 128-dimensional space is
        [sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)),
         cos(t * 10^(0*4/63)), ... , cos(t * 10^(63*4/63))]

    Parameters:
        diffusion_steps (torch.long tensor, shape=(batchsize, 1)):
                                    diffusion steps for batch data
        diffusion_step_embed_dim_in (int, default=128):
                                    dimensionality of the embedding space for discrete diffusion steps
    Returns:
        the embedding vectors (torch.tensor, shape=(batchsize, diffusion_step_embed_dim_in)):
    """

    assert diffusion_step_embed_dim_in % 2 == 0

    half_dim = diffusion_step_embed_dim_in // 2
    _embed = np.log(10000) / (half_dim - 1)
    _embed = torch.exp(torch.arange(half_dim) * -_embed).cuda()
    _embed = diffusion_steps * _embed
    diffusion_step_embed = torch.cat((torch.sin(_embed),
                                      torch.cos(_embed)), 1)
    return diffusion_step_embed


"""
Below scripts were borrowed from
https://github.com/philsyn/DiffWave-Vocoder/blob/master/WaveNet.py
"""


def swish(x):
    return x * torch.sigmoid(x)


# dilated conv layer with kaiming_normal initialization
# from https://github.com/ksw0306/FloWaveNet/blob/master/modules.py
class Conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1):
        super().__init__()
        self.padding = dilation * (kernel_size - 1) // 2
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size,
                              dilation=dilation, padding=self.padding)
        self.conv = nn.utils.weight_norm(self.conv)
        nn.init.kaiming_normal_(self.conv.weight)

    def forward(self, x):
        out = self.conv(x)
        return out


# conv1x1 layer with zero initialization
# from https://github.com/ksw0306/FloWaveNet/blob/master/modules.py but the scale parameter is removed
class ZeroConv1d(nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__()
        self.conv = nn.Conv1d(in_channel, out_channel, kernel_size=1, padding=0)
        self.conv.weight.data.zero_()
        self.conv.bias.data.zero_()

    def forward(self, x):
        out = self.conv(x)
        return out


# every residual block (named residual layer in paper)
# contains one noncausal dilated conv
class ResidualBlock(nn.Module):
    def __init__(self, res_channels, skip_channels, dilation,
                 diffusion_step_embed_dim_out):
        super().__init__()
        self.res_channels = res_channels

        # Use a FC layer for diffusion step embedding
        self.fc_t = nn.Linear(diffusion_step_embed_dim_out, self.res_channels)

        # Dilated conv layer
        self.dilated_conv_layer = Conv(self.res_channels, 2 * self.res_channels,
                                       kernel_size=3, dilation=dilation)

        # Add mel spectrogram upsampler and conditioner conv1x1 layer
        self.upsample_conv2d = nn.ModuleList()
        for s in [16, 16]:
            conv_trans2d = nn.ConvTranspose2d(1, 1, (3, 2 * s),
                                              padding=(1, s // 2),
                                              stride=(1, s))
            conv_trans2d = nn.utils.weight_norm(conv_trans2d)
            nn.init.kaiming_normal_(conv_trans2d.weight)
            self.upsample_conv2d.append(conv_trans2d)

        # 80 is mel bands
        self.mel_conv = Conv(80, 2 * self.res_channels, kernel_size=1)

        # Residual conv1x1 layer, connect to next residual layer
        self.res_conv = nn.Conv1d(res_channels, res_channels, kernel_size=1)
        self.res_conv = nn.utils.weight_norm(self.res_conv)
        nn.init.kaiming_normal_(self.res_conv.weight)

        # Skip conv1x1 layer, add to all skip outputs through skip connections
        self.skip_conv = nn.Conv1d(res_channels, skip_channels, kernel_size=1)
        self.skip_conv = nn.utils.weight_norm(self.skip_conv)
        nn.init.kaiming_normal_(self.skip_conv.weight)

    def forward(self, input_data):
        x, mel_spec, diffusion_step_embed = input_data
        h = x
        batch_size, n_channels, seq_len = x.shape
        assert n_channels == self.res_channels

        # Add in diffusion step embedding
        part_t = self.fc_t(diffusion_step_embed)
        part_t = part_t.view([batch_size, self.res_channels, 1])
        h += part_t

        # Dilated conv layer
        h = self.dilated_conv_layer(h)

        # Upsample spectrogram to size of audio
        mel_spec = torch.unsqueeze(mel_spec, dim=1)
        mel_spec = F.leaky_relu(self.upsample_conv2d[0](mel_spec), 0.4, inplace=False)
        mel_spec = F.leaky_relu(self.upsample_conv2d[1](mel_spec), 0.4, inplace=False)
        mel_spec = torch.squeeze(mel_spec, dim=1)

        assert mel_spec.size(2) >= seq_len
        if mel_spec.size(2) > seq_len:
            mel_spec = mel_spec[:, :, :seq_len]

        mel_spec = self.mel_conv(mel_spec)
        h += mel_spec

        # Gated-tanh nonlinearity
        out = torch.tanh(h[:, :self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels:, :])

        # Residual and skip outputs
        res = self.res_conv(out)
        assert x.shape == res.shape
        skip = self.skip_conv(out)

        # Normalize for training stability
        return (x + res) * math.sqrt(0.5), skip


class ResidualGroup(nn.Module):
    def __init__(self, res_channels, skip_channels, num_res_layers, dilation_cycle,
                 diffusion_step_embed_dim_in,
                 diffusion_step_embed_dim_mid,
                 diffusion_step_embed_dim_out):
        super().__init__()
        self.num_res_layers = num_res_layers
        self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in

        # Use the shared two FC layers for diffusion step embedding
        self.fc_t1 = nn.Linear(diffusion_step_embed_dim_in, diffusion_step_embed_dim_mid)
        self.fc_t2 = nn.Linear(diffusion_step_embed_dim_mid, diffusion_step_embed_dim_out)

        # Stack all residual blocks with dilations 1, 2, ... , 512, ... , 1, 2, ..., 512
        self.residual_blocks = nn.ModuleList()
        for n in range(self.num_res_layers):
            self.residual_blocks.append(
                ResidualBlock(res_channels, skip_channels,
                               dilation=2 ** (n % dilation_cycle),
                               diffusion_step_embed_dim_out=diffusion_step_embed_dim_out))

    def forward(self, input_data):
        x, mel_spectrogram, diffusion_steps = input_data

        # Embed diffusion step t
        diffusion_step_embed = calc_diffusion_step_embedding(
            diffusion_steps, self.diffusion_step_embed_dim_in)
        diffusion_step_embed = swish(self.fc_t1(diffusion_step_embed))
        diffusion_step_embed = swish(self.fc_t2(diffusion_step_embed))

        # Pass all residual layers
        h = x
        skip = 0
        for n in range(self.num_res_layers):
            # Use the output from last residual layer
            h, skip_n = self.residual_blocks[n]((h, mel_spectrogram, diffusion_step_embed))
            # Accumulate all skip outputs
            skip += skip_n

        # Normalize for training stability
        return skip * math.sqrt(1.0 / self.num_res_layers)


class DiffWave(nn.Module):
    def __init__(self, in_channels, res_channels, skip_channels, out_channels,
                 num_res_layers, dilation_cycle,
                 diffusion_step_embed_dim_in,
                 diffusion_step_embed_dim_mid,
                 diffusion_step_embed_dim_out):
        super().__init__()

        # Initial conv1x1 with relu
        self.init_conv = nn.Sequential(Conv(in_channels, res_channels, kernel_size=1), nn.ReLU(inplace=False))
        # All residual layers
        self.residual_layer = ResidualGroup(res_channels,
                                            skip_channels,
                                            num_res_layers,
                                            dilation_cycle,
                                            diffusion_step_embed_dim_in,
                                            diffusion_step_embed_dim_mid,
                                            diffusion_step_embed_dim_out)
        # Final conv1x1 -> relu -> zeroconv1x1
        self.final_conv = nn.Sequential(Conv(skip_channels, skip_channels, kernel_size=1),
                                        nn.ReLU(inplace=False), ZeroConv1d(skip_channels, out_channels))

    def forward(self, input_data):
        audio, mel_spectrogram, diffusion_steps = input_data
        x = audio
        x = self.init_conv(x).clone()
        x = self.residual_layer((x, mel_spectrogram, diffusion_steps))
        return self.final_conv(x)


================================================
FILE: bddm/models/galr.py
================================================
#!/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
#  Globally Attentive Locally Recurrent (GALR) Networks
#  (https://arxiv.org/abs/2101.05014)
#
#  Author: Max W. Y. Lam (maxwylam@tencent.com)
#  Copyright (c) 2021Tencent. All Rights Reserved
#
########################################################################


import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable


class Encoder(nn.Module):

    def __init__(self, enc_dim, win_len):
        """
        1D Convoluation based Waveform Encoder

        Parameters:
            enc_dim (int): Dimension of each frame (e.g. choice in paper: ``128``).
            win_len (int): Window length for processing raw signal samples (e.g. a
                common choice: ``16``). By default, the windows are half-overlapping.
        """
        super().__init__()
        # 1D convolutional layer
        self.enc_conv = nn.Conv1d(1, enc_dim,
                                  kernel_size=win_len,
                                  stride=win_len//2, bias=False)

    def forward(self, signals):
        """
        Non-linearly encode signals from raw waveform to frame sequences.

        Parameters:
            signals (tensor): A batch of signals in shape `[B, T]`, where `T` is the
                maximum length of these `B` signals.
        Returns:
            frames (tensor): A batch of encoded feature (a.k.a. frames) sequences in shape
                `[B, D, L]`, where `B` is the batch size, `D` is the encoded feature
                dimension (enc_dim), and `L` is the length of this feature sequence.
        """
        frames = F.relu(self.enc_conv(signals.unsqueeze(1)))
        return frames


class BiLSTMproj(nn.Module):

    def __init__(self, enc_dim, hid_dim):
        """
        Locally Recurrent Layer (Sec 2.2.1 in https://arxiv.org/abs/2101.05014).
            It consists of a bi-directional LSTM followed by a linear projection.

        Parameters:
            enc_dim (int): Dimension of each frame (e.g. choice in paper: ``128``).
            hid_dim (int): Number of hidden nodes used in the Bi-LSTM.
        """
        super().__init__()
        # Bi-LSTM with learnable (h_0, c_0) state
        self.rnn = nn.LSTM(enc_dim, hid_dim,
                           1, dropout=0, batch_first=True, bidirectional=True)
        self.cell_init = nn.Parameter(torch.rand(1, 1, hid_dim))
        self.hidden_init = nn.Parameter(torch.rand(1, 1, hid_dim))

        # Linear projection layer
        self.proj = nn.Linear(hid_dim * 2, enc_dim)

    def forward(self, intra_segs):
        """
        Process through a locally recurrent layer along the intra-segment
            direction.

        Parameters:
        	frames (tensor): A batch of intra-segments in shape `[B*S, K, D]`, where
                `B` is the batch size, `S` is the number of segments, 'K' is the
                segment length (seg_len) and `D` is the feature dimension (enc_dim).
        Returns:
            lr_output (tensor): A batch of processed segments with the same shape as the input.
        """
        batch_size_seq_len = intra_segs.size(0)
        cell = self.cell_init.repeat(2, batch_size_seq_len, 1)
        hidden = self.hidden_init.repeat(2, batch_size_seq_len, 1)
        rnn_output, _ = self.rnn(intra_segs, (hidden, cell))
        lr_output = self.proj(rnn_output)
        return lr_output


class AttnPositionalEncoding(nn.Module):

    def __init__(self, enc_dim, attn_max_len=5000):
        """
        Positional Encoding for Multi-Head Attention

        Parameters:
            enc_dim (int): Dimension of each frame (e.g. choice in paper: ``128``).
            attn_max_len (int): Maximum length of the sequence to be processed by
                multi-head attention.
        """
        super().__init__()
        pe = torch.zeros(attn_max_len, enc_dim)
        position = torch.arange(0, attn_max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, enc_dim, 2).float() * (-math.log(10000.0)/enc_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Compute positional encoding

        Parameters:
        	x (tensor): the sequence to be addded with the positional encoding vector
        Returns:
            output (tensor): the encoded input
        """
        output = x + self.pe[:x.size(0), :]
        return output


class GlobalAttnLayer(nn.Module):

    def __init__(self, enc_dim, n_attn_head, attn_dropout):
        """
        Globally Attentive Layer (Sec 2.2.2 in
            https://arxiv.org/abs/2101.05014)

        Parameters:
            enc_dim (int): Dimension of each frame (e.g. choice in paper: ``128``).
            n_attn_head (int): Number of heads for the multi-head attention. (e.g.
                choice in paper: ``8``)
            attn_dropout (float): Dropout rate for multi-head attention (e.g. choice in
                paper: ``0.1``).
        """
        super().__init__()
        self.attn = nn.MultiheadAttention(enc_dim,
            n_attn_head, dropout=attn_dropout)
        self.norm = nn.LayerNorm(enc_dim)
        self.dropout = nn.Dropout(attn_dropout)

    def forward(self, inter_segs):
        """
        Process through a globally attentive layer along the inter-segment
            direction.

        Parameters:
        	inter_segs (tensor): A batch of inter-segments in shape `[S, B*K, D]`, where
                `B` is the batch size, `S` is the number of segments, 'K' is the
                segment length (seg_len) and `D` is the feature dimension (enc_dim).
        Returns:
            output (tensor): A batch of processed segments with the same shape as the input.
        """
        output, _ = self.attn(inter_segs, inter_segs, inter_segs)
        output = self.norm(output + self.dropout(output))
        return output


class DeepGlobalAttnLayer(nn.Module):

    def __init__(self, enc_dim, n_attn_head, attn_dropout, n_attn_layer=1):
        """
        A Stack of Globally Attentive Layers (Sec 2.2.2 in
            https://arxiv.org/abs/2101.05014)

        Parameters:
            enc_dim (int): Dimension of each frame (e.g. choice in paper: ``128``).
            n_attn_head (int): Number of heads for the multi-head attention. (e.g.
                choice in paper: ``8``)
            attn_dropout (float): Dropout rate for multi-head attention (e.g. choice in
                paper: ``0.1``).
            n_attn_layer (int): Number of globally attentive layers stacked. The
                setting in paper by default is ``1``.
        """
        super().__init__()
        self.attn_in_norm = nn.LayerNorm(enc_dim)
        self.pos_enc = AttnPositionalEncoding(enc_dim)
        self.attn_layer = nn.ModuleList([
            GlobalAttnLayer(enc_dim, n_attn_head, attn_dropout) for _ in range(n_attn_layer)])

    def forward(self, inter_segs):
        """
        Process through a stack of globally attentive layers.

        Parameters:
        	inter_segs (tensor): A batch of inter-segments in shape `[S, B*K, D]`, where
            `B` is the batch size, `S` is the number of segments, 'K' is the
            segment length (seg_len) and `D` is the feature dimension (enc_dim).
        Returns:
        	 output (tensor): A batch of processed segments with the same shape as the input.
        """
        output = self.attn_in_norm(inter_segs)
        output = self.pos_enc(output)
        for block in self.attn_layer:
            output = block(output)
        return output


class GALRBlock(nn.Module):

    def __init__(self, enc_dim, hid_dim, seg_len, low_dim=8, n_attn_head=8, attn_dropout=0.1):
        """
        Globally Attentive Locally Recurrent (GALR) Block (Sec. 2.2 in
            https://arxiv.org/abs/2101.05014)

        Parameters:
            enc_dim (int): Dimension of each frame (e.g. choice in paper: ``128``).
            hid_dim (int): Number of hidden nodes used in the Bi-LSTM.
            seg_len (int): Segment length for processing frame sequence (e.g. a
                ``64`` when win_len is ``16``). By default, the segments are
                half-overlapping.
            low_dim (int): Lower dimension for speeding up GALR. (Sec. 2.2.3)
                (e.g. ``8``  when seg_len is ``64``).
            n_attn_head (int): Number of heads for the multi-head attention. (e.g.
                choice in paper: ``8``)
            attn_dropout (float): Dropout rate for multi-head attention (e.g. choice in
                paper: ``0.1``).
        """
        super().__init__()
        self.low_dim = low_dim
        self.local_rnn = BiLSTMproj(enc_dim, hid_dim)
        self.local_norm = nn.GroupNorm(1, enc_dim)
        self.global_rnn = DeepGlobalAttnLayer(enc_dim, n_attn_head, attn_dropout)
        self.global_norm = nn.GroupNorm(1, enc_dim)
        self.ld_map = nn.Linear(seg_len, low_dim)
        self.ld_inv_map = nn.Linear(low_dim, seg_len)

    def forward(self, segments):
        """
        Process through a GALR block.

        Parameters:
        	segments (tensor): A batch of 3D segments in shape `[B, D, K, S]`, where
                `B` is the batch size, `S` is the number of segments, 'K' is the
                segment length (seg_len) and `D` is the feature dimension (enc_dim).
        Returns:
        	 segments (tensor): A batch of processed segments with the same shape as the input.
        """
        batch_size, feat_dim, seg_len, n_segs = segments.size()
        # Change the sequence direction for intra-segment processing
        local_input = segments.transpose(1, 3).reshape(batch_size * n_segs, seg_len, feat_dim)
        # Process through a locally recurrent layer
        local_output = self.local_rnn(local_input)
        # Reshape to match the dimensionality of the input for residual connection
        local_output = local_output.view(batch_size, n_segs, seg_len, feat_dim).transpose(1, 3).contiguous()
        # Add a layer normalization before the residual connection
        local_output = self.local_norm(local_output)
        # Add residual connection
        segments = segments + local_output

        # Change the sequence direction for intra-segment processing
        global_input = segments.permute(3, 2, 0, 1).contiguous().view(n_segs, seg_len, batch_size, feat_dim)
        # Perform low-dimensional mapping for speeding up GALR
        global_input = self.ld_map(global_input.transpose(1, -1))
        # Reshape for intra-segment processing (sequence, batch size, feature dim)
        global_input = global_input.transpose(1, -1).contiguous().view(n_segs, -1, feat_dim)
        # Process through a globally attentive layer
        global_output = self.global_rnn(global_input)
        # Reshape for low-dimensional inverse mapping
        global_output = global_output.view(n_segs, self.low_dim, -1, feat_dim).transpose(1, -1)
        # Map the low-dimensional features back to the original size
        global_output = self.ld_inv_map(global_output)
        # Reshape to match the dimensionality of the input for residual connection
        global_output = global_output.permute(2, 1, 3, 0).contiguous()
        # Add a layer normalization before the residual connection
        global_output = self.global_norm(global_output)
        # Add residual connection
        segments = segments + global_output

        return segments


class _GALR(nn.Module):

    def __init__(self, n_block, enc_dim, hid_dim, win_len, seg_len,
            low_dim=8, n_attn_head=8, attn_dropout=0.1):
        """
        Globally Attentive Locally Recurrent (GALR) Networks
            (https://arxiv.org/abs/2101.05014)

        Parameters:
            n_block (int): Number of GALR blocks (e.g. choice in paper: ``6``).
            enc_dim (int): Dimension of each frame (e.g. choice in paper: ``128``).
            hid_dim (int): Number of hidden nodes used in the Bi-LSTM.
            win_len (int): Window length for processing raw signal samples (e.g. a
                common choice: ``8``). By default, the windows are half-overlapping.
            seg_len (int): Segment length for processing frame sequence (e.g. a
                ``64`` when win_len is ``8``). By default, the segments are half-overlapping.
            low_dim (int): Lower dimension for speeding up GALR. (Sec. 2.2.3)
                (e.g. ``8``  when seg_len is ``64``).
            n_attn_head (int): Number of heads for the multi-head attention. (e.g.
                choice in paper: ``8``)
            attn_dropout (flaot): Dropout rate for multi-head attention (e.g. choice in
                paper: ``0.1``).
        """
        super().__init__()
        self.win_len = win_len
        self.seg_len = seg_len
        self.encoder = Encoder(enc_dim, win_len)
        self.bottleneck = nn.Conv1d(enc_dim, enc_dim, 1, bias=False)
        # GALR blocks
        self.blocks = nn.ModuleList([
            GALRBlock(enc_dim, hid_dim, seg_len, low_dim, n_attn_head, attn_dropout)
                for i in range(n_block)])
        # Many-to-one gated layer applied to GALR's last block
        self.block_gate = nn.Sequential(
            nn.PReLU(),
            nn.Conv2d(enc_dim, 1, 1)
        )

    def forward(self, noisy_signals):
        """
        Process through a GALR network for noise estimation

        Parameters:
        	noisy_signals (tensor): A batch of 1D signals in shape `[B, T]`, where
                `B` is the batch size, `T` is the maximum length of the `B` signals.
        Returns:
        	est_ratios (tensor): A batch of scalar noise scale ratios in `[B, 1]` shape.
        """
        noisy_signals_padded, _ = self.pad_zeros(noisy_signals)
        mix_frames = self.encoder(noisy_signals_padded)
        mix_frames = self.bottleneck(mix_frames)
        block_feature, _ = self.split_feature(mix_frames)
        for block in self.blocks:
            block_feature = block(block_feature)
        est_segments = self.block_gate(block_feature)
        est_ratios = torch.sigmoid(est_segments.mean([2, 3]))
        est_ratios = est_ratios.mean(1, keepdim=True)
        return est_ratios

    def pad_zeros(self, signals):
        """
        Pad a batch of signals with zeros before encoding.

        Parameters:
        	signals (tensor): A batch of 1D signals in shape `[B, T]`, where `B` is
                the batch size, `T` is the maximum length of the `B` signals.
        Returns:
        	signals (tensor): A batch of padded signals and the length of zeros used for
                padding. (in shape `[B, T]`)
            rest (int): the redundant spaces created in padding
        """
        batch_size, sig_len = signals.shape
        self.hop_size = self.win_len // 2
        rest = self.win_len - (self.hop_size + sig_len % self.win_len) % self.win_len
        if rest > 0:
            pad = torch.zeros(batch_size, rest).type(signals.type())
            signals = torch.cat([signals, pad], 1)
        pad_aux = torch.zeros(batch_size, self.hop_size).type(signals.type())
        signals = torch.cat([pad_aux, signals, pad_aux], 1)
        return signals, rest

    def pad_segment(self, frames):
        """
        Pad a batch of frames with zeros before segmentation.

        Parameters:
        	frames (tensor): A batch of 2D frames in shape `[B, D, L]`, where `B` is
                the batch size, `D` is the encoded feature dimension (enc_dim), and
                `L` is the length of this feature sequence.
        Returns:
        	frames (tensor): A batch of padded frames and the length of zeros used for
                padding. (in shape `[B, D, L]`)
            rest (int): the redundant spaces created in padding
        """
        batch_size, feat_dim, seq_len = frames.shape
        stride = self.seg_len // 2
        rest = self.seg_len - (stride + seq_len % self.seg_len) % self.seg_len
        if rest > 0:
            pad = Variable(torch.zeros(batch_size, feat_dim, rest)).type(frames.type())
            frames = torch.cat([frames, pad], 2)
        pad_aux = Variable(torch.zeros(batch_size, feat_dim, stride)).type(frames.type())
        frames = torch.cat([pad_aux, frames, pad_aux], 2)
        return frames, rest

    def split_feature(self, frames):
        """
        Perform segmentation by dividing every K (seg_len) consecutive frames
            into S segments.

        Parameters:
        	frames (tensor): A batch of 2D frames in shape `[B, D, L]`, where `B` is
                the batch size, `D` is the encoded feature Dension (enc_D), and
                `L` is the length of this feature sequence.
        Returns:
        	segments (tensor): A batch of 3D segments and the length of zeros used for
                padding. (in shape `[B, D, K, S]`)
            rest (int): the redundant spaces created in padding
        """
        frames, rest = self.pad_segment(frames)
        batch_size, feat_dim, _ = frames.shape
        stride = self.seg_len // 2
        lsegs = frames[:, :, :-stride].contiguous().view(batch_size, feat_dim, -1, self.seg_len)
        rsegs = frames[:, :, stride:].contiguous().view(batch_size, feat_dim, -1, self.seg_len)
        segments = torch.cat([lsegs, rsegs], -1).view(batch_size, feat_dim, -1, self.seg_len)
        segments = segments.transpose(2, 3).contiguous()
        return segments, rest


class GALR(nn.Module):

    def __init__(self, blocks=2, input_dim=128, hidden_dim=128, window_length=8, segment_size=64):
        """
        GALR Schedule Network

        Parameters:
            blocks (int): Number of GALR blocks (e.g. choice in paper: ``6``).
            input_dim (int): Dimension of each frame (e.g. choice in paper: ``128``).
            hidden_dim (int): Number of hidden nodes used in the Bi-LSTM.
            window_length (int): Window length for processing raw signal samples (e.g. a
                common choice: ``8``). By default, the windows are half-overlapping.
            segment_size (int): Segment length for processing frame sequence (e.g. a
                ``64`` when win_len is ``8``). By default, the segments are half-overlapping.
        """
        super().__init__()
        self.ratio_nn = _GALR(blocks, input_dim, hidden_dim, window_length, segment_size)

    def forward(self, audio, scales):
        """
        Estimate the next beta scale using the GALR schedule network

        Parameters:
        	audio (tensor): A batch of 1D signals in shape `[B, T]`, where
                `B` is the batch size, `T` is the maximum length of the `B` signals.
            scales (list): [beta_nxt, delta] for computing the upper bound
        Returns:
        	beta (tensor): A batch of scalar noise scale ratios in `[B, 1, 1]` shape.
        """
        beta_nxt, delta = scales
        bounds = torch.cat([beta_nxt, delta], 1)
        mu, _ = torch.min(bounds, 1)
        mu = mu[:, None]
        ratio = self.ratio_nn(audio)
        beta = mu * ratio
        return beta.view(audio.size(0), 1, 1)


================================================
FILE: bddm/sampler/__init__.py
================================================
#!/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
#  Sampler
#
#  Author: Max W. Y. Lam (maxwylam@tencent.com)
#  Copyright (c) 2021Tencent. All Rights Reserved
#
########################################################################

from .sampler import Sampler


================================================
FILE: bddm/sampler/sampler.py
================================================
#!/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
#  BDDM Sampler (Supports Noise Scheduling and Sampling)
#
#  Author: Max W. Y. Lam (maxwylam@tencent.com)
#  Copyright (c) 2021Tencent. All Rights Reserved
#
########################################################################


from __future__ import absolute_import

import os
import librosa

import torch
import numpy as np
from scipy.io.wavfile import write as wavwrite
from pystoi import stoi
from pypesq import pesq

from bddm.utils.log_utils import log
from bddm.utils.check_utils import check_score_network
from bddm.utils.check_utils import check_schedule_network
from bddm.utils.diffusion_utils import compute_diffusion_params
from bddm.utils.diffusion_utils import map_noise_scale_to_time_step
from bddm.models import get_score_network, get_schedule_network
from bddm.loader.dataset import create_generation_dataloader
from bddm.loader.dataset import create_train_and_valid_dataloader


class Sampler(object):

    metric2index = {"PESQ": 0, "STOI": 1}

    def __init__(self, config):
        """
        Sampler Class, implements the Noise Scheduling and Sampling algorithms in BDDMs

        Parameters:
            config (namespace): BDDM Configuration
        """
        self.config = config
        self.exp_dir = config.exp_dir
        self.clip = config.grad_clip
        self.load = config.load
        self.model = get_score_network(config).cuda().eval()
        self.schedule = None
        # Initialize diffusion parameters using a pre-specified linear schedule
        noise_schedule = torch.linspace(config.beta_0, config.beta_T, config.T).cuda()
        self.diff_params = compute_diffusion_params(noise_schedule)
        if self.config.command != 'train':
            # Find schedule net, if not trained then use DDPM or DDIM sampling mode
            schedule_net_trained, schedule_net_path = check_schedule_network(config)
            if not schedule_net_trained:
                _, score_net_path = check_score_network(config)
                assert score_net_path is not None, 'Error: No score network can be found!'
                self.config.load = score_net_path
            else:
                self.config.load = schedule_net_path
                self.model.schedule_net = get_schedule_network(config).cuda().eval()

        # Perform noise scheduling when noise schedule file (.schedule) is not given
        if self.config.command == 'generate' and self.config.sampling_noise_schedule != '':
            # Generation mode given pre-searched noise schedule
            self.schedule = torch.load(self.config.sampling_noise_schedule, map_location='cpu')

        self.reset()

    def reset(self):
        """
        Reset sampling environment
        """
        if self.config.command != 'train' and self.load != '':
            package = torch.load(self.load, map_location=lambda storage, loc: storage.cuda())
            self.model.load_state_dict(package['model_state_dict'])
            log('Loaded checkpoint %s' % self.load, self.config)

        if self.config.command == 'generate':
            # Given Mel-spectrogram directory for speech vocoding
            self.gen_loader = create_generation_dataloader(self.config)
        else:
            # Sample a reference audio sample for noise scheduling
            _, self.vl_loader = create_train_and_valid_dataloader(self.config)
            self.draw_reference_data_pair()

    def draw_reference_data_pair(self):
        """
        Draw a new input-output pair for noise scheduling
        """
        self.ref_spec, self.ref_audio = next(iter(self.vl_loader))
        self.ref_spec, self.ref_audio = self.ref_spec.cuda(), self.ref_audio.cuda()

    def generate(self):
        """
        Start the generation process
        """
        generate_dir = os.path.join(self.exp_dir, 'generated')
        os.makedirs(generate_dir, exist_ok=True)
        scores = {metric: [] for metric in self.metric2index}
        for filepath, mel_spec, audio in self.gen_loader:
            mel_spec = mel_spec.cuda()
            generated_audio, n_steps = self.sampling(schedule=self.schedule,
                                                     condition=mel_spec,
                                                     ddim=self.config.use_ddim_steps)
            audio_key = '.'.join(filepath[0].split('/')[-1].split('.')[:-1])

            if len(audio) > 0:
                # Assess the generated audio with the reference audio
                self.ref_audio = audio
                score = self.assess(generated_audio, audio_key=audio_key)
                for metric in self.metric2index:
                    scores[metric].append(score[self.metric2index[metric]])

            model_name = 'BDDM' if self.schedule is not None else (
                'DDIM' if self.config.use_ddim_steps else 'DDPM')
            generated_file = os.path.join(generate_dir,
                '%s_by_%s-%d.wav'%(audio_key, model_name, n_steps))
            wavwrite(generated_file, self.config.sampling_rate,
                     generated_audio.squeeze().cpu().numpy())
            log('Generated '+generated_file, self.config)
        log('Avg Scores: PESQ = %.3f +/- %.3f, %.5f +/- %.5f'%(
            np.mean(scores['PESQ']), 1.96 * np.std(scores['PESQ']),
            np.mean(scores['STOI']), 1.96 * np.std(scores['STOI'])), self.config)

    def noise_scheduling(self, ddim=False):
        """
        Start the noise scheduling process

        Parameters:
            ddim (bool): whether to use the DDIM's p_theta for noise scheduling or not
        Returns:
            ts_infer (tensor): the step indices estimated by BDDM
            a_infer (tensor):  the alphas estimated by BDDM
            b_infer (tensor):  the betas estimated by BDDM
            s_infer (tensor):  the std. deviations estimated by BDDM
        """
        max_steps = self.diff_params["N"]
        alpha = self.diff_params["alpha"]
        alpha_param = self.diff_params["alpha_param"]
        beta_param = self.diff_params["beta_param"]
        min_beta = self.diff_params["beta"].min()
        betas = []
        x = torch.normal(0, 1, size=self.ref_audio.shape).cuda()
        with torch.no_grad():
            b_cur = torch.ones(1, 1, 1).cuda() * beta_param
            a_cur = torch.ones(1, 1, 1).cuda() * alpha_param
            for n in range(max_steps - 1, -1, -1):
                step = map_noise_scale_to_time_step(a_cur.squeeze().item(), alpha)
                if step >= 0:
                    betas.append(b_cur.squeeze().item())
                else:
                    break
                ts = (step * torch.ones((1, 1))).cuda()
                e = self.model((x, self.ref_spec.clone(), ts,))
                a_nxt = a_cur / (1 - b_cur).sqrt()
                if ddim:
                    c1 = a_nxt / a_cur
                    c2 = -(1 - a_cur**2.).sqrt() * c1
                    x = c1 * x + c2 * e
                    c3 = (1 - a_nxt**2.).sqrt()
                    x = x + c3 * e
                else:
                    x = x - b_cur / torch.sqrt(1 - a_cur**2.) * e
                    x = x / torch.sqrt(1 - b_cur)
                    if n > 0:
                        z = torch.normal(0, 1, size=self.ref_audio.shape).cuda()
                        x = x + torch.sqrt((1 - a_nxt**2.) / (1 - a_cur**2.) * b_cur) * z
                a_nxt, beta_nxt = a_cur, b_cur
                a_cur = a_nxt / (1 - beta_nxt).sqrt()
                if a_cur > 1:
                    break
                b_cur = self.model.schedule_net(x.squeeze(1),
                    (beta_nxt.view(-1, 1), (1 - a_cur**2.).view(-1, 1)))
                if b_cur.squeeze().item() < min_beta:
                    break
        b_infer = torch.FloatTensor(betas[::-1]).cuda()
        a_infer = 1 - b_infer
        s_infer = b_infer + 0
        for n in range(1, len(b_infer)):
            a_infer[n] *= a_infer[n-1]
            s_infer[n] *= (1 - a_infer[n-1]) / (1 - a_infer[n])
        a_infer = torch.sqrt(a_infer)
        s_infer = torch.sqrt(s_infer)

        # Mapping noise scales to time steps
        ts_infer = []
        for n in range(len(b_infer)):
            step = map_noise_scale_to_time_step(a_infer[n], alpha)
            if step >= 0:
                ts_infer.append(step)
        ts_infer = torch.FloatTensor(ts_infer)
        return ts_infer, a_infer, b_infer, s_infer

    def sampling(self, schedule=None, condition=None,
                 ddim=0, return_sequence=False, audio_size=None):
        """
        Perform the sampling algorithm

        Parameters:
            schedule (list):        the [ts_infer, a_infer, b_infer, s_infer]
                                    returned by the noise scheduling algorithm
            condition (tensor):     the condition for computing scores
            ddim (bool):            whether to use the DDIM for sampling or not
            return_sequence (bool): whether returning all steps' samples or not
            audio_size (list):      the shape of the audio to be sampled
        Returns:
            xs (list):              (if return_sequence) the list of generated audios
            x (tensor):             the generated audio(s) in shape=audio_size
            N (int):                the number of sampling steps
        """

        n_steps = self.diff_params["T"]
        if condition is None:
            condition = self.ref_spec

        if audio_size is None:
            audio_length = condition.size(-1) * self.config.hop_len
            audio_size = (1, 1, audio_length)

        if schedule is None:
            if ddim > 1:
                # Use DDIM (linear) for sampling ({ddim} steps)
                ts_infer = torch.linspace(0, n_steps - 1, ddim).long()
                a_infer = self.diff_params["alpha"].index_select(0, ts_infer.cuda())
                b_infer = self.diff_params["beta"].index_select(0, ts_infer.cuda())
                s_infer = self.diff_params["sigma"].index_select(0, ts_infer.cuda())
            else:
                # Use DDPM for sampling (complete T steps)
                # P.S. if ddim = 1, run DDIM reverse process for T steps
                ts_infer = torch.linspace(0, n_steps - 1, n_steps)
                a_infer = self.diff_params["alpha"]
                b_infer = self.diff_params["beta"]
                s_infer = self.diff_params["sigma"]
        else:
            ts_infer, a_infer, b_infer, s_infer = schedule

        sampling_steps = len(ts_infer)

        x = torch.normal(0, 1, size=audio_size).cuda()
        if return_sequence:
            xs = []
        with torch.no_grad():
            for n in range(sampling_steps - 1, -1, -1):
                if sampling_steps > 50 and (sampling_steps - n) % 50 == 0:
                    # Log progress per 50 steps when sampling_steps is large
                    log('\tComputed %d / %d steps !'%(
                        sampling_steps - n, sampling_steps), self.config)
                ts = (ts_infer[n] * torch.ones((1, 1))).cuda()
                e = self.model((x, condition, ts,))
                if ddim:
                    if n > 0:
                        a_nxt = a_infer[n - 1]
                    else:
                        a_nxt = a_infer[n] / (1 - b_infer[n]).sqrt()
                    c1 = a_nxt / a_infer[n]
                    c2 = -(1 - a_infer[n]**2.).sqrt() * c1
                    c3 = (1 - a_nxt**2.).sqrt()
                    x = c1 * x + (c2 + c3) * e
                else:
                    x = x - b_infer[n] / torch.sqrt(1 - a_infer[n]**2.) * e
                    x = x / torch.sqrt(1 - b_infer[n])
                    if n > 0:
                        z = torch.normal(0, 1, size=audio_size).cuda()
                        x = x + s_infer[n] * z
                if return_sequence:
                    xs.append(x)
        if return_sequence:
            return xs
        return x, sampling_steps

    def noise_scheduling_with_params(self, alpha_param, beta_param):
        """
        Run noise scheduling for once given the (alpha_param, beta_param) pair

        Parameters:
            alpha_param (float): a hyperparameter defining the alpha_hat value at step N
            beta_param (float):  a hyperparameter defining the beta_hat value at step N
        """
        log('TRY alpha_param=%.2f, beta_param=%.2f:'%(
            alpha_param, beta_param), self.config)
        # Define the pair key
        key = '%.2f,%.2f' % (alpha_param, beta_param)
        # Set alpha_param and beta_param in self.diff_params
        self.diff_params['alpha_param'] = alpha_param
        self.diff_params['beta_param'] = beta_param
        # Use DDPM reverse process for noise scheduling
        ddpm_schedule = self.noise_scheduling(ddim=False)
        log("\tSearched a %d-step schedule using DDPM reverse process" % (
            len(ddpm_schedule[0])), self.config)
        generated_audio, _ = self.sampling(schedule=ddpm_schedule)
        # Compute objective scores
        ddpm_score = self.assess(generated_audio)
        # Get the number of sampling steps with this schedule
        steps = len(ddpm_schedule[0])
        # Compare the performance with previous same-step schedule using the metric
        if steps not in self.steps2score:
            # Save the first schedule with this number of steps
            self.steps2score[steps] = [key, ] + ddpm_score
            self.steps2schedule[steps] = ddpm_schedule
            log('\tFound the first %d-step schedule: (PESQ, STOI) = (%.2f, %.3f)'%(
                steps, ddpm_score[0], ddpm_score[1]), self.config)
        elif ddpm_score[0] > self.steps2score[steps][1] and ddpm_score[1] > self.steps2score[steps][2]:
            # Found a better same-step schedule achieving a higher score
            log('\tFound a better %d-step schedule: (PESQ, STOI) = (%.2f, %.3f) -> (%.2f, %.3f)'%(
                steps, self.steps2score[steps][1], self.steps2score[steps][2],
                ddpm_score[0], ddpm_score[1]), self.config)
            self.steps2score[steps] = [key, ] + ddpm_score
            self.steps2schedule[steps] = ddpm_schedule
        # Use DDIM reverse process for noise scheduling
        ddim_schedule = self.noise_scheduling(ddim=True)
        log("\tSearched a %d-step schedule using DDIM reverse process" % (
            len(ddim_schedule[0])), self.config)
        generated_audio, _ = self.sampling(schedule=ddim_schedule)
        # Compute objective scores
        ddim_score = self.assess(generated_audio)
        # Get the number of sampling steps with this schedule
        steps = len(ddim_schedule[0])
        # Compare the performance with previous same-step schedule using the metric
        if steps not in self.steps2score:
            # Save the first schedule with this number of steps
            self.steps2score[steps] = [key, ] + ddim_score
            self.steps2schedule[steps] = ddim_schedule
            log('\tFound the first %d-step schedule: (PESQ, STOI) = (%.2f, %.3f)'%(
                steps, ddim_score[0], ddim_score[1]), self.config)
        elif ddim_score[0] > self.steps2score[steps][1] and ddim_score[1] > self.steps2score[steps][2]:
            # Found a better same-step schedule achieving a higher score
            log('\tFound a better %d-step schedule: (PESQ, STOI) = (%.2f, %.3f) -> (%.2f, %.3f)'%(
                steps, self.steps2score[steps][1], self.steps2score[steps][2],
                ddim_score[0], ddim_score[1]), self.config)
            self.steps2score[steps] = [key, ] + ddim_score
            self.steps2schedule[steps] = ddim_schedule

    def noise_scheduling_without_params(self):
        """
        Search for the best noise scheduling hyperparameters: (alpha_param, beta_param)
        """
        # Noise scheduling mode, given N
        self.reverse_process = 'BDDM'
        assert 'N' in vars(self.config).keys(), 'Error: N is undefined for BDDM!'
        self.diff_params["N"] = self.config.N
        # Init search result dictionaries
        self.steps2schedule, self.steps2score = {}, {}
        search_bins = int(self.config.bddm_search_bins)
        # Define search range of alpha_param
        alpha_last = self.diff_params["alpha"][-1].squeeze().item()
        alpha_first = self.diff_params["alpha"][0].squeeze().item()
        alpha_diff = (alpha_first - alpha_last) / (search_bins + 1)
        alpha_param_list = [alpha_last + alpha_diff * (i + 1) for i in range(search_bins)]
        # Define search range of beta_param
        beta_diff = 1. / (search_bins + 1)
        beta_param_list = [beta_diff * (i + 1) for i in range(search_bins)]
        # Search for beta_param and alpha_param, take O(search_bins^2)
        for beta_param in beta_param_list:
            for alpha_param in alpha_param_list:
                if alpha_param > (1 - beta_param) ** 0.5:
                    # Invalid range
                    continue
                # Update the scores and noise schedules with (alpha_param, beta_param)
                self.noise_scheduling_with_params(alpha_param, beta_param)
        # Lastly, repeat the random starting point (x_hat_N) and choose the best schedule
        noise_schedule_dir = os.path.join(self.exp_dir, 'noise_schedules')
        os.makedirs(noise_schedule_dir, exist_ok=True)
        steps_list = list(self.steps2score.keys())
        for steps in steps_list:
            log("-"*80, self.config)
            log("Select the best out of %d x_hat_N ~ N(0,I) for %d steps:"%(
                self.config.noise_scheduling_attempts, steps), self.config)
            # Get current best pair
            key = self.steps2score[steps][0]
            # Get back the best (alpha_param, beta_param) pair for a given steps
            alpha_param, beta_param = list(map(float, key.split(',')))
            # Repeat K times for a given number of steps
            for _ in range(self.config.noise_scheduling_attempts):
                # Random +/- 5%
                _alpha_param = alpha_param * (0.95 + np.random.rand() * 0.1)
                _beta_param = beta_param * (0.95 + np.random.rand() * 0.1)
                # Update the scores and noise schedules with (alpha_param, beta_param)
                self.noise_scheduling_with_params(_alpha_param, _beta_param)
        # Save the best searched noise schedule ({N}steps_{key}_{metric}{best_score}.ns)
        for steps in sorted(self.steps2score.keys(), reverse=True):
            filepath = os.path.join(noise_schedule_dir, '%dsteps_PESQ%.2f_STOI%.3f.ns'%(
                steps, self.steps2score[steps][1], self.steps2score[steps][2]))
            torch.save(self.steps2schedule[steps], filepath)
            log("Saved searched schedule: %s" % filepath, self.config)

    def assess(self, generated_audio, audio_key=None):
        """
        Assess the generated audio using objective metrics: PESQ and STOI.
            P.S. Users should first install pypesq and pystoi using pip

        Parameters:
            generated_audio (tensor): the generated audio to be assessed
            audio_key (str):          the key of the respective audio
        Returns:
            pesq_score (float):       the PESQ score (the higher the better)
            stoi_score (float):       the STOI score (the higher the better)
        """
        est_audio = generated_audio.squeeze().cpu().numpy()
        ref_audio = self.ref_audio.squeeze().cpu().numpy()
        if est_audio.shape[-1] > ref_audio.shape[-1]:
            est_audio = est_audio[..., :ref_audio.shape[-1]]
        else:
            ref_audio = ref_audio[..., :est_audio.shape[-1]]
        # Compute STOI using PySTOI
        # PySTOI (https://github.com/mpariente/pystoi)
        stoi_score = stoi(ref_audio, est_audio, self.config.sampling_rate, extended=False)
        # Resample audio to 16K Hz to compute PESQ using PyPESQ (supports only 8K / 16K)
        # PyPESQ (https://github.com/vBaiCai/python-pesq)
        if self.config.sampling_rate != 16000:
            est_audio_16k = librosa.resample(est_audio, self.config.sampling_rate, 16000)
            ref_audio_16k = librosa.resample(ref_audio, self.config.sampling_rate, 16000)
            pesq_score = pesq(ref_audio_16k, est_audio_16k, 16000)
        else:
            pesq_score = pesq(ref_audio, est_audio, 16000)
        # Log scores
        log('\t%sScores: PESQ = %.3f, STOI = %.5f'%(
            '' if audio_key is None else audio_key+' ', pesq_score, stoi_score), self.config)
        # Return scores: the higher the better
        return [pesq_score, stoi_score]


================================================
FILE: bddm/trainer/__init__.py
================================================
#!/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
#  Trainer
#
#  Author: Max W. Y. Lam (maxwylam@tencent.com)
#  Copyright (c) 2021Tencent. All Rights Reserved
#
########################################################################


from .trainer import Trainer


================================================
FILE: bddm/trainer/bmuf.py
================================================
#!/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
#  BMUF Multi-GPU Training Method
#
#  Author: Max W. Y. Lam (maxwylam@tencent.com)
#  Copyright (c) 2021Tencent. All Rights Reserved
#
########################################################################


import os
import sys
import psutil
import torch
import torch.nn as nn
import torch.distributed as dist


class BmufTrainer(object):

    def __init__(self, master_node, rank, world_size, model, block_momentum, block_lr):
        """
        Basic BMUF Trainer Class, implements Nesterov Block Momentum

        Parameters:
            master_node (int):      master node index, zero in most cases
            rank (int):             local rank, eg, 0-7 if 8GPUs are used
            world_size (int):       total number of workers
            model (nn.Module):      PyTorch model
            block_momentum (float): block momentum value
            block_lr (float):       block learning rate
        """
        assert torch.cuda.is_available(), "Distributed mode requires CUDA."
        self.master_node = 0  # By default, use device 0 as master node
        self.model = model
        self.rank = rank
        self.world_size = world_size
        self.block_momentum = block_momentum
        self.block_lr = block_lr
        dist.init_process_group(backend="nccl", init_method="env://")
        param_vec = nn.utils.parameters_to_vector(model.parameters())
        self.param = param_vec.data.clone()
        dist.broadcast(tensor=self.param, src=self.master_node, async_op=False)
        num_param = self.param.numel()
        self.delta_prev = torch.FloatTensor([0]*num_param).to(self.param.device)
        if self.rank == self.master_node:
            self.delta_prev = torch.FloatTensor([0]*num_param).cuda(self.master_node)
        else:
            self.delta_prev = None
            self._copy_vec_to_param(self.param)
        # for fining the best model among different ranks
        self.min_val_loss = torch.FloatTensor([float('inf')]).to(self.param.device)

    def check_all_processes_running(self):
        """
        Check whether all processes are healthy
        """
        pid = os.getpid()
        for p in range(pid-self.rank, pid-self.rank+self.world_size):
            if not psutil.pid_exists(p):
                sys.exit(-1)

    def get_average_valid_loss(self, val_loss):
        """
        Get average validatin loss  through all reduce operations

        Parameters:
            val_loss (Tensor): calculated validation loss in each process
        Returns:
            save_or_not (bool): whether save the current validated model or not
        """
        save_or_not = False
        dist.all_reduce(tensor=val_loss)
        val_loss = val_loss / float(self.world_size)
        # save the min valid loss among all gpus
        if 'min_val_loss' not in self.__dict__ or val_loss < self.min_val_loss:
            self.min_val_loss = val_loss.clone()
            save_or_not = True
        return save_or_not

    def update_and_sync(self, val_loss=None):
        """
        Update and synchronize block gradients

        Parameters:
            val_loss (tensor): calculated validation loss in each process
        Returns:
            save_or_not (bool): whether save the current validated model or not
        """
        self.check_all_processes_running()
        save_or_not = False
        if val_loss is not None:
            save_or_not = self.get_average_valid_loss(val_loss)

        cur_param_vec = nn.utils.parameters_to_vector(self.model.parameters()).data
        delta = self.param - cur_param_vec
        # Gather block gradients into delta
        dist.reduce(tensor=delta, dst=self.master_node)

        # Check if model params are still healthy
        if torch.isnan(delta).sum().item():
            print('Found nan, exit!', flush=True)
            sys.exit(-1)
        if self.rank == self.master_node:
            # Local rank is master node
            delta = delta / float(self.world_size)
            self.delta_prev = self.block_momentum * self.delta_prev + \
                              (self.block_lr * (1 - self.block_momentum) * delta)
            self.param -= (1+self.block_momentum) * self.delta_prev
        dist.broadcast(tensor=self.param, src=self.master_node, async_op=False)
        self._copy_vec_to_param(self.param)
        return save_or_not

    def _copy_vec_to_param(self, vec):
        """
        Copy a vectorized array to the model parameters

        Parameters:
            vec (tensor): a single vector represents the parameters of a model.
        """
        # Ensure vec of type Tensor
        if not isinstance(vec, torch.Tensor):
            raise TypeError('expected torch.Tensor, but got: {}'
                            .format(torch.typename(vec)))
        # Pointer for slicing the vector for each parameter
        pointer = 0
        for param in self.model.parameters():
            # The length of the parameter
            num_param = param.numel()
            # Slice the vector, reshape it, and replace the old data of the parameter
            param.data = param.data.copy_(vec[pointer:pointer + num_param]
                                          .view_as(param).data)
            # Increment the pointer
            pointer += num_param


================================================
FILE: bddm/trainer/ema.py
================================================
#!/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
#  EMA Helper Class
#
#  Author: Max W. Y. Lam (maxwylam@tencent.com)
#  Copyright (c) 2021Tencent. All Rights Reserved
#
########################################################################


import torch.nn as nn


class EMAHelper(object):

    def __init__(self, mu=0.999):
        """
        Exponential Moving Average Training Helper Class

        Parameters:
            mu (float): decaying rate
        """
        self.mu = mu
        self.shadow = {}

    def register(self, module):
        """
        Register module by copying all learnable parameters to self.shadow

        Parameters:
            module (nn.Module): model to be trained
        """
        if isinstance(module, nn.DataParallel):
            module = module.module
        for name, param in module.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self, module):
        """
        Update self.shadow using the module parameters

        Parameters:
            module (nn.Module): model in training
        """
        if isinstance(module, nn.DataParallel):
            module = module.module
        for name, param in module.named_parameters():
            if param.requires_grad:
                self.shadow[name].data = (
                    1. - self.mu) * param.data + self.mu * self.shadow[name].data

    def ema(self, module):
        """
        Copy self.shadow to the module parameters

        Parameters:
            module (nn.Module): model in training
        """
        if isinstance(module, nn.DataParallel):
            module = module.module
        for name, param in module.named_parameters():
            if param.requires_grad:
                param.data.copy_(self.shadow[name].data)

    def ema_copy(self, module):
        """
        Initialize a new module using self.shadow as the parameters

        Parameters:
            module (nn.Module): model in training
        """
        if isinstance(module, nn.DataParallel):
            inner_module = module.module
            module_copy = type(inner_module)(
                inner_module.config).to(inner_module.config.device)
            module_copy.load_state_dict(inner_module.state_dict())
            module_copy = nn.DataParallel(module_copy)
        else:
            module_copy = type(module)(module.config).to(module.config.device)
            module_copy.load_state_dict(module.state_dict())
        self.ema(module_copy)
        return module_copy

    def state_dict(self):
        """
        Get self.shadow as the state dict

        Returns:
            shadow (dict): state dict
        """
        return self.shadow

    def load_state_dict(self, state_dict):
        """
        Load a state dict to self.shadow

        Parameters:
            state dict (dict): state dict to be copied to self.shadow
        """
        self.shadow = state_dict


================================================
FILE: bddm/trainer/loss.py
================================================
#!/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
#  Implements Training Losses for BDDMs
#
#  Author: Max W. Y. Lam (maxwylam@tencent.com)
#  Copyright (c) 2021Tencent. All Rights Reserved
#
########################################################################


import torch
import torch.nn as nn


class ScoreLoss(nn.Module):

    def __init__(self, config, diff_params):
        """
        Score Loss Class, implements DDPM's simplified loss (see Eq. 5 in BDDM's paper)

        Parameters:
            config (namespace): BDDM Configuration
            diff_params (dict): Dictionary that stores pre-computed diffusion parameters
        """
        super().__init__()
        self.config = config
        self.num_steps = diff_params["T"]
        self.alpha = diff_params["alpha"]

    def forward(self, model, mels, audios):
        """
        Compute the training loss for learning theta

        Parameters:
            model (nn.Module):   the score network
            mels (tensor):       shape=(batch size, frames, spectrogram dim)
            audios (tensor):     shape=(batch size, 1, length of audio)
        Returns:
            score loss (tensor): shape=(batch size,)
        """
        batch_size = audios.size(0)
        ts = torch.randint(low=0, high=self.num_steps, size=(batch_size, 1, 1)).cuda()
        noise_scales = self.alpha[ts]
        z = torch.normal(0, 1, size=audios.shape).cuda()
        noisy_audios = noise_scales * audios + (1 - noise_scales**2.).sqrt() * z
        e = model((noisy_audios, mels, ts.view(batch_size, 1),))
        # Use WaveGrad's L1Loss for speech generation
        theta_loss = nn.L1Loss(reduction='none')(e, z[:, :, :e.size(-1)]).mean([1, 2])
        return theta_loss


class StepLoss(nn.Module):

    def __init__(self, config, diff_params):
        """
        Step Loss Class, implements BDDM's step loss (see Eq. 14 in BDDM's paper)

        Parameters:
            config (namespace): BDDM Configuration
            diff_params (dict): Dictionary that stores pre-computed diffusion parameters
        """
        super().__init__()
        self.config = config
        self.num_steps = diff_params["T"]
        self.alpha = diff_params["alpha"]
        self.tau = diff_params["tau"]

    def forward(self, model, mels, audios):
        """
        Compute the training loss for learning phi

        Parameters:
            model (nn.Module):  the score network & the schedule network
            mels (tensor):      shape=(batch size, frames, spectrogram dim)
            audios (tensor):    shape=(batch size, 1, length of audio)
        Returns:
            step loss (tensor): shape=(batch size,)
        """
        batch_size = audios.size(0)
        ts = torch.randint(self.tau, self.num_steps-self.tau, size=(batch_size,)).cuda()
        alpha_cur = self.alpha.index_select(0, ts).view(batch_size, 1, 1)
        alpha_nxt = self.alpha.index_select(0, ts+self.tau).view(batch_size, 1, 1)
        b_nxt = 1 - (alpha_nxt / alpha_cur)**2.
        delta = (1 - alpha_cur**2.).sqrt()
        z = torch.normal(0, 1, size=audios.shape).cuda()
        noisy_audios = alpha_cur * audios + delta * z
        e = model((noisy_audios, mels, ts.view(batch_size, 1),))
        beta_bounds = (b_nxt.view(batch_size, 1), delta.view(batch_size, 1)**2.)
        b_hat = model.schedule_net(noisy_audios.squeeze(1), beta_bounds)
        delta, b_hat, z, e = delta.squeeze(1), b_hat.squeeze(1), z.squeeze(1), e.squeeze(1)
        phi_loss = delta**2. / (2. * (delta**2. - b_hat))
        phi_loss = phi_loss * (z - b_hat / (delta**2.) * e).square()
        phi_loss = phi_loss + torch.log(1e-8 + delta**2. / (b_hat + 1e-8)) / 4.
        phi_loss = phi_loss.sum(-1) + (b_hat / delta**2 - 1) / 2. * audios.size(-1)
        return phi_loss


================================================
FILE: bddm/trainer/trainer.py
================================================
#!/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
#  BDDM Trainer (Support Multi-GPU Training using BMUF method)
#
#  Author: Max W. Y. Lam (maxwylam@tencent.com)
#  Copyright (c) 2021Tencent. All Rights Reserved
#
########################################################################


from __future__ import absolute_import

import os
import time
import copy
import torch

from bddm.sampler import Sampler
from bddm.trainer.ema import EMAHelper
from bddm.trainer.bmuf import BmufTrainer
from bddm.trainer.loss import ScoreLoss, StepLoss
from bddm.utils.log_utils import log
from bddm.utils.check_utils import check_score_network
from bddm.utils.diffusion_utils import compute_diffusion_params
from bddm.models import get_score_network, get_schedule_network
from bddm.loader.dataset import create_train_and_valid_dataloader


class Trainer(object):

    def __init__(self, config):
        """
        Trainer Class, implements a general multi-GPU training framework in PyTorch

        Parameters:
            config (namespace): BDDM Configuration
        """
        self.config = config
        self.exp_dir = config.exp_dir
        self.clip = config.grad_clip
        self.load = config.load
        self.model = get_score_network(config).cuda()
        # Define training target
        if self.config.resume_training and self.load != '':
            self.training_target = 'score_nets'
        else:
            score_net_trained, score_net_path = check_score_network(config)
            self.training_target = 'schedule_nets' if score_net_trained else 'score_nets'
        torch.autograd.set_detect_anomaly(True)
        # Initialize diffusion parameters using a pre-specified linear schedule
        noise_schedule = torch.linspace(config.beta_0, config.beta_T, config.T).cuda()
        self.diff_params = compute_diffusion_params(noise_schedule)
        if self.training_target == 'schedule_nets':
            if self.load == '':
                self.load = score_net_path
            self.diff_params["tau"] = config.tau
            for p in self.model.parameters():
                p.requires_grad = False
            # Define the schedule net as a sub-module of the score net for convenience
            self.model.schedule_net = get_schedule_network(config).cuda()
            self.loss_func = StepLoss(config, self.diff_params)
            self.n_training_steps = config.schedule_net_training_steps
            # In practice using batch size = 1 would lead to much lower step loss
            config.batch_size = 1
            model_to_train = self.model.schedule_net
        else:
            self.loss_func = ScoreLoss(config, self.diff_params)
            self.n_training_steps = config.score_net_training_steps
            model_to_train = self.model
        # Define optimizer
        self.optimizer = torch.optim.AdamW(model_to_train.parameters(),
            lr=config.lr, weight_decay=config.weight_decay, amsgrad=True)
        # Define EMA training helper
        self.ema_helper = EMAHelper(mu=config.ema_rate)
        self.ema_helper.register(model_to_train)
        # Initialize BMUF trainer
        self.world_size = int(os.environ['WORLD_SIZE'])
        self.bmuf_trainer = BmufTrainer(0, config.local_rank, self.world_size,
            model_to_train, config.bmuf_block_momentum, config.bmuf_block_lr)
        self.sync_period = config.bmuf_sync_period
        self.device = torch.device("cuda:{}".format(config.local_rank))
        self.local_rank = config.local_rank
        # Get data loaders
        self.tr_loader, self.vl_loader = create_train_and_valid_dataloader(config)
        if self.training_target == 'score_nets':
            # Define a Sampler for quality validation (should be added after BMUF)
            self.valid_sampler = Sampler(config)
            self.valid_sampler.model = get_score_network(config).cuda()
        self.reset()

    def reset(self):
        """
        Reset training environment
        """
        self.tr_loss, self.vl_loss = [], []
        self.training_step = 0
        if self.load != '':
            package = torch.load(self.load, map_location=lambda storage, loc: storage.cuda())
            init_state_dict = self.model.state_dict()
            mismatch_params = set()
            # Remove the checkpoint params that are not found in model
            for key in list(package['model_state_dict'].keys()):
                if key not in init_state_dict.keys():
                    param = copy.deepcopy(package['model_state_dict'][key])
                    del package['model_state_dict'][key]
                    log('ignored: %s in checkpoint not found in model'%key, self.config)
                elif package['model_state_dict'][key].size() != init_state_dict[key].size():
                    log(package['model_state_dict'][key].size(), self.config)
                    log(init_state_dict[key].size(), self.config)
                    log('ignored: %s in checkpoint size mismatched'%key, self.config)
                    del package['model_state_dict'][key]
            # Replace the ignored checkpoint params by the init params
            for key in list(init_state_dict.keys()):
                if key not in package['model_state_dict'].keys():
                    mismatch_params.add(key)
                    log('ignored: %s in model not found in checkpoint'%key, self.config)
                    package['model_state_dict'][key] = init_state_dict[key]
            self.model.load_state_dict(package['model_state_dict'])
            if self.config.resume_training and len(mismatch_params) == 0:
                # Load steps to resume training
                if self.training_target == 'score_nets' and 'score_net_training_step' in package:
                    self.training_step = package['score_net_training_step']
                elif self.training_target == 'schedule_nets' and 'schedule_net_training_step' in package:
                    self.training_step = package['schedule_net_training_step']
            if self.config.freeze_checkpoint_params and len(mismatch_params) > 0:
                # Only update new parameters defined in model
                for key, param in self.model.named_parameters():
                    if key not in mismatch_params:
                        param.requires_grad = False
            log('Loaded checkpoint %s' % self.load, self.config)
        # Create save folder
        os.makedirs(self.exp_dir, exist_ok=True)
        self.prev_val_loss, self.min_val_loss = float("inf"), float("inf")
        self.val_no_impv, self.halving = 0, 0

    def train(self):
        """
        Start the main training process
        """
        best_state = copy.deepcopy(self.model.state_dict())
        while self.training_step < self.n_training_steps:
            # Train one epoch
            log("Start training %s from step %d ......."%(
                self.training_target, self.training_step), self.config)
            self.bmuf_trainer.check_all_processes_running()
            self.model.train()
            start = time.time()
            tr_avg_loss = self._run_one_epoch(validate=False)
            log('-' * 85, self.config)
            log('Train Summary | Step {} | Time {:.2f}s | Train Loss {:.5f}'.format(
                self.training_step, time.time()-start, tr_avg_loss), self.config)
            log('-' * 85, self.config)
            # Start validation
            log('Start validation ......', self.config)
            self.model.eval()
            start = time.time()
            with torch.no_grad():
                val_loss = self._run_one_epoch(validate=True)
            log('-' * 85, self.config)
            log('Valid Summary | Step {} | Time {:.2f}s | Valid Loss {:.5f}'.format(
                self.training_step, time.time()-start, val_loss.item()), self.config)
            log('-' * 85, self.config)
            save_or_not = self.bmuf_trainer.update_and_sync(val_loss=val_loss)
            if save_or_not:
                self.val_no_impv = 0
                if self.bmuf_trainer.rank == self.bmuf_trainer.master_node:
                    model_serialized = self.serialize()
                    file_path = os.path.join(self.exp_dir, self.training_target,
                                             '%d.pkl' % self.training_step)
                    torch.save(model_serialized, file_path)
                    log("Found better model, saved to %s" % file_path, self.config)
            if val_loss >= self.min_val_loss:
                # LR decays
                self.val_no_impv += 1
                if self.val_no_impv == self.config.patience:
                    log("No imporvement for %d epochs, early stopped!"%(
                        self.config.patience), self.config)
                    self.bmuf_trainer.kill_all_processes()
                    break
                if self.val_no_impv >= self.config.patience // 2:
                    self.model.load_state_dict(best_state)
            else:
                self.val_no_impv = 0
                self.min_val_loss = val_loss
                best_state = copy.deepcopy(self.ema_helper.state_dict())

    def _run_one_epoch(self, validate=False):
        """
        Run one epoch

        Parameters:
            validate (bool):      whether to run a valiation epoch or a training epoch
        Returns:
            average loss (float): the average training/validation loss
        """
        start = time.time()
        total_loss, total_cnt = 0, 0
        if validate and self.training_target == 'score_nets':
            # Use EMA state dict for validation
            self.valid_sampler.model.load_state_dict(self.ema_helper.state_dict())
            # To validate score nets, we use Sampler to test sample quality instead of loss
            generated_audio, _ = self.valid_sampler.sampling()
            # Compute objective scores (score = PESQ)
            quality_score = self.valid_sampler.assess(generated_audio)[0]
            return - torch.FloatTensor([quality_score])[0].cuda()
        data_loader = self.vl_loader if validate else self.tr_loader
        data_loader.dataset.reset()
        start_step = self.training_step
        for i, batch in enumerate(data_loader):
            mels, audios = list(map(lambda x: x.cuda(), batch))
            loss = self.loss_func(self.model, mels, audios)
            total_loss += loss.detach().sum()
            total_cnt += len(loss)
            avg_loss = loss.mean()
            if not validate:
                self.optimizer.zero_grad()
                avg_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
                self.optimizer.step()
                # Apply block momentum and sync parameters
                self.bmuf_trainer.update_and_sync()
                self.bmuf_trainer.check_all_processes_running()
                # n_gpus * batch_size
                self.training_step += len(loss) * self.bmuf_trainer.world_size
                if i % self.config.log_period == 0:
                    log('Train Step {} | Avg. Loss {:.5f} | New Loss {:.5f} | {:.2f}s/step'.format(
                        self.training_step, total_loss / total_cnt, avg_loss,
                        (time.time() - start) / (self.training_step - start_step)), self.config)
                if self.training_target == 'schedule_nets':
                    self.ema_helper.update(self.model.schedule_net)
                else:
                    self.ema_helper.update(self.model)
                if self.training_step >= self.n_training_steps or\
                        max(i, self.training_step - start_step) >= self.config.steps_per_epoch:
                    # Release grad memory
                    self.optimizer.zero_grad()
                    return total_loss / total_cnt
            else:
                if i % self.config.log_period == 0:
                    log('Valid Step {} | Avg. Loss {:.5f} | New Loss {:.5f} | {:.2f}s/step'.format(
                        i + 1, total_loss/total_cnt, avg_loss,
                        (time.time() - start) / (i + 1)), self.config)
        if not validate:
            # Release grad memory
            self.optimizer.zero_grad()
        torch.cuda.empty_cache()
        return total_loss / total_cnt

    def serialize(self):
        """
        Pack the model and configurations into a dictionary

        Returns:
            package (dict): the serialized package to be saved
        """
        if self.training_target == 'schedule_nets':
            model_state = copy.deepcopy(self.model.state_dict())
            ema_state = copy.deepcopy(self.ema_helper.state_dict())
            for p in self.ema_helper.state_dict():
                model_state['schedule_net.'+p] =  ema_state[p]
        else:
            model_state = copy.deepcopy(self.ema_helper.state_dict())
        if self.config.save_fp16:
            for p in model_state:
                model_state[p] = model_state[p].half()
        package = {
            # hyper-parameter
            'config': self.config,
            # state
            'model_state_dict': model_state
        }
        if self.training_target == 'score_nets':
            package['score_net_training_step'] = self.training_step
            package['schedule_net_training_step'] = 0
        else:
            package['score_net_training_step'] = self.config.score_net_training_steps
            package['schedule_net_training_step'] = self.training_step
        return package


================================================
FILE: bddm/utils/check_utils.py
================================================
#!/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
#  Check Utils: Find Checkpoints
#
#  Author: Max W. Y. Lam (maxwylam@tencent.com)
#  Copyright (c) 2021Tencent. All Rights Reserved
#
########################################################################


import os
import glob
import torch


def check_score_network(config):
    """
    Check if the score network is trained by searching the ${exp_dir}/score_nets

    Parameters:
        config (namespace): the configuration given by the user
    Returns:
        result (bool):      a boolean to determine if the score network is trained
        path (str):         the path to the score network checkpoint if trained
    """
    if config.load != '':
        ckpt = torch.load(config.load)
        if 'score_net_training_step' in ckpt.keys():
            if ckpt['score_net_training_step'] >= config.score_net_training_steps:
                return True, config.load
        else:
            # We suppose that an external checkpoint is already well-trained
            return True, config.load
    max_training_steps = 0
    for path in glob.glob(os.path.join(config.exp_dir, 'score_nets', '[0-9]*.pkl')):
        n_training_steps = int(path.split('/')[-1].split('.')[0])
        max_training_steps = max(max_training_steps, n_training_steps)
    if max_training_steps == 0:
        return False, None
    load_path = os.path.join(config.exp_dir, 'score_nets', '%d.pkl'%(max_training_steps))
    return True, load_path


def check_schedule_network(config):
    """
    Check if the schedule network is trained by searching the ${exp_dir}/schedule_nets

    Parameters:
        config (namespace): the configuration given by the user
    Returns:
        result (bool):      a boolean to determine if the schedule network is trained
        path (str):         the path to the BDDM checkpoint if trained
    """
    if config.load != '':
        ckpt = torch.load(config.load)
        if 'schedule_net_training_step' in ckpt.keys():
            return True, config.load
        else:
            return False, config.load
    max_training_steps = 0
    for path in glob.glob(os.path.join(config.exp_dir, 'schedule_nets', '[0-9]*.pkl')):
        n_training_steps = int(path.split('/')[-1].split('.')[0])
        max_training_steps = max(max_training_steps, n_training_steps)
    if max_training_steps == 0:
        # We suppose that an external checkpoint only contains a well-trained score network
        return False, config.load
    schedule_net_load_path = os.path.join(
        config.exp_dir, 'schedule_nets', '%d.pkl'%(max_training_steps))
    return True, schedule_net_load_path


================================================
FILE: bddm/utils/diffusion_utils.py
================================================
#!/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
#  Diffusion Utils: Pre-compute Variables
#
#  Author: Max W. Y. Lam (maxwylam@tencent.com)
#  Copyright (c) 2021Tencent. All Rights Reserved
#
########################################################################


import torch


def compute_diffusion_params(beta):
    """
    Compute the diffusion parameters defined in BDDMs

    Parameters:
        beta (tensor):      the beta schedule
    Returns:
        diff_params (dict): a dictionary of diffusion hyperparameters including:
            T (int), beta/alpha/sigma (torch.tensor on cpu, shape=(T, ))
            These cpu tensors are changed to cuda tensors on each individual gpu
    """

    alpha = 1 - beta
    sigma = beta + 0
    for t in range(1, len(beta)):
        alpha[t] *= alpha[t-1]
        sigma[t] *= (1-alpha[t-1]) / (1-alpha[t])
    alpha = torch.sqrt(alpha)
    sigma = torch.sqrt(sigma)
    diff_params = {"T": len(beta), "beta": beta, "alpha": alpha, "sigma": sigma}
    return diff_params


def map_noise_scale_to_time_step(alpha_infer, alpha):
    """
    Map an alpha_infer to an approx. time step in the alpha tensor.
        (Modified from Algorithm 3 Fast Sampling in DiffWave)

    Parameters:
        alpha_infer (float): noise scale at time `n` for inference
        alpha (tensor):      noise scales used in training, shape=(T, )
    Returns:
        t (float):           approximated time step in alpha tensor
    """

    if alpha_infer < alpha[-1]:
        return len(alpha) - 1
    if alpha_infer > alpha[0]:
        return 0
    for t in range(len(alpha) - 1):
        if alpha[t+1] <= alpha_infer <= alpha[t]:
            step_diff = alpha[t] - alpha_infer
            step_diff /= alpha[t] - alpha[t+1]
            return t + step_diff.item()
    return -1


================================================
FILE: bddm/utils/log_utils.py
================================================
#!/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
#  Log Utils: Print Log with Time/PID
#
#  Author: Max W. Y. Lam (maxwylam@tencent.com)
#  Copyright (c) 2021Tencent. All Rights Reserved
#
########################################################################

import os
import sys
import time


def ctime():
    """
    Get time now

    Returns:
        time_string (str): current time in string
    """
    return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))


def head():
    """
    Get header for logging: time and PID of the current process

    Returns:
        header_string (str): the header string for logging
    """
    return "%s %d" % (ctime(), os.getpid())


def log(msg, config):
    """
    Save log to the device[id].log & print device0.log to STDOUT

    Parameters:
        msg (sting):        message to be logged
        config (namespace): BDDM Configuration
    """
    if config.local_rank == 0:
        sys.stdout.write("[%s] %s\n" % (head(), msg))
        sys.stdout.flush()
    os.makedirs(config.exp_dir, exist_ok=True)
    with open(os.path.join(config.exp_dir, 'device%d.log'%(config.local_rank)), 'a') as f:
        f.write("[%s] %s\n" % (head(), msg))
        f.flush()


================================================
FILE: egs/lj/DiffWave.pkl
================================================
[File too large to display: 26.6 MB]

================================================
FILE: egs/lj/DiffWave_GALR.pkl
================================================
[File too large to display: 15.4 MB]

================================================
FILE: egs/lj/conf.yml
================================================
# General config
exp_dir: './exp'
seed: 0
load: 'DiffWave_GALR.pkl'
# Generation config
sampling_noise_schedule: './noise_schedules/12steps_PESQ4.07_STOI0.978.ns'
gen_data_dir: "./demo_case"
#gen_data_dir: "./test_wavs"
use_ddim_steps: 0 # options: 0-DDPM, 1-DDIM, M-DDIM M steps (when noise schedule is not given)
# Diffusion config
T: 200
beta_0: 0.0001
beta_T: 0.02
tau: 66
# Noise scheduling config
N: 12
bddm_search_bins: 9
noise_scheduling_attempts: 10
# Score network config
score_net: 'DiffWave'
num_res_layers: 30
in_channels: 1
res_channels: 128
skip_channels: 128
dilation_cycle: 10
out_channels: 1
diffusion_step_embed_dim_in: 128
diffusion_step_embed_dim_mid: 512
diffusion_step_embed_dim_out: 512
# Schedule network config
schedule_net: 'GALR'
blocks: 1
hidden_dim: 128
input_dim: 128
window_length: 8
segment_size: 64
# Trainer config
resume_training: False
save_fp16: True
steps_per_epoch: 1000
score_net_training_steps: 5000000
schedule_net_training_steps: 10000
lr: 0.00001
freeze_checkpoint_params: False # only used when provided load path
save_generated_dir: 'gen_wav'
batch_size: 1
patience: 100
grad_clip: 5.0
weight_decay: 0.00001
log_period: 1
bmuf_block_momentum: 0.5
bmuf_block_lr: 0.7
bmuf_sync_period: 2
ema_rate: 0.999
# Data config
train_data_dir: "LJSpeech-1.1/train_wavs"
valid_data_dir: "LJSpeech-1.1/val_wavs"
n_worker: 10
sampling_rate: 22050
seg_len: 16000
fil_len: 1024
hop_len: 256
win_len: 1024
mel_fmin: 0.0
mel_fmax: 8000.0


================================================
FILE: egs/lj/eval.list
================================================
LJ001-0001.wav
LJ001-0047.wav
LJ003-0140.wav
LJ003-0145.wav
LJ003-0235.wav
LJ003-0240.wav
LJ004-0195.wav
LJ005-0059.wav
LJ005-0206.wav
LJ005-0269.wav
LJ006-0141.wav
LJ006-0301.wav
LJ007-0104.wav
LJ007-0106.wav
LJ007-0192.wav
LJ007-0208.wav
LJ008-0043.wav
LJ008-0265.wav
LJ009-0071.wav
LJ009-0190.wav
LJ010-0056.wav
LJ010-0082.wav
LJ011-0053.wav
LJ011-0062.wav
LJ011-0221.wav
LJ012-0005.wav
LJ012-0015.wav
LJ013-0004.wav
LJ013-0060.wav
LJ015-0063.wav
LJ015-0254.wav
LJ016-0070.wav
LJ016-0318.wav
LJ017-0049.wav
LJ017-0251.wav
LJ018-0029.wav
LJ018-0071.wav
LJ018-0105.wav
LJ019-0329.wav
LJ019-0350.wav
LJ019-0358.wav
LJ020-0086.wav
LJ020-0102.wav
LJ022-0164.wav
LJ022-0185.wav
LJ023-0117.wav
LJ024-0103.wav
LJ025-0143.wav
LJ026-0030.wav
LJ027-0020.wav
LJ028-0130.wav
LJ028-0155.wav
LJ028-0302.wav
LJ028-0432.wav
LJ028-0465.wav
LJ028-0497.wav
LJ028-0499.wav
LJ029-0018.wav
LJ030-0088.wav
LJ031-0072.wav
LJ032-0010.wav
LJ032-0057.wav
LJ035-0061.wav
LJ035-0124.wav
LJ035-0200.wav
LJ036-0093.wav
LJ037-0061.wav
LJ037-0186.wav
LJ037-0195.wav
LJ037-0237.wav
LJ038-0008.wav
LJ038-0051.wav
LJ038-0092.wav
LJ039-0072.wav
LJ040-0210.wav
LJ040-0216.wav
LJ041-0075.wav
LJ042-0092.wav
LJ042-0159.wav
LJ042-0181.wav
LJ042-0224.wav
LJ042-0248.wav
LJ043-0026.wav
LJ043-0150.wav
LJ045-0123.wav
LJ045-0128.wav
LJ045-0147.wav
LJ045-0216.wav
LJ046-0026.wav
LJ046-0043.wav
LJ046-0162.wav
LJ046-0186.wav
LJ047-0058.wav
LJ047-0114.wav
LJ047-0131.wav
LJ048-0022.wav
LJ049-0155.wav
LJ049-0165.wav
LJ050-0043.wav
LJ050-0134.wav


================================================
FILE: egs/lj/generate.sh
================================================
#!/bin/bash

## example usage (single GPU): sh generate.sh 0 conf.yml

cuda=$1

CUDA_VISIBLE_DEVICES=$cuda python3 main.py \
	--command generate --config $2


================================================
FILE: egs/lj/main.py
================================================
import os
import sys
sys.path.append('../../')
import json
import shutil
import hashlib
import argparse
import numpy as np
import yaml

import torch
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

from bddm import trainer, sampler
from bddm.utils.log_utils import log


def dict_hash_5char(dictionary):
    ''' Map a unique dictionary into a 5-character string '''
    dhash = hashlib.md5()
    encoded = json.dumps(dictionary, sort_keys=True).encode()
    dhash.update(encoded)
    return dhash.hexdigest()[:5]


def start_exp(config, config_hash):
    ''' Create experiment directory or set it to an existing directory '''
    if config.load != '' and '_nets' in config.load:
        config.exp_dir = '/'.join(config.load.split('/')[:-2])
    else:
        config.exp_dir += '/%s-%s_conf-hash-%s' % (
            config.score_net, config.schedule_net, config_hash)
    if config.local_rank != 0:
        return
    log('Experiment directory: %s' % (config.exp_dir), config)
    # Backup the config file
    shutil.copyfile(config.config, os.path.join(config.exp_dir, 'conf.yml'))
    # Create a backup scripts sub-folder
    os.makedirs(os.path.join(config.exp_dir, 'backup_scripts'), exist_ok=True)
    # Backup all .py files under bddm/
    backup_files = []
    for root, _, files in os.walk("../../"):
        if 'egs' in root:
            continue
        for f in files:
            if f.endswith(".py"):
                backup_files.append(os.path.join(root, f))
    for src_file in backup_files:
        basename = src_file
        while '../' in basename:
            basename = basename.replace('../', '')
        basename = basename.replace('./', '')
        dst_file = os.path.join(config.exp_dir, 'backup_scripts', basename)
        dst_dir = '/'.join(dst_file.split('/')[:-1])
        if not os.path.exists(dst_dir):
            os.makedirs(dst_dir)
        shutil.copyfile(src_file, dst_file)
    # Prepare sub-folders for saving model checkpoints
    os.makedirs(os.path.join(config.exp_dir, 'score_nets'), exist_ok=True)
    os.makedirs(os.path.join(config.exp_dir, 'schedule_nets'), exist_ok=True)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Bilateral Denoising Diffusion Models')
    parser.add_argument('--command',
                        type=str,
                        default='train',
                        help='available commands: train | search | generate')
    parser.add_argument('--config',
                        '-c',
                        type=str,
                        default='conf.yml',
                        help='config .yml path')
    parser.add_argument('--local_rank',
                        type=int,
                        default=0,
                        help='process device ID for multi-GPU training')

    arg_config = parser.parse_args()

    # Parse yaml and define configurations
    config = arg_config.__dict__
    with open(arg_config.config) as f:
        yaml_config = yaml.safe_load(f)
    HASH = dict_hash_5char(yaml_config)
    for key in yaml_config:
        config[key] = yaml_config[key]
    config = argparse.Namespace(**config)

    # Set random seed for reproducible results
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)
    torch.cuda.set_device(config.local_rank)

    # Check if the command is valid or not
    commands = ['train', 'schedule', 'generate']
    assert config.command in commands, 'Error: %s command not found.'%(config.command)

    # Create/retrieve exp dir
    start_exp(config, HASH)
    log('Argv: %s' % (' '.join(sys.argv)), config)

    try:
        if config.command == 'train':
            # Create Trainer for training
            trainer = trainer.Trainer(config)
            trainer.train()
        elif config.command == 'schedule':
            # Create Sampler for noise scheduling
            sampler = sampler.Sampler(config)
            sampler.noise_scheduling_without_params()
        elif config.command == 'generate':
            # Create Sampler for generation
            # NOTE: Remember to define "gen_data_dir" in conf.yml before data generation
            sampler = sampler.Sampler(config)
            sampler.generate()
        log('-' * 80, config)

    except KeyboardInterrupt:
        log('-' * 80, config)
        log('Exiting early', config)


================================================
FILE: egs/lj/schedule.sh
================================================
#!/bin/bash

## example usage (single GPU): sh schedule.sh 0 conf.yml

cuda=$1

CUDA_VISIBLE_DEVICES=$cuda python3 main.py \
	--command schedule --config $2


================================================
FILE: egs/lj/train.sh
================================================
#!/bin/bash

## example usage: sh train.sh 0,1,2,3 conf.yml

cuda=$1
comma=${cuda//[^,]}
nproc=$((${#comma}+1))

CUDA_VISIBLE_DEVICES=$cuda python3 -m torch.distributed.launch \
	--nproc_per_node $nproc --master_port $RANDOM main.py \
	--command train --config $2


================================================
FILE: egs/lj/valid.list
================================================
LJ001-0023.wav
LJ001-0084.wav
LJ001-0144.wav
LJ001-0163.wav
LJ001-0179.wav
LJ002-0042.wav
LJ002-0140.wav
LJ002-0201.wav
LJ002-0245.wav
LJ003-0004.wav
LJ003-0009.wav
LJ003-0143.wav
LJ003-0227.wav
LJ003-0305.wav
LJ003-0316.wav
LJ003-0336.wav
LJ004-0059.wav
LJ004-0098.wav
LJ004-0130.wav
LJ004-0140.wav
LJ004-0173.wav
LJ004-0199.wav
LJ005-0071.wav
LJ005-0108.wav
LJ005-0226.wav
LJ006-0053.wav
LJ006-0055.wav
LJ006-0081.wav
LJ006-0099.wav
LJ006-0103.wav
LJ006-0107.wav
LJ006-0152.wav
LJ007-0169.wav
LJ007-0232.wav
LJ008-0007.wav
LJ008-0053.wav
LJ008-0110.wav
LJ008-0200.wav
LJ008-0263.wav
LJ008-0280.wav
LJ008-0281.wav
LJ009-0068.wav
LJ009-0158.wav
LJ009-0179.wav
LJ009-0184.wav
LJ009-0185.wav
LJ009-0234.wav
LJ009-0235.wav
LJ009-0304.wav
LJ010-0011.wav
LJ010-0057.wav
LJ010-0106.wav
LJ010-0191.wav
LJ010-0282.wav
LJ010-0314.wav
LJ011-0064.wav
LJ011-0076.wav
LJ011-0130.wav
LJ011-0144.wav
LJ011-0249.wav
LJ011-0275.wav
LJ012-0046.wav
LJ012-0061.wav
LJ012-0100.wav
LJ012-0133.wav
LJ012-0145.wav
LJ012-0173.wav
LJ012-0225.wav
LJ012-0288.wav
LJ012-0290.wav
LJ013-0080.wav
LJ013-0096.wav
LJ013-0103.wav
LJ013-0137.wav
LJ013-0170.wav
LJ013-0247.wav
LJ013-0258.wav
LJ014-0063.wav
LJ014-0148.wav
LJ014-0202.wav
LJ014-0205.wav
LJ014-0278.wav
LJ014-0294.wav
LJ014-0309.wav
LJ014-0314.wav
LJ015-0136.wav
LJ015-0177.wav
LJ015-0295.wav
LJ016-0214.wav
LJ016-0237.wav
LJ016-0284.wav
LJ016-0332.wav
LJ017-0047.wav
LJ017-0050.wav
LJ017-0078.wav
LJ017-0243.wav
LJ018-0047.wav
LJ018-0055.wav
LJ018-0142.wav
LJ018-0262.wav
LJ018-0337.wav
LJ018-0361.wav
LJ019-0017.wav
LJ019-0058.wav
LJ019-0074.wav
LJ019-0080.wav
LJ019-0097.wav
LJ019-0123.wav
LJ019-0187.wav
LJ019-0253.wav
LJ019-0274.wav
LJ019-0317.wav
LJ020-0025.wav
LJ021-0132.wav
LJ021-0137.wav
LJ021-0180.wav
LJ022-0004.wav
LJ022-0133.wav
LJ022-0148.wav
LJ022-0152.wav
LJ023-0003.wav
LJ023-0033.wav
LJ024-0139.wav
LJ024-0141.wav
LJ025-0056.wav
LJ025-0104.wav
LJ025-0117.wav
LJ026-0022.wav
LJ027-0063.wav
LJ027-0075.wav
LJ028-0026.wav
LJ028-0083.wav
LJ028-0116.wav
LJ028-0142.wav
LJ028-0293.wav
LJ028-0343.wav
LJ028-0387.wav
LJ028-0488.wav
LJ028-0502.wav
LJ029-0035.wav
LJ029-0060.wav
LJ030-0013.wav
LJ030-0121.wav
LJ030-0154.wav
LJ030-0241.wav
LJ031-0083.wav
LJ031-0123.wav
LJ031-0205.wav
LJ032-0078.wav
LJ032-0119.wav
LJ032-0179.wav
LJ032-0220.wav
LJ033-0035.wav
LJ033-0066.wav
LJ033-0085.wav
LJ033-0143.wav
LJ034-0056.wav
LJ034-0073.wav
LJ034-0178.wav
LJ034-0206.wav
LJ035-0002.wav
LJ035-0079.wav
LJ036-0007.wav
LJ036-0013.wav
LJ036-0015.wav
LJ036-0067.wav
LJ036-0173.wav
LJ037-0056.wav
LJ037-0067.wav
LJ037-0075.wav
LJ037-0085.wav
LJ037-0102.wav
LJ037-0145.wav
LJ037-0163.wav
LJ038-0015.wav
LJ038-0019.wav
LJ038-0071.wav
LJ038-0083.wav
LJ038-0282.wav
LJ038-0294.wav
LJ038-0303.wav
LJ039-0076.wav
LJ039-0088.wav
LJ040-0034.wav
LJ040-0091.wav
LJ040-0123.wav
LJ040-0229.wav
LJ041-0103.wav
LJ041-0113.wav
LJ042-0037.wav
LJ042-0049.wav
LJ042-0060.wav
LJ042-0120.wav
LJ042-0127.wav
LJ042-0147.wav
LJ042-0161.wav
LJ042-0201.wav
LJ043-0010.wav
LJ043-0074.wav
LJ043-0089.wav
LJ043-0128.wav
LJ044-0076.wav
LJ044-0139.wav
LJ044-0140.wav
LJ044-0200.wav
LJ044-0209.wav
LJ045-0184.wav
LJ045-0191.wav
LJ046-0169.wav
LJ046-0203.wav
LJ047-0031.wav
LJ047-0120.wav
LJ047-0132.wav
LJ047-0153.wav
LJ047-0178.wav
LJ047-0192.wav
LJ047-0201.wav
LJ047-0238.wav
LJ048-0046.wav
LJ048-0093.wav
LJ048-0119.wav
LJ048-0139.wav
LJ048-0193.wav
LJ048-0207.wav
LJ048-0270.wav
LJ049-0044.wav
LJ049-0089.wav
LJ049-0140.wav
LJ049-0179.wav
LJ050-0012.wav
LJ050-0049.wav
LJ050-0070.wav
LJ050-0078.wav
LJ050-0136.wav
LJ050-0179.wav
LJ050-0270.wav


================================================
FILE: egs/vctk/DiffWave.pkl
================================================
[File too large to display: 13.4 MB]

================================================
FILE: egs/vctk/DiffWave_GALR.pkl
================================================
[File too large to display: 15.4 MB]

================================================
FILE: egs/vctk/conf.yml
================================================
# General config
exp_dir: './exp'
seed: 0
load: 'DiffWave_GALR.pkl'
# Generation config
sampling_noise_schedule: './noise_schedules/8steps_PESQ3.88_STOI0.987.ns'
gen_data_dir: "./demo_case"
#gen_data_dir: "./test_wavs"
use_ddim_steps: 0 # options: 0-DDPM, 1-DDIM, M-DDIM M steps (when noise schedule is not given)
# Diffusion config
T: 200
beta_0: 0.0001
beta_T: 0.02
tau: 25
# Noise scheduling config
N: 8
bddm_search_bins: 9
noise_scheduling_attempts: 10
# Score network config
score_net: 'DiffWave'
num_res_layers: 30
in_channels: 1
res_channels: 128
skip_channels: 128
dilation_cycle: 10
out_channels: 1
diffusion_step_embed_dim_in: 128
diffusion_step_embed_dim_mid: 512
diffusion_step_embed_dim_out: 512
# Schedule network config
schedule_net: 'GALR'
blocks: 1
hidden_dim: 128
input_dim: 128
window_length: 8
segment_size: 64
# Trainer config
resume_training: False
save_fp16: True
steps_per_epoch: 1000
score_net_training_steps: 5000000
schedule_net_training_steps: 5000
lr: 0.00001
freeze_checkpoint_params: False # only used when provided load path
save_generated_dir: 'gen_wav'
batch_size: 1
patience: 100
grad_clip: 5.0
weight_decay: 0.00001
log_period: 1
bmuf_block_momentum: 0.5
bmuf_block_lr: 0.7
bmuf_sync_period: 2
ema_rate: 0.999
# Data config
train_data_dir: "VCTK-Corpus/wav"
valid_data_dir: "val_case"
n_worker: 10
sampling_rate: 22050
seg_len: 16000
fil_len: 1024
hop_len: 256
win_len: 1024
mel_fmin: 0.0
mel_fmax: 8000.0


================================================
FILE: egs/vctk/generate.sh
================================================
#!/bin/bash

## example usage (single GPU): sh generate.sh 0 conf.yml

cuda=$1

CUDA_VISIBLE_DEVICES=$cuda python3 main.py \
	--command generate --config $2


================================================
FILE: egs/vctk/main.py
================================================
import os
import sys
sys.path.append('../../')
import json
import shutil
import hashlib
import argparse
import numpy as np
import yaml

import torch
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

from bddm import trainer, sampler
from bddm.utils.log_utils import log


def dict_hash_5char(dictionary):
    ''' Map a unique dictionary into a 5-character string '''
    dhash = hashlib.md5()
    encoded = json.dumps(dictionary, sort_keys=True).encode()
    dhash.update(encoded)
    return dhash.hexdigest()[:5]


def start_exp(config, config_hash):
    ''' Create experiment directory or set it to an existing directory '''
    if config.load != '' and '_nets' in config.load:
        config.exp_dir = '/'.join(config.load.split('/')[:-2])
    else:
        config.exp_dir += '/%s-%s_conf-hash-%s' % (
            config.score_net, config.schedule_net, config_hash)
    if config.local_rank != 0:
        return
    log('Experiment directory: %s' % (config.exp_dir), config)
    # Backup the config file
    shutil.copyfile(config.config, os.path.join(config.exp_dir, 'conf.yml'))
    # Create a backup scripts sub-folder
    os.makedirs(os.path.join(config.exp_dir, 'backup_scripts'), exist_ok=True)
    # Backup all .py files under bddm/
    backup_files = []
    for root, _, files in os.walk("../../"):
        if 'egs' in root:
            continue
        for f in files:
            if f.endswith(".py"):
                backup_files.append(os.path.join(root, f))
    for src_file in backup_files:
        basename = src_file
        while '../' in basename:
            basename = basename.replace('../', '')
        basename = basename.replace('./', '')
        dst_file = os.path.join(config.exp_dir, 'backup_scripts', basename)
        dst_dir = '/'.join(dst_file.split('/')[:-1])
        if not os.path.exists(dst_dir):
            os.makedirs(dst_dir)
        shutil.copyfile(src_file, dst_file)
    # Prepare sub-folders for saving model checkpoints
    os.makedirs(os.path.join(config.exp_dir, 'score_nets'), exist_ok=True)
    os.makedirs(os.path.join(config.exp_dir, 'schedule_nets'), exist_ok=True)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Bilateral Denoising Diffusion Models')
    parser.add_argument('--command',
                        type=str,
                        default='train',
                        help='available commands: train | search | generate')
    parser.add_argument('--config',
                        '-c',
                        type=str,
                        default='conf.yml',
                        help='config .yml path')
    parser.add_argument('--local_rank',
                        type=int,
                        default=0,
                        help='process device ID for multi-GPU training')

    arg_config = parser.parse_args()

    # Parse yaml and define configurations
    config = arg_config.__dict__
    with open(arg_config.config) as f:
        yaml_config = yaml.safe_load(f)
    HASH = dict_hash_5char(yaml_config)
    for key in yaml_config:
        config[key] = yaml_config[key]
    config = argparse.Namespace(**config)

    # Set random seed for reproducible results
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)
    torch.cuda.set_device(config.local_rank)

    # Check if the command is valid or not
    commands = ['train', 'schedule', 'generate']
    assert config.command in commands, 'Error: %s command not found.'%(config.command)

    # Create/retrieve exp dir
    start_exp(config, HASH)
    log('Argv: %s' % (' '.join(sys.argv)), config)

    try:
        if config.command == 'train':
            # Create Trainer for training
            trainer = trainer.Trainer(config)
            trainer.train()
        elif config.command == 'schedule':
            # Create Sampler for noise scheduling
            sampler = sampler.Sampler(config)
            sampler.noise_scheduling_without_params()
        elif config.command == 'generate':
            # Create Sampler for generation
            # NOTE: Remember to define "gen_data_dir" in conf.yml before data generation
            sampler = sampler.Sampler(config)
            sampler.generate()
        log('-' * 80, config)

    except KeyboardInterrupt:
        log('-' * 80, config)
        log('Exiting early', config)


================================================
FILE: egs/vctk/schedule.sh
================================================
#!/bin/bash

## example usage (single GPU): sh schedule.sh 0 conf.yml

cuda=$1

CUDA_VISIBLE_DEVICES=$cuda python3 main.py \
	--command schedule --config $2


================================================
FILE: egs/vctk/train.sh
================================================
#!/bin/bash

## example usage: sh train.sh 0,1,2,3 conf.yml

cuda=$1
comma=${cuda//[^,]}
nproc=$((${#comma}+1))

CUDA_VISIBLE_DEVICES=$cuda python3 -m torch.distributed.launch \
	--nproc_per_node $nproc --master_port $RANDOM main.py \
	--command train --config $2
Download .txt
gitextract_wr3lh52z/

├── .gitignore
├── LICENSE
├── README.md
├── bddm/
│   ├── loader/
│   │   ├── __init__.py
│   │   ├── dataset.py
│   │   └── stft.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── diffwave.py
│   │   └── galr.py
│   ├── sampler/
│   │   ├── __init__.py
│   │   └── sampler.py
│   ├── trainer/
│   │   ├── __init__.py
│   │   ├── bmuf.py
│   │   ├── ema.py
│   │   ├── loss.py
│   │   └── trainer.py
│   └── utils/
│       ├── check_utils.py
│       ├── diffusion_utils.py
│       └── log_utils.py
└── egs/
    ├── lj/
    │   ├── DiffWave.pkl
    │   ├── DiffWave_GALR.pkl
    │   ├── conf.yml
    │   ├── demo_case/
    │   │   ├── LJ001-0001.mel
    │   │   ├── LJ004-0233.mel
    │   │   └── LJ050-0270.mel
    │   ├── eval.list
    │   ├── generate.sh
    │   ├── main.py
    │   ├── noise_schedules/
    │   │   ├── 12steps_PESQ4.07_STOI0.978.ns
    │   │   └── 7steps_PESQ4.05_STOI0.975.ns
    │   ├── schedule.sh
    │   ├── train.sh
    │   └── valid.list
    └── vctk/
        ├── DiffWave.pkl
        ├── DiffWave_GALR.pkl
        ├── conf.yml
        ├── demo_case/
        │   ├── p259_311.mel
        │   └── p287_123.mel
        ├── generate.sh
        ├── main.py
        ├── noise_schedules/
        │   └── 8steps_PESQ3.88_STOI0.987.ns
        ├── schedule.sh
        └── train.sh
Download .txt
SYMBOL INDEX (112 symbols across 15 files)

FILE: bddm/loader/dataset.py
  class SpectrogramDataset (line 27) | class SpectrogramDataset(data.Dataset):
    method __init__ (line 29) | def __init__(self, data_dir, n_gpus, is_sampling, sampling_rate,
    method reset (line 78) | def reset(self):
    method find_all_wavs_in_dir (line 93) | def find_all_wavs_in_dir(self, data_dir):
    method find_all_mels_in_dir (line 107) | def find_all_mels_in_dir(self, data_dir):
    method crop_audio_and_mel (line 119) | def crop_audio_and_mel(self, audio, mel_spec):
    method __getitem__ (line 149) | def __getitem__(self, index):
    method __len__ (line 208) | def __len__(self):
  function create_train_and_valid_dataloader (line 220) | def create_train_and_valid_dataloader(config):
  function create_generation_dataloader (line 259) | def create_generation_dataloader(config):

FILE: bddm/loader/stft.py
  class STFT (line 43) | class STFT(nn.Module):
    method __init__ (line 45) | def __init__(self, filter_length=800, hop_length=200, win_length=800,
    method transform (line 78) | def transform(self, input_data):
    method forward (line 108) | def forward(self, input_data):
  class TacotronSTFT (line 114) | class TacotronSTFT(nn.Module):
    method __init__ (line 118) | def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
    method spectral_normalize (line 130) | def spectral_normalize(self, x):
    method spectral_de_normalize (line 133) | def spectral_de_normalize(self, x):
    method mel_spectrogram (line 136) | def mel_spectrogram(self, y):

FILE: bddm/models/__init__.py
  function get_score_network (line 18) | def get_score_network(config):
  function get_schedule_network (line 25) | def get_schedule_network(config):

FILE: bddm/models/diffwave.py
  function calc_diffusion_step_embedding (line 22) | def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_...
  function swish (line 55) | def swish(x):
  class Conv (line 61) | class Conv(nn.Module):
    method __init__ (line 62) | def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1):
    method forward (line 70) | def forward(self, x):
  class ZeroConv1d (line 77) | class ZeroConv1d(nn.Module):
    method __init__ (line 78) | def __init__(self, in_channel, out_channel):
    method forward (line 84) | def forward(self, x):
  class ResidualBlock (line 91) | class ResidualBlock(nn.Module):
    method __init__ (line 92) | def __init__(self, res_channels, skip_channels, dilation,
    method forward (line 127) | def forward(self, input_data):
  class ResidualGroup (line 166) | class ResidualGroup(nn.Module):
    method __init__ (line 167) | def __init__(self, res_channels, skip_channels, num_res_layers, dilati...
    method forward (line 187) | def forward(self, input_data):
  class DiffWave (line 209) | class DiffWave(nn.Module):
    method __init__ (line 210) | def __init__(self, in_channels, res_channels, skip_channels, out_chann...
    method forward (line 231) | def forward(self, input_data):

FILE: bddm/models/galr.py
  class Encoder (line 21) | class Encoder(nn.Module):
    method __init__ (line 23) | def __init__(self, enc_dim, win_len):
    method forward (line 38) | def forward(self, signals):
  class BiLSTMproj (line 54) | class BiLSTMproj(nn.Module):
    method __init__ (line 56) | def __init__(self, enc_dim, hid_dim):
    method forward (line 75) | def forward(self, intra_segs):
  class AttnPositionalEncoding (line 95) | class AttnPositionalEncoding(nn.Module):
    method __init__ (line 97) | def __init__(self, enc_dim, attn_max_len=5000):
    method forward (line 115) | def forward(self, x):
  class GlobalAttnLayer (line 128) | class GlobalAttnLayer(nn.Module):
    method __init__ (line 130) | def __init__(self, enc_dim, n_attn_head, attn_dropout):
    method forward (line 148) | def forward(self, inter_segs):
  class DeepGlobalAttnLayer (line 165) | class DeepGlobalAttnLayer(nn.Module):
    method __init__ (line 167) | def __init__(self, enc_dim, n_attn_head, attn_dropout, n_attn_layer=1):
    method forward (line 187) | def forward(self, inter_segs):
  class GALRBlock (line 205) | class GALRBlock(nn.Module):
    method __init__ (line 207) | def __init__(self, enc_dim, hid_dim, seg_len, low_dim=8, n_attn_head=8...
    method forward (line 234) | def forward(self, segments):
  class _GALR (line 279) | class _GALR(nn.Module):
    method __init__ (line 281) | def __init__(self, n_block, enc_dim, hid_dim, win_len, seg_len,
    method forward (line 317) | def forward(self, noisy_signals):
    method pad_zeros (line 338) | def pad_zeros(self, signals):
    method pad_segment (line 360) | def pad_segment(self, frames):
    method split_feature (line 383) | def split_feature(self, frames):
  class GALR (line 407) | class GALR(nn.Module):
    method __init__ (line 409) | def __init__(self, blocks=2, input_dim=128, hidden_dim=128, window_len...
    method forward (line 425) | def forward(self, audio, scales):

FILE: bddm/sampler/sampler.py
  class Sampler (line 34) | class Sampler(object):
    method __init__ (line 38) | def __init__(self, config):
    method reset (line 72) | def reset(self):
    method draw_reference_data_pair (line 89) | def draw_reference_data_pair(self):
    method generate (line 96) | def generate(self):
    method noise_scheduling (line 128) | def noise_scheduling(self, ddim=False):
    method sampling (line 197) | def sampling(self, schedule=None, condition=None,
    method noise_scheduling_with_params (line 274) | def noise_scheduling_with_params(self, alpha_param, beta_param):
    method noise_scheduling_without_params (line 336) | def noise_scheduling_without_params(self):
    method assess (line 389) | def assess(self, generated_audio, audio_key=None):

FILE: bddm/trainer/bmuf.py
  class BmufTrainer (line 21) | class BmufTrainer(object):
    method __init__ (line 23) | def __init__(self, master_node, rank, world_size, model, block_momentu...
    method check_all_processes_running (line 56) | def check_all_processes_running(self):
    method get_average_valid_loss (line 65) | def get_average_valid_loss(self, val_loss):
    method update_and_sync (line 83) | def update_and_sync(self, val_loss=None):
    method _copy_vec_to_param (line 116) | def _copy_vec_to_param(self, vec):

FILE: bddm/trainer/ema.py
  class EMAHelper (line 16) | class EMAHelper(object):
    method __init__ (line 18) | def __init__(self, mu=0.999):
    method register (line 28) | def register(self, module):
    method update (line 41) | def update(self, module):
    method ema (line 55) | def ema(self, module):
    method ema_copy (line 68) | def ema_copy(self, module):
    method state_dict (line 87) | def state_dict(self):
    method load_state_dict (line 96) | def load_state_dict(self, state_dict):

FILE: bddm/trainer/loss.py
  class ScoreLoss (line 17) | class ScoreLoss(nn.Module):
    method __init__ (line 19) | def __init__(self, config, diff_params):
    method forward (line 32) | def forward(self, model, mels, audios):
  class StepLoss (line 54) | class StepLoss(nn.Module):
    method __init__ (line 56) | def __init__(self, config, diff_params):
    method forward (line 70) | def forward(self, model, mels, audios):

FILE: bddm/trainer/trainer.py
  class Trainer (line 31) | class Trainer(object):
    method __init__ (line 33) | def __init__(self, config):
    method reset (line 93) | def reset(self):
    method train (line 138) | def train(self):
    method _run_one_epoch (line 189) | def _run_one_epoch(self, validate=False):
    method serialize (line 251) | def serialize(self):

FILE: bddm/utils/check_utils.py
  function check_score_network (line 18) | def check_score_network(config):
  function check_schedule_network (line 46) | def check_schedule_network(config):

FILE: bddm/utils/diffusion_utils.py
  function compute_diffusion_params (line 16) | def compute_diffusion_params(beta):
  function map_noise_scale_to_time_step (line 39) | def map_noise_scale_to_time_step(alpha_infer, alpha):

FILE: bddm/utils/log_utils.py
  function ctime (line 17) | def ctime():
  function head (line 27) | def head():
  function log (line 37) | def log(msg, config):

FILE: egs/lj/main.py
  function dict_hash_5char (line 19) | def dict_hash_5char(dictionary):
  function start_exp (line 27) | def start_exp(config, config_hash):

FILE: egs/vctk/main.py
  function dict_hash_5char (line 19) | def dict_hash_5char(dictionary):
  function start_exp (line 27) | def start_exp(config, config_hash):
Condensed preview — 43 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (143K chars).
[
  {
    "path": ".gitignore",
    "chars": 116,
    "preview": "egs/lj/exp\negs/lj/LJSpeech1.1\n*.log\n*.pyc\n__pycache__/\n*/__pycache__/\n*.zip\n*/*.zip\nsource/\nmake.bat\nMakefile\n*.swp\n"
  },
  {
    "path": "LICENSE",
    "chars": 11337,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 7601,
    "preview": "# Bilateral Denoising Diffusion Models (BDDMs)\n\n[![GitHub Stars](https://img.shields.io/github/stars/tencent-ailab/bddm?"
  },
  {
    "path": "bddm/loader/__init__.py",
    "chars": 302,
    "preview": "#!/bin/env python\n# -*- coding: utf-8 -*-\n########################################################################\n#\n#  "
  },
  {
    "path": "bddm/loader/dataset.py",
    "chars": 11893,
    "preview": "#!/bin/env python\n# -*- coding: utf-8 -*-\n########################################################################\n#\n#  "
  },
  {
    "path": "bddm/loader/stft.py",
    "chars": 5812,
    "preview": "\"\"\"\nBSD 3-Clause License\n\nCopyright (c) 2017, Prem Seetharaman\nAll rights reserved.\n\n* Redistribution and use in source "
  },
  {
    "path": "bddm/models/__init__.py",
    "chars": 940,
    "preview": "#!/bin/env python\n# -*- coding: utf-8 -*-\n########################################################################\n#\n#  "
  },
  {
    "path": "bddm/models/diffwave.py",
    "chars": 9454,
    "preview": "#!/bin/env python\n# -*- coding: utf-8 -*-\n########################################################################\n#\n#  "
  },
  {
    "path": "bddm/models/galr.py",
    "chars": 19247,
    "preview": "#!/bin/env python\n# -*- coding: utf-8 -*-\n########################################################################\n#\n#  "
  },
  {
    "path": "bddm/sampler/__init__.py",
    "chars": 333,
    "preview": "#!/bin/env python\n# -*- coding: utf-8 -*-\n########################################################################\n#\n#  "
  },
  {
    "path": "bddm/sampler/sampler.py",
    "chars": 20624,
    "preview": "#!/bin/env python\n# -*- coding: utf-8 -*-\n########################################################################\n#\n#  "
  },
  {
    "path": "bddm/trainer/__init__.py",
    "chars": 334,
    "preview": "#!/bin/env python\n# -*- coding: utf-8 -*-\n########################################################################\n#\n#  "
  },
  {
    "path": "bddm/trainer/bmuf.py",
    "chars": 5379,
    "preview": "#!/bin/env python\n# -*- coding: utf-8 -*-\n########################################################################\n#\n#  "
  },
  {
    "path": "bddm/trainer/ema.py",
    "chars": 3040,
    "preview": "#!/bin/env python\n# -*- coding: utf-8 -*-\n########################################################################\n#\n#  "
  },
  {
    "path": "bddm/trainer/loss.py",
    "chars": 3862,
    "preview": "#!/bin/env python\n# -*- coding: utf-8 -*-\n########################################################################\n#\n#  "
  },
  {
    "path": "bddm/trainer/trainer.py",
    "chars": 13572,
    "preview": "#!/bin/env python\n# -*- coding: utf-8 -*-\n########################################################################\n#\n#  "
  },
  {
    "path": "bddm/utils/check_utils.py",
    "chars": 2713,
    "preview": "#!/bin/env python\n# -*- coding: utf-8 -*-\n########################################################################\n#\n#  "
  },
  {
    "path": "bddm/utils/diffusion_utils.py",
    "chars": 1874,
    "preview": "#!/bin/env python\n# -*- coding: utf-8 -*-\n########################################################################\n#\n#  "
  },
  {
    "path": "bddm/utils/log_utils.py",
    "chars": 1293,
    "preview": "#!/bin/env python\n# -*- coding: utf-8 -*-\n########################################################################\n#\n#  "
  },
  {
    "path": "egs/lj/conf.yml",
    "chars": 1466,
    "preview": "# General config\nexp_dir: './exp'\nseed: 0\nload: 'DiffWave_GALR.pkl'\n# Generation config\nsampling_noise_schedule: './nois"
  },
  {
    "path": "egs/lj/eval.list",
    "chars": 1500,
    "preview": "LJ001-0001.wav\nLJ001-0047.wav\nLJ003-0140.wav\nLJ003-0145.wav\nLJ003-0235.wav\nLJ003-0240.wav\nLJ004-0195.wav\nLJ005-0059.wav\n"
  },
  {
    "path": "egs/lj/generate.sh",
    "chars": 157,
    "preview": "#!/bin/bash\n\n## example usage (single GPU): sh generate.sh 0 conf.yml\n\ncuda=$1\n\nCUDA_VISIBLE_DEVICES=$cuda python3 main."
  },
  {
    "path": "egs/lj/main.py",
    "chars": 4403,
    "preview": "import os\nimport sys\nsys.path.append('../../')\nimport json\nimport shutil\nimport hashlib\nimport argparse\nimport numpy as "
  },
  {
    "path": "egs/lj/schedule.sh",
    "chars": 157,
    "preview": "#!/bin/bash\n\n## example usage (single GPU): sh schedule.sh 0 conf.yml\n\ncuda=$1\n\nCUDA_VISIBLE_DEVICES=$cuda python3 main."
  },
  {
    "path": "egs/lj/train.sh",
    "chars": 264,
    "preview": "#!/bin/bash\n\n## example usage: sh train.sh 0,1,2,3 conf.yml\n\ncuda=$1\ncomma=${cuda//[^,]}\nnproc=$((${#comma}+1))\n\nCUDA_VI"
  },
  {
    "path": "egs/lj/valid.list",
    "chars": 3540,
    "preview": "LJ001-0023.wav\nLJ001-0084.wav\nLJ001-0144.wav\nLJ001-0163.wav\nLJ001-0179.wav\nLJ002-0042.wav\nLJ002-0140.wav\nLJ002-0201.wav\n"
  },
  {
    "path": "egs/vctk/conf.yml",
    "chars": 1442,
    "preview": "# General config\nexp_dir: './exp'\nseed: 0\nload: 'DiffWave_GALR.pkl'\n# Generation config\nsampling_noise_schedule: './nois"
  },
  {
    "path": "egs/vctk/generate.sh",
    "chars": 157,
    "preview": "#!/bin/bash\n\n## example usage (single GPU): sh generate.sh 0 conf.yml\n\ncuda=$1\n\nCUDA_VISIBLE_DEVICES=$cuda python3 main."
  },
  {
    "path": "egs/vctk/main.py",
    "chars": 4403,
    "preview": "import os\nimport sys\nsys.path.append('../../')\nimport json\nimport shutil\nimport hashlib\nimport argparse\nimport numpy as "
  },
  {
    "path": "egs/vctk/schedule.sh",
    "chars": 157,
    "preview": "#!/bin/bash\n\n## example usage (single GPU): sh schedule.sh 0 conf.yml\n\ncuda=$1\n\nCUDA_VISIBLE_DEVICES=$cuda python3 main."
  },
  {
    "path": "egs/vctk/train.sh",
    "chars": 264,
    "preview": "#!/bin/bash\n\n## example usage: sh train.sh 0,1,2,3 conf.yml\n\ncuda=$1\ncomma=${cuda//[^,]}\nnproc=$((${#comma}+1))\n\nCUDA_VI"
  }
]

// ... and 12 more files (download for full content)

About this extraction

This page contains the full source code of the tencent-ailab/bddm GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 43 files (70.9 MB), approximately 34.4k tokens, and a symbol index with 112 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!