Repository: tugstugi/pytorch-dc-tts Branch: master Commit: f892fd27c768 Files: 21 Total size: 61.2 KB Directory structure: gitextract__69wshca/ ├── .gitignore ├── LICENSE ├── README.md ├── audio.py ├── datasets/ │ ├── .gitignore │ ├── __init__.py │ ├── data_loader.py │ ├── lj_speech.py │ └── mb_speech.py ├── dl_and_preprop_dataset.py ├── hparams.py ├── logger.py ├── models/ │ ├── __init__.py │ ├── layers.py │ ├── ssrn.py │ └── text2mel.py ├── requirements.txt ├── synthesize.py ├── train-ssrn.py ├── train-text2mel.py └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ .idea .DS_Store __pycache__ .ipynb_checkpoints *.ipynb logdir/ samples *.npy *.tar.bz2 ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2019 Erdene-Ochir Tuguldur Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ PyTorch implementation of [Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention](https://arxiv.org/abs/1710.08969) based partially on the following projects: * https://github.com/Kyubyong/dc_tts (audio pre processing) * https://github.com/r9y9/deepvoice3_pytorch (data loader sampler) ## Online Text-To-Speech Demo The following notebooks are executable on [https://colab.research.google.com ](https://colab.research.google.com): * [Mongolian Male Voice TTS Demo](https://colab.research.google.com/github/tugstugi/pytorch-dc-tts/blob/master/notebooks/MongolianTTS.ipynb) * [English Female Voice TTS Demo (LJ-Speech)](https://colab.research.google.com/github/tugstugi/pytorch-dc-tts/blob/master/notebooks/EnglishTTS.ipynb) For audio samples and pretrained models, visit the above notebook links. ## Training/Synthesizing English Text-To-Speech The English TTS uses the [LJ-Speech](https://keithito.com/LJ-Speech-Dataset/) dataset. 1. Download the dataset: `python dl_and_preprop_dataset.py --dataset=ljspeech` 2. Train the Text2Mel model: `python train-text2mel.py --dataset=ljspeech` 3. Train the SSRN model: `python train-ssrn.py --dataset=ljspeech` 4. Synthesize sentences: `python synthesize.py --dataset=ljspeech` * The WAV files are saved in the `samples` folder. ## Training/Synthesizing Mongolian Text-To-Speech The Mongolian text-to-speech uses 5 hours audio from the [Mongolian Bible](https://www.bible.com/mn/versions/1590-2013-ariun-bibli-2013). 1. Download the dataset: `python dl_and_preprop_dataset.py --dataset=mbspeech` 2. Train the Text2Mel model: `python train-text2mel.py --dataset=mbspeech` 3. Train the SSRN model: `python train-ssrn.py --dataset=mbspeech` 4. Synthesize sentences: `python synthesize.py --dataset=mbspeech` * The WAV files are saved in the `samples` folder. ================================================ FILE: audio.py ================================================ """These methods are copied from https://github.com/Kyubyong/dc_tts/""" import os import copy import librosa import scipy.io.wavfile import numpy as np from tqdm import tqdm from scipy import signal from hparams import HParams as hp def spectrogram2wav(mag): '''# Generate wave file from linear magnitude spectrogram Args: mag: A numpy array of (T, 1+n_fft//2) Returns: wav: A 1-D numpy array. ''' # transpose mag = mag.T # de-noramlize mag = (np.clip(mag, 0, 1) * hp.max_db) - hp.max_db + hp.ref_db # to amplitude mag = np.power(10.0, mag * 0.05) # wav reconstruction wav = griffin_lim(mag ** hp.power) # de-preemphasis wav = signal.lfilter([1], [1, -hp.preemphasis], wav) # trim wav, _ = librosa.effects.trim(wav) return wav.astype(np.float32) def griffin_lim(spectrogram): '''Applies Griffin-Lim's raw.''' X_best = copy.deepcopy(spectrogram) for i in range(hp.n_iter): X_t = invert_spectrogram(X_best) est = librosa.stft(X_t, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length) phase = est / np.maximum(1e-8, np.abs(est)) X_best = spectrogram * phase X_t = invert_spectrogram(X_best) y = np.real(X_t) return y def invert_spectrogram(spectrogram): '''Applies inverse fft. Args: spectrogram: [1+n_fft//2, t] ''' return librosa.istft(spectrogram, hop_length=hp.hop_length, win_length=hp.win_length, window="hann") def get_spectrograms(fpath): '''Parse the wave file in `fpath` and Returns normalized melspectrogram and linear spectrogram. Args: fpath: A string. The full path of a sound file. Returns: mel: A 2d array of shape (T, n_mels) and dtype of float32. mag: A 2d array of shape (T, 1+n_fft/2) and dtype of float32. ''' # Loading sound file y, sr = librosa.load(fpath, sr=hp.sr) # Trimming y, _ = librosa.effects.trim(y) # Preemphasis y = np.append(y[0], y[1:] - hp.preemphasis * y[:-1]) # stft linear = librosa.stft(y=y, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length) # magnitude spectrogram mag = np.abs(linear) # (1+n_fft//2, T) # mel spectrogram mel_basis = librosa.filters.mel(sr=hp.sr, n_fft=hp.n_fft, n_mels=hp.n_mels) # (n_mels, 1+n_fft//2) mel = np.dot(mel_basis, mag) # (n_mels, t) # to decibel mel = 20 * np.log10(np.maximum(1e-5, mel)) mag = 20 * np.log10(np.maximum(1e-5, mag)) # normalize mel = np.clip((mel - hp.ref_db + hp.max_db) / hp.max_db, 1e-8, 1) mag = np.clip((mag - hp.ref_db + hp.max_db) / hp.max_db, 1e-8, 1) # Transpose mel = mel.T.astype(np.float32) # (T, n_mels) mag = mag.T.astype(np.float32) # (T, 1+n_fft//2) return mel, mag def save_to_wav(mag, filename): """Generate and save an audio file from the given linear spectrogram using Griffin-Lim.""" wav = spectrogram2wav(mag) scipy.io.wavfile.write(filename, hp.sr, wav) def preprocess(dataset_path, speech_dataset): """Preprocess the given dataset.""" wavs_path = os.path.join(dataset_path, 'wavs') mels_path = os.path.join(dataset_path, 'mels') if not os.path.isdir(mels_path): os.mkdir(mels_path) mags_path = os.path.join(dataset_path, 'mags') if not os.path.isdir(mags_path): os.mkdir(mags_path) for fname in tqdm(speech_dataset.fnames): mel, mag = get_spectrograms(os.path.join(wavs_path, '%s.wav' % fname)) t = mel.shape[0] # Marginal padding for reduction shape sync. num_paddings = hp.reduction_rate - (t % hp.reduction_rate) if t % hp.reduction_rate != 0 else 0 mel = np.pad(mel, [[0, num_paddings], [0, 0]], mode="constant") mag = np.pad(mag, [[0, num_paddings], [0, 0]], mode="constant") # Reduction mel = mel[::hp.reduction_rate, :] np.save(os.path.join(mels_path, '%s.npy' % fname), mel) np.save(os.path.join(mags_path, '%s.npy' % fname), mag) ================================================ FILE: datasets/.gitignore ================================================ LJSpeech-1.1/ MBSpeech-1.0/ *.tar.gz ================================================ FILE: datasets/__init__.py ================================================ ================================================ FILE: datasets/data_loader.py ================================================ import random import numpy as np import torch from torch.utils.data.dataloader import default_collate, DataLoader from torch.utils.data.sampler import Sampler __all__ = ['Text2MelDataLoader', 'SSRNDataLoader'] class Text2MelDataLoader(DataLoader): def __init__(self, text2mel_dataset, batch_size, mode='train', num_workers=8): if mode == 'train': text2mel_dataset.slice(0, -batch_size) elif mode == 'valid': text2mel_dataset.slice(len(text2mel_dataset) - batch_size, -1) else: raise ValueError("mode must be either 'train' or 'valid'") super().__init__(text2mel_dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, shuffle=True) class SSRNDataLoader(DataLoader): def __init__(self, ssrn_dataset, batch_size, mode='train', num_workers=8): if mode == 'train': ssrn_dataset.slice(0, -batch_size) super().__init__(ssrn_dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, sampler=PartiallyRandomizedSimilarTimeLengthSampler(lengths=ssrn_dataset.text_lengths, data_source=None, batch_size=batch_size)) elif mode == 'valid': ssrn_dataset.slice(len(ssrn_dataset) - batch_size, -1) super().__init__(ssrn_dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, shuffle=True) else: raise ValueError("mode must be either 'train' or 'valid'") def collate_fn(batch): keys = batch[0].keys() max_lengths = {key: 0 for key in keys} collated_batch = {key: [] for key in keys} # find out the max lengths for row in batch: for key in keys: max_lengths[key] = max(max_lengths[key], row[key].shape[0]) # pad to the max lengths for row in batch: for key in keys: array = row[key] dim = len(array.shape) assert dim == 1 or dim == 2 # TODO: because of pre processing, later we want to have (n_mels, T) if dim == 1: padded_array = np.pad(array, (0, max_lengths[key] - array.shape[0]), mode='constant') else: padded_array = np.pad(array, ((0, max_lengths[key] - array.shape[0]), (0, 0)), mode='constant') collated_batch[key].append(padded_array) # use the default_collate to convert to tensors for key in keys: collated_batch[key] = default_collate(collated_batch[key]) return collated_batch class PartiallyRandomizedSimilarTimeLengthSampler(Sampler): """Copied from: https://github.com/r9y9/deepvoice3_pytorch/blob/master/train.py. Partially randomized sampler 1. Sort by lengths 2. Pick a small patch and randomize it 3. Permutate mini-batches """ def __init__(self, lengths, data_source, batch_size=16, batch_group_size=None, permutate=True): super().__init__(data_source) self.lengths, self.sorted_indices = torch.sort(torch.LongTensor(lengths)) self.batch_size = batch_size if batch_group_size is None: batch_group_size = min(batch_size * 32, len(self.lengths)) if batch_group_size % batch_size != 0: batch_group_size -= batch_group_size % batch_size self.batch_group_size = batch_group_size assert batch_group_size % batch_size == 0 self.permutate = permutate def __iter__(self): indices = self.sorted_indices.clone() batch_group_size = self.batch_group_size s, e = 0, 0 for i in range(len(indices) // batch_group_size): s = i * batch_group_size e = s + batch_group_size random.shuffle(indices[s:e]) # Permutate batches if self.permutate: perm = np.arange(len(indices[:e]) // self.batch_size) random.shuffle(perm) indices[:e] = indices[:e].view(-1, self.batch_size)[perm, :].view(-1) # Handle last elements s += batch_group_size if s < len(indices): random.shuffle(indices[s:]) return iter(indices) def __len__(self): return len(self.sorted_indices) ================================================ FILE: datasets/lj_speech.py ================================================ """Data loader for the LJSpeech dataset. See: https://keithito.com/LJ-Speech-Dataset/""" import os import re import codecs import unicodedata import numpy as np from torch.utils.data import Dataset vocab = "PE abcdefghijklmnopqrstuvwxyz'.?" # P: Padding, E: EOS. char2idx = {char: idx for idx, char in enumerate(vocab)} idx2char = {idx: char for idx, char in enumerate(vocab)} def text_normalize(text): text = ''.join(char for char in unicodedata.normalize('NFD', text) if unicodedata.category(char) != 'Mn') # Strip accents text = text.lower() text = re.sub("[^{}]".format(vocab), " ", text) text = re.sub("[ ]+", " ", text) return text def read_metadata(metadata_file): fnames, text_lengths, texts = [], [], [] transcript = os.path.join(metadata_file) lines = codecs.open(transcript, 'r', 'utf-8').readlines() for line in lines: fname, _, text = line.strip().split("|") fnames.append(fname) text = text_normalize(text) + "E" # E: EOS text = [char2idx[char] for char in text] text_lengths.append(len(text)) texts.append(np.array(text, np.longlong)) return fnames, text_lengths, texts def get_test_data(sentences, max_n): normalized_sentences = [text_normalize(line).strip() + "E" for line in sentences] # text normalization, E: EOS texts = np.zeros((len(normalized_sentences), max_n + 1), np.longlong) for i, sent in enumerate(normalized_sentences): texts[i, :len(sent)] = [char2idx[char] for char in sent] return texts class LJSpeech(Dataset): def __init__(self, keys, dir_name='LJSpeech-1.1'): self.keys = keys self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), dir_name) self.fnames, self.text_lengths, self.texts = read_metadata(os.path.join(self.path, 'metadata.csv')) def slice(self, start, end): self.fnames = self.fnames[start:end] self.text_lengths = self.text_lengths[start:end] self.texts = self.texts[start:end] def __len__(self): return len(self.fnames) def __getitem__(self, index): data = {} if 'texts' in self.keys: data['texts'] = self.texts[index] if 'mels' in self.keys: # (39, 80) data['mels'] = np.load(os.path.join(self.path, 'mels', "%s.npy" % self.fnames[index])) if 'mags' in self.keys: # (39, 80) data['mags'] = np.load(os.path.join(self.path, 'mags', "%s.npy" % self.fnames[index])) if 'mel_gates' in self.keys: data['mel_gates'] = np.ones(data['mels'].shape[0], dtype=np.int64) # TODO: because pre processing! if 'mag_gates' in self.keys: data['mag_gates'] = np.ones(data['mags'].shape[0], dtype=np.int64) # TODO: because pre processing! return data ================================================ FILE: datasets/mb_speech.py ================================================ """Data loader for the Mongolian Bible dataset.""" import os import codecs import numpy as np from torch.utils.data import Dataset vocab = "PE абвгдеёжзийклмноөпрстуүфхцчшъыьэюя-.,!?" # P: Padding, E: EOS. char2idx = {char: idx for idx, char in enumerate(vocab)} idx2char = {idx: char for idx, char in enumerate(vocab)} def text_normalize(text): text = text.lower() # text = text.replace(",", "'") # text = text.replace("!", "?") for c in "-—:": text = text.replace(c, "-") for c in "()\"«»“”'": text = text.replace(c, ",") return text def read_metadata(metadata_file): fnames, text_lengths, texts = [], [], [] transcript = os.path.join(metadata_file) lines = codecs.open(transcript, 'r', 'utf-8').readlines() for line in lines: fname, _, text = line.strip().split("|") fnames.append(fname) text = text_normalize(text) + "E" # E: EOS text = [char2idx[char] for char in text] text_lengths.append(len(text)) texts.append(np.array(text, np.longlong)) return fnames, text_lengths, texts def get_test_data(sentences, max_n): normalized_sentences = [text_normalize(line).strip() + "E" for line in sentences] # text normalization, E: EOS texts = np.zeros((len(normalized_sentences), max_n + 1), np.longlong) for i, sent in enumerate(normalized_sentences): texts[i, :len(sent)] = [char2idx[char] for char in sent] return texts class MBSpeech(Dataset): def __init__(self, keys, dir_name='MBSpeech-1.0'): self.keys = keys self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), dir_name) self.fnames, self.text_lengths, self.texts = read_metadata(os.path.join(self.path, 'metadata.csv')) def slice(self, start, end): self.fnames = self.fnames[start:end] self.text_lengths = self.text_lengths[start:end] self.texts = self.texts[start:end] def __len__(self): return len(self.fnames) def __getitem__(self, index): data = {} if 'texts' in self.keys: data['texts'] = self.texts[index] if 'mels' in self.keys: # (39, 80) data['mels'] = np.load(os.path.join(self.path, 'mels', "%s.npy" % self.fnames[index])) if 'mags' in self.keys: # (39, 80) data['mags'] = np.load(os.path.join(self.path, 'mags', "%s.npy" % self.fnames[index])) if 'mel_gates' in self.keys: data['mel_gates'] = np.ones(data['mels'].shape[0], dtype=np.int64) # TODO: because pre processing! if 'mag_gates' in self.keys: data['mag_gates'] = np.ones(data['mags'].shape[0], dtype=np.int64) # TODO: because pre processing! return data # # simple method to convert mongolian numbers to text, copied from somewhere # def number2word(number): digit_len = len(number) digit_name = {1: '', 2: 'мянга', 3: 'сая', 4: 'тэрбум', 5: 'их наяд', 6: 'тунамал'} if digit_len == 1: return _last_digit_2_str(number) if digit_len == 2: return _2_digits_2_str(number) if digit_len == 3: return _3_digits_to_str(number) if digit_len < 7: return _3_digits_to_str(number[:-3], False) + ' ' + digit_name[2] + ' ' + _3_digits_to_str(number[-3:]) digitgroup = [number[0 if i - 3 < 0 else i - 3:i] for i in reversed(range(len(number), 0, -3))] count = len(digitgroup) i = 0 result = '' while i < count - 1: result += ' ' + (_3_digits_to_str(digitgroup[i], False) + ' ' + digit_name[count - i]) i += 1 return result.strip() + ' ' + _3_digits_to_str(digitgroup[-1]) def _1_digit_2_str(digit): return {'0': '', '1': 'нэгэн', '2': 'хоёр', '3': 'гурван', '4': 'дөрвөн', '5': 'таван', '6': 'зургаан', '7': 'долоон', '8': 'найман', '9': 'есөн'}[digit] def _last_digit_2_str(digit): return {'0': 'тэг', '1': 'нэг', '2': 'хоёр', '3': 'гурав', '4': 'дөрөв', '5': 'тав', '6': 'зургаа', '7': 'долоо', '8': 'найм', '9': 'ес'}[digit] def _2_digits_2_str(digit, is_fina=True): word2 = {'0': '', '1': 'арван', '2': 'хорин', '3': 'гучин', '4': 'дөчин', '5': 'тавин', '6': 'жаран', '7': 'далан', '8': 'наян', '9': 'ерэн'} word2fina = {'10': 'арав', '20': 'хорь', '30': 'гуч', '40': 'дөч', '50': 'тавь', '60': 'жар', '70': 'дал', '80': 'ная', '90': 'ер'} if digit[1] == '0': return word2fina[digit] if is_fina else word2[digit[0]] digit1 = _last_digit_2_str(digit[1]) if is_fina else _1_digit_2_str(digit[1]) return (word2[digit[0]] + ' ' + digit1).strip() def _3_digits_to_str(digit, is_fina=True): digstr = digit.lstrip('0') if len(digstr) == 0: return '' if len(digstr) == 1: return _1_digit_2_str(digstr) if len(digstr) == 2: return _2_digits_2_str(digstr, is_fina) if digit[-2:] == '00': return _1_digit_2_str(digit[0]) + ' зуу' if is_fina else _1_digit_2_str(digit[0]) + ' зуун' else: return _1_digit_2_str(digit[0]) + ' зуун ' + _2_digits_2_str(digit[-2:], is_fina) ================================================ FILE: dl_and_preprop_dataset.py ================================================ #!/usr/bin/env python """Download and preprocess datasets. Supported datasets are: * English female: LJSpeech (https://keithito.com/LJ-Speech-Dataset/) * Mongolian male: MBSpeech (Mongolian Bible) """ __author__ = 'Erdene-Ochir Tuguldur' import os import sys import csv import time import argparse import fnmatch import librosa import pandas as pd from hparams import HParams as hp from zipfile import ZipFile from audio import preprocess from utils import download_file from datasets.mb_speech import MBSpeech from datasets.lj_speech import LJSpeech parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--dataset", required=True, choices=['ljspeech', 'mbspeech'], help='dataset name') args = parser.parse_args() if args.dataset == 'ljspeech': dataset_file_name = 'LJSpeech-1.1.tar.bz2' datasets_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'datasets') dataset_path = os.path.join(datasets_path, 'LJSpeech-1.1') if os.path.isdir(dataset_path) and False: print("LJSpeech dataset folder already exists") sys.exit(0) else: dataset_file_path = os.path.join(datasets_path, dataset_file_name) if not os.path.isfile(dataset_file_path): url = "http://data.keithito.com/data/speech/%s" % dataset_file_name download_file(url, dataset_file_path) else: print("'%s' already exists" % dataset_file_name) print("extracting '%s'..." % dataset_file_name) os.system('cd %s; tar xvjf %s' % (datasets_path, dataset_file_name)) # pre process print("pre processing...") lj_speech = LJSpeech([]) preprocess(dataset_path, lj_speech) elif args.dataset == 'mbspeech': dataset_name = 'MBSpeech-1.0' datasets_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'datasets') dataset_path = os.path.join(datasets_path, dataset_name) if os.path.isdir(dataset_path) and False: print("MBSpeech dataset folder already exists") sys.exit(0) else: bible_books = ['01_Genesis', '02_Exodus', '03_Leviticus'] for bible_book_name in bible_books: bible_book_file_name = '%s.zip' % bible_book_name bible_book_file_path = os.path.join(datasets_path, bible_book_file_name) if not os.path.isfile(bible_book_file_path): url = "https://s3.us-east-2.amazonaws.com/bible.davarpartners.com/Mongolian/" + bible_book_file_name download_file(url, bible_book_file_path) else: print("'%s' already exists" % bible_book_file_name) print("extracting '%s'..." % bible_book_file_name) zipfile = ZipFile(bible_book_file_path) zipfile.extractall(datasets_path) dataset_csv_file_path = os.path.join(datasets_path, '%s-csv.zip' % dataset_name) dataset_csv_extracted_path = os.path.join(datasets_path, '%s-csv' % dataset_name) if not os.path.isfile(dataset_csv_file_path): url = "https://www.dropbox.com/s/dafueq0w278lbz6/%s-csv.zip?dl=1" % dataset_name download_file(url, dataset_csv_file_path) else: print("'%s' already exists" % dataset_csv_file_path) print("extracting '%s'..." % dataset_csv_file_path) zipfile = ZipFile(dataset_csv_file_path) zipfile.extractall(datasets_path) sample_rate = 44100 # original sample rate total_duration_s = 0 if not os.path.isdir(dataset_path): os.mkdir(dataset_path) wavs_path = os.path.join(dataset_path, 'wavs') if not os.path.isdir(wavs_path): os.mkdir(wavs_path) metadata_csv = open(os.path.join(dataset_path, 'metadata.csv'), 'w') metadata_csv_writer = csv.writer(metadata_csv, delimiter='|') def _normalize(s): """remove leading '-'""" s = s.strip() if s[0] == '—' or s[0] == '-': s = s[1:].strip() return s def _get_mp3_file(book_name, chapter): book_download_path = os.path.join(datasets_path, book_name) wildcard = "*%02d - DPI.mp3" % chapter for file_name in os.listdir(book_download_path): if fnmatch.fnmatch(file_name, wildcard): return os.path.join(book_download_path, file_name) return None def _convert_mp3_to_wav(book_name, book_nr): global total_duration_s chapter = 1 while True: try: i = 0 chapter_csv_file_name = os.path.join(dataset_csv_extracted_path, "%s_%02d.csv" % (book_name, chapter)) df = pd.read_csv(chapter_csv_file_name, sep="|") print("processing %s..." % chapter_csv_file_name) mp3_file = _get_mp3_file(book_name, chapter) print("processing %s..." % mp3_file) assert mp3_file is not None samples, sr = librosa.load(mp3_file, sr=sample_rate, mono=True) assert sr == sample_rate for index, row in df.iterrows(): start, end, sentence = row['start'], row['end'], row['sentence'] assert end > start duration = end - start duration_s = int(duration / sample_rate) if duration_s > 10: continue # only audios shorter than 10s total_duration_s += duration_s i += 1 sentence = _normalize(sentence) fn = "MB%d%02d-%04d" % (book_nr, chapter, i) metadata_csv_writer.writerow([fn, sentence, sentence]) # same format as LJSpeech wav = samples[start:end] wav = librosa.resample(wav, sample_rate, hp.sr) # use same sample rate as LJSpeech librosa.output.write_wav(os.path.join(wavs_path, fn + ".wav"), wav, hp.sr) chapter += 1 except FileNotFoundError: break _convert_mp3_to_wav('01_Genesis', 1) _convert_mp3_to_wav('02_Exodus', 2) _convert_mp3_to_wav('03_Leviticus', 3) metadata_csv.close() print("total audio duration: %ss" % (time.strftime('%H:%M:%S', time.gmtime(total_duration_s)))) # pre process print("pre processing...") mb_speech = MBSpeech([]) preprocess(dataset_path, mb_speech) ================================================ FILE: hparams.py ================================================ """Hyper parameters.""" __author__ = 'Erdene-Ochir Tuguldur' class HParams: """Hyper parameters""" disable_progress_bar = False # set True if you don't want the progress bar in the console logdir = "logdir" # log dir where the checkpoints and tensorboard files are saved # audio.py options, these values are from https://github.com/Kyubyong/dc_tts/blob/master/hyperparams.py reduction_rate = 4 # melspectrogram reduction rate, don't change because SSRN is using this rate n_fft = 2048 # fft points (samples) n_mels = 80 # Number of Mel banks to generate power = 1.5 # Exponent for amplifying the predicted magnitude n_iter = 50 # Number of inversion iterations preemphasis = .97 max_db = 100 ref_db = 20 sr = 22050 # Sampling rate frame_shift = 0.0125 # seconds frame_length = 0.05 # seconds hop_length = int(sr * frame_shift) # samples. =276. win_length = int(sr * frame_length) # samples. =1102. max_N = 180 # Maximum number of characters. max_T = 210 # Maximum number of mel frames. e = 128 # embedding dimension d = 256 # Text2Mel hidden unit dimension c = 512+128 # SSRN hidden unit dimension dropout_rate = 0.05 # dropout # Text2Mel network options text2mel_lr = 0.005 # learning rate text2mel_max_iteration = 300000 # max train step text2mel_weight_init = 'none' # 'kaiming', 'xavier' or 'none' text2mel_normalization = 'layer' # 'layer', 'weight' or 'none' text2mel_basic_block = 'gated_conv' # 'highway', 'gated_conv' or 'residual' # SSRN network options ssrn_lr = 0.0005 # learning rate ssrn_max_iteration = 150000 # max train step ssrn_weight_init = 'kaiming' # 'kaiming', 'xavier' or 'none' ssrn_normalization = 'weight' # 'layer', 'weight' or 'none' ssrn_basic_block = 'residual' # 'highway', 'gated_conv' or 'residual' ================================================ FILE: logger.py ================================================ """Wrapper class for logging into the TensorBoard and comet.ml""" __author__ = 'Erdene-Ochir Tuguldur' __all__ = ['Logger'] import os from tensorboardX import SummaryWriter from hparams import HParams as hp class Logger(object): def __init__(self, dataset_name, model_name): self.model_name = model_name self.project_name = "%s-%s" % (dataset_name, self.model_name) self.logdir = os.path.join(hp.logdir, self.project_name) self.writer = SummaryWriter(log_dir=self.logdir) def log_step(self, phase, step, loss_dict, image_dict): if phase == 'train': if step % 50 == 0: # self.writer.add_scalar('lr', get_lr(), step) # self.writer.add_scalar('%s-step/loss' % phase, loss, step) for key in sorted(loss_dict): self.writer.add_scalar('%s-step/%s' % (phase, key), loss_dict[key], step) if step % 1000 == 0: for key in sorted(image_dict): self.writer.add_image('%s/%s' % (self.model_name, key), image_dict[key], step) def log_epoch(self, phase, step, loss_dict): for key in sorted(loss_dict): self.writer.add_scalar('%s/%s' % (phase, key), loss_dict[key], step) ================================================ FILE: models/__init__.py ================================================ from .text2mel import Text2Mel from .ssrn import SSRN ================================================ FILE: models/layers.py ================================================ __author__ = 'Erdene-Ochir Tuguldur' __all__ = ['E', 'D', 'C', 'HighwayBlock', 'GatedConvBlock', 'ResidualBlock'] import torch.nn as nn import torch.nn.functional as F from hparams import HParams as hp class LayerNorm(nn.LayerNorm): def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): """Layer Norm.""" super(LayerNorm, self).__init__(normalized_shape, eps=eps, elementwise_affine=elementwise_affine) def forward(self, x): x = x.permute(0, 2, 1) # PyTorch LayerNorm seems to be expect (B, T, C) y = super(LayerNorm, self).forward(x) y = y.permute(0, 2, 1) # reverse return y class D(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, dilation, weight_init='none', normalization='weight', nonlinearity='linear'): """1D Deconvolution.""" super(D, self).__init__() self.deconv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride=2, # paper: stride of deconvolution is always 2 dilation=dilation) if normalization == 'weight': self.deconv = nn.utils.weight_norm(self.deconv) elif normalization == 'layer': self.layer_norm = LayerNorm(out_channels) self.nonlinearity = nonlinearity if weight_init == 'kaiming': nn.init.kaiming_normal_(self.deconv.weight, mode='fan_out', nonlinearity=nonlinearity) elif weight_init == 'xavier': nn.init.xavier_uniform_(self.deconv.weight, nn.init.calculate_gain(nonlinearity)) def forward(self, x, output_size=None): y = self.deconv(x, output_size=output_size) if hasattr(self, 'layer_norm'): y = self.layer_norm(y) y = F.dropout(y, p=hp.dropout_rate, training=self.training, inplace=True) if self.nonlinearity == 'relu': y = F.relu(y, inplace=True) return y class C(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, dilation, causal=False, weight_init='none', normalization='weight', nonlinearity='linear'): """1D convolution. The argument 'causal' indicates whether the causal convolution should be used or not. """ super(C, self).__init__() self.causal = causal if causal: self.padding = (kernel_size - 1) * dilation else: self.padding = (kernel_size - 1) * dilation // 2 self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=1, # paper: 'The stride of convolution is always 1.' padding=self.padding, dilation=dilation) if normalization == 'weight': self.conv = nn.utils.weight_norm(self.conv) elif normalization == 'layer': self.layer_norm = LayerNorm(out_channels) self.nonlinearity = nonlinearity if weight_init == 'kaiming': nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity=nonlinearity) elif weight_init == 'xavier': nn.init.xavier_uniform_(self.conv.weight, nn.init.calculate_gain(nonlinearity)) def forward(self, x): y = self.conv(x) padding = self.padding if self.causal and padding > 0: y = y[:, :, :-padding] if hasattr(self, 'layer_norm'): y = self.layer_norm(y) y = F.dropout(y, p=hp.dropout_rate, training=self.training, inplace=True) if self.nonlinearity == 'relu': y = F.relu(y, inplace=True) return y class E(nn.Module): def __init__(self, num_embeddings, embedding_dim): super(E, self).__init__() self.embedding = nn.Embedding(num_embeddings, embedding_dim, padding_idx=0) def forward(self, x): return self.embedding(x) class HighwayBlock(nn.Module): def __init__(self, d, k, delta, causal=False, weight_init='none', normalization='weight'): """Highway Network like layer: https://arxiv.org/abs/1505.00387 The input and output shapes remain same. Args: d: input channel k: kernel size delta: dilation causal: causal convolution or not """ super(HighwayBlock, self).__init__() self.d = d self.C = C(in_channels=d, out_channels=2 * d, kernel_size=k, dilation=delta, causal=causal, weight_init=weight_init, normalization=normalization) def forward(self, x): L = self.C(x) H1 = L[:, :self.d, :] H2 = L[:, self.d:, :] sigH1 = F.sigmoid(H1) return sigH1 * H2 + (1 - sigH1) * x class GatedConvBlock(nn.Module): def __init__(self, d, k, delta, causal=False, weight_init='none', normalization='weight'): """Gated convolutional layer: https://arxiv.org/abs/1612.08083 The input and output shapes remain same. Args: d: input channel k: kernel size delta: dilation causal: causal convolution or not """ super(GatedConvBlock, self).__init__() self.C = C(in_channels=d, out_channels=2 * d, kernel_size=k, dilation=delta, causal=causal, weight_init=weight_init, normalization=normalization) self.glu = nn.GLU(dim=1) def forward(self, x): L = self.C(x) return self.glu(L) + x class ResidualBlock(nn.Module): def __init__(self, d, k, delta, causal=False, weight_init='none', normalization='weight', widening_factor=2): """Residual block: https://arxiv.org/abs/1512.03385 The input and output shapes remain same. Args: d: input channel k: kernel size delta: dilation causal: causal convolution or not """ super(ResidualBlock, self).__init__() self.C1 = C(in_channels=d, out_channels=widening_factor * d, kernel_size=k, dilation=delta, causal=causal, weight_init=weight_init, normalization=normalization, nonlinearity='relu') self.C2 = C(in_channels=widening_factor * d, out_channels=d, kernel_size=k, dilation=delta, causal=causal, weight_init=weight_init, normalization=normalization, nonlinearity='relu') def forward(self, x): return self.C2(self.C1(x)) + x ================================================ FILE: models/ssrn.py ================================================ """ Hideyuki Tachibana, Katsuya Uenoyama, Shunsuke Aihara Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention https://arxiv.org/abs/1710.08969 SSRN Network. """ __author__ = 'Erdene-Ochir Tuguldur' __all__ = ['SSRN'] import torch.nn as nn import torch.nn.functional as F from hparams import HParams as hp from .layers import D, C, HighwayBlock, GatedConvBlock, ResidualBlock def Conv(in_channels, out_channels, kernel_size, dilation, nonlinearity='linear'): return C(in_channels, out_channels, kernel_size, dilation, causal=False, weight_init=hp.ssrn_weight_init, normalization=hp.ssrn_normalization, nonlinearity=nonlinearity) def DeConv(in_channels, out_channels, kernel_size, dilation, nonlinearity='linear'): return D(in_channels, out_channels, kernel_size, dilation, weight_init=hp.ssrn_weight_init, normalization=hp.ssrn_normalization, nonlinearity=nonlinearity) def BasicBlock(d, k, delta): if hp.ssrn_basic_block == 'gated_conv': return GatedConvBlock(d, k, delta, causal=False, weight_init=hp.ssrn_weight_init, normalization=hp.ssrn_normalization) elif hp.ssrn_basic_block == 'highway': return HighwayBlock(d, k, delta, causal=False, weight_init=hp.ssrn_weight_init, normalization=hp.ssrn_normalization) else: return ResidualBlock(d, k, delta, causal=False, weight_init=hp.ssrn_weight_init, normalization=hp.ssrn_normalization, widening_factor=1) class SSRN(nn.Module): def __init__(self, c=hp.c, f=hp.n_mels, f_prime=(1 + hp.n_fft // 2)): """Spectrogram super-resolution network. Args: c: SSRN dim f: Number of mel bins f_prime: full spectrogram dim Input: Y: (B, f, T) predicted melspectrograms Outputs: Z_logit: logit of Z Z: (B, f_prime, 4*T) full spectrograms """ super(SSRN, self).__init__() self.layers = nn.Sequential( Conv(f, c, 1, 1), BasicBlock(c, 3, 1), BasicBlock(c, 3, 3), DeConv(c, c, 2, 1), BasicBlock(c, 3, 1), BasicBlock(c, 3, 3), DeConv(c, c, 2, 1), BasicBlock(c, 3, 1), BasicBlock(c, 3, 3), Conv(c, 2 * c, 1, 1), BasicBlock(2 * c, 3, 1), BasicBlock(2 * c, 3, 1), Conv(2 * c, f_prime, 1, 1), # Conv(f_prime, f_prime, 1, 1, nonlinearity='relu'), # Conv(f_prime, f_prime, 1, 1, nonlinearity='relu'), BasicBlock(f_prime, 1, 1), Conv(f_prime, f_prime, 1, 1) ) def forward(self, x): Z_logit = self.layers(x) Z = F.sigmoid(Z_logit) return Z_logit, Z ================================================ FILE: models/text2mel.py ================================================ """ Hideyuki Tachibana, Katsuya Uenoyama, Shunsuke Aihara Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention https://arxiv.org/abs/1710.08969 Text2Mel Network. """ __author__ = 'Erdene-Ochir Tuguldur' __all__ = ['Text2Mel'] import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from hparams import HParams as hp from .layers import E, C, HighwayBlock, GatedConvBlock, ResidualBlock def Conv(in_channels, out_channels, kernel_size, dilation, causal=False, nonlinearity='linear'): return C(in_channels, out_channels, kernel_size, dilation, causal=causal, weight_init=hp.text2mel_weight_init, normalization=hp.text2mel_normalization, nonlinearity=nonlinearity) def BasicBlock(d, k, delta, causal=False): if hp.text2mel_basic_block == 'gated_conv': return GatedConvBlock(d, k, delta, causal=causal, weight_init=hp.text2mel_weight_init, normalization=hp.text2mel_normalization) elif hp.text2mel_basic_block == 'highway': return HighwayBlock(d, k, delta, causal=causal, weight_init=hp.text2mel_weight_init, normalization=hp.text2mel_normalization) else: return ResidualBlock(d, k, delta, causal=causal, weight_init=hp.text2mel_weight_init, normalization=hp.text2mel_normalization, widening_factor=2) def CausalConv(in_channels, out_channels, kernel_size, dilation, nonlinearity='linear'): return Conv(in_channels, out_channels, kernel_size, dilation, causal=True, nonlinearity=nonlinearity) def CausalBasicBlock(d, k, delta): return BasicBlock(d, k, delta, causal=True) class TextEnc(nn.Module): def __init__(self, vocab, e=hp.e, d=hp.d): """Text encoder network. Args: vocab: vocabulary e: embedding dim d: Text2Mel dim Input: L: (B, N) text inputs Outputs: K: (B, d, N) keys V: (N, d, N) values """ super(TextEnc, self).__init__() self.d = d self.embedding = E(len(vocab), e) self.layers = nn.Sequential( Conv(e, 2 * d, 1, 1, nonlinearity='relu'), Conv(2 * d, 2 * d, 1, 1), BasicBlock(2 * d, 3, 1), BasicBlock(2 * d, 3, 3), BasicBlock(2 * d, 3, 9), BasicBlock(2 * d, 3, 27), BasicBlock(2 * d, 3, 1), BasicBlock(2 * d, 3, 3), BasicBlock(2 * d, 3, 9), BasicBlock(2 * d, 3, 27), BasicBlock(2 * d, 3, 1), BasicBlock(2 * d, 3, 1), BasicBlock(2 * d, 1, 1), BasicBlock(2 * d, 1, 1) ) def forward(self, x): out = self.embedding(x) out = out.permute(0, 2, 1) # change to (B, e, N) out = self.layers(out) # (B, 2*d, N) K = out[:, :self.d, :] # (B, d, N) V = out[:, self.d:, :] # (B, d, N) return K, V class AudioEnc(nn.Module): def __init__(self, d=hp.d, f=hp.n_mels): """Audio encoder network. Args: d: Text2Mel dim f: Number of mel bins Input: S: (B, f, T) melspectrograms Output: Q: (B, d, T) queries """ super(AudioEnc, self).__init__() self.layers = nn.Sequential( CausalConv(f, d, 1, 1, nonlinearity='relu'), CausalConv(d, d, 1, 1, nonlinearity='relu'), CausalConv(d, d, 1, 1), CausalBasicBlock(d, 3, 1), CausalBasicBlock(d, 3, 3), CausalBasicBlock(d, 3, 9), CausalBasicBlock(d, 3, 27), CausalBasicBlock(d, 3, 1), CausalBasicBlock(d, 3, 3), CausalBasicBlock(d, 3, 9), CausalBasicBlock(d, 3, 27), CausalBasicBlock(d, 3, 3), CausalBasicBlock(d, 3, 3), ) def forward(self, x): return self.layers(x) class AudioDec(nn.Module): def __init__(self, d=hp.d, f=hp.n_mels): """Audio decoder network. Args: d: Text2Mel dim f: Number of mel bins Input: R_prime: (B, 2d, T) [V*Attention, Q] paper says: "we found it beneficial in our pilot study." Output: Y: (B, f, T) """ super(AudioDec, self).__init__() self.layers = nn.Sequential( CausalConv(2 * d, d, 1, 1), CausalBasicBlock(d, 3, 1), CausalBasicBlock(d, 3, 3), CausalBasicBlock(d, 3, 9), CausalBasicBlock(d, 3, 27), CausalBasicBlock(d, 3, 1), CausalBasicBlock(d, 3, 1), # CausalConv(d, d, 1, 1, nonlinearity='relu'), # CausalConv(d, d, 1, 1, nonlinearity='relu'), CausalBasicBlock(d, 1, 1), CausalConv(d, d, 1, 1, nonlinearity='relu'), CausalConv(d, f, 1, 1) ) def forward(self, x): return self.layers(x) class Text2Mel(nn.Module): def __init__(self, vocab, d=hp.d): """Text to melspectrogram network. Args: vocab: vocabulary d: Text2Mel dim Input: L: (B, N) text inputs S: (B, f, T) melspectrograms Outputs: Y_logit: logit of Y Y: predicted melspectrograms A: (B, N, T) attention matrix """ super(Text2Mel, self).__init__() self.d = d self.text_enc = TextEnc(vocab) self.audio_enc = AudioEnc() self.audio_dec = AudioDec() def forward(self, L, S, monotonic_attention=False): K, V = self.text_enc(L) Q = self.audio_enc(S) A = torch.bmm(K.permute(0, 2, 1), Q) / np.sqrt(self.d) if monotonic_attention: # TODO: vectorize instead of loops B, N, T = A.size() for i in range(B): prva = -1 # previous attention for t in range(T): _, n = torch.max(A[i, :, t], 0) if not (-1 <= n - prva <= 3): A[i, :, t] = -2 ** 20 # some small numbers A[i, min(N - 1, prva + 1), t] = 1 _, prva = torch.max(A[i, :, t], 0) A = F.softmax(A, dim=1) R = torch.bmm(V, A) R_prime = torch.cat((R, Q), 1) Y_logit = self.audio_dec(R_prime) Y = F.sigmoid(Y_logit) return Y_logit, Y, A ================================================ FILE: requirements.txt ================================================ librosa>=0.5.1 torch>=0.4 tensorboardX>=1.2 tqdm>=4.15.0 numpy>=1.25.0 scipy pandas requests scikit-image ================================================ FILE: synthesize.py ================================================ #!/usr/bin/env python """Synthetize sentences into speech.""" __author__ = 'Erdene-Ochir Tuguldur' import os import sys import argparse from tqdm import * import numpy as np import torch from models import Text2Mel, SSRN from hparams import HParams as hp from audio import save_to_wav from utils import get_last_checkpoint_file_name, load_checkpoint, save_to_png parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--dataset", required=True, choices=['ljspeech', 'mbspeech'], help='dataset name') args = parser.parse_args() if args.dataset == 'ljspeech': from datasets.lj_speech import vocab, get_test_data SENTENCES = [ "The birch canoe slid on the smooth planks.", "Glue the sheet to the dark blue background.", "It's easy to tell the depth of a well.", "These days a chicken leg is a rare dish.", "Rice is often served in round bowls.", "The juice of lemons makes fine punch.", "The box was thrown beside the parked truck.", "The hogs were fed chopped corn and garbage.", "Four hours of steady work faced us.", "Large size in stockings is hard to sell.", "The boy was there when the sun rose.", "A rod is used to catch pink salmon.", "The source of the huge river is the clear spring.", "Kick the ball straight and follow through.", "Help the woman get back to her feet.", "A pot of tea helps to pass the evening.", "Smoky fires lack flame and heat.", "The soft cushion broke the man's fall.", "The salt breeze came across from the sea.", "The girl at the booth sold fifty bonds." ] else: from datasets.mb_speech import vocab, get_test_data SENTENCES = [ "Нийслэлийн прокурорын газраас төрийн өндөр албан тушаалтнуудад холбогдох зарим эрүүгийн хэргүүдийг шүүхэд шилжүүлэв.", "Мөнх тэнгэрийн хүчин дор Монгол Улс цэцэглэн хөгжих болтугай.", "Унасан хүлгээ түрүү магнай, аман хүзүүнд уралдуулж, айрагдуулсан унаач хүүхдүүдэд бэлэг гардууллаа.", "Албан ёсоор хэлэхэд “Монгол Улсын хэрэг эрхлэх газрын гэгээнтэн” гэж нэрлээд байгаа зүйл огт байхгүй.", "Сайн чанарын бохирын хоолой зарна.", "Хараа тэглэх мэс заслын дараа хараа дахин муудах магадлал бага.", "Ер нь бол хараа тэглэх мэс заслыг гоо сайхны мэс засалтай адилхан гэж зүйрлэж болно.", "Хашлага даван, зүлэг гэмтээсэн жолоочийн эрхийг хоёр жилээр хасжээ.", "Монгол хүн бидний сэтгэлийг сорсон орон. Энэ бол миний төрсөн нутаг. Монголын сайхан орон.", "Постройка крейсера затягивалась из-за проектных неувязок, необходимости." ] torch.set_grad_enabled(False) text2mel = Text2Mel(vocab).eval() last_checkpoint_file_name = get_last_checkpoint_file_name(os.path.join(hp.logdir, '%s-text2mel' % args.dataset)) # last_checkpoint_file_name = 'logdir/%s-text2mel/step-020K.pth' % args.dataset if last_checkpoint_file_name: print("loading text2mel checkpoint '%s'..." % last_checkpoint_file_name) load_checkpoint(last_checkpoint_file_name, text2mel, None) else: print("text2mel not exits") sys.exit(1) ssrn = SSRN().eval() last_checkpoint_file_name = get_last_checkpoint_file_name(os.path.join(hp.logdir, '%s-ssrn' % args.dataset)) # last_checkpoint_file_name = 'logdir/%s-ssrn/step-005K.pth' % args.dataset if last_checkpoint_file_name: print("loading ssrn checkpoint '%s'..." % last_checkpoint_file_name) load_checkpoint(last_checkpoint_file_name, ssrn, None) else: print("ssrn not exits") sys.exit(1) # synthetize by one by one because there is a batch processing bug! for i in range(len(SENTENCES)): sentences = [SENTENCES[i]] max_N = len(SENTENCES[i]) L = torch.from_numpy(get_test_data(sentences, max_N)) zeros = torch.from_numpy(np.zeros((1, hp.n_mels, 1), np.float32)) Y = zeros A = None for t in tqdm(range(hp.max_T)): _, Y_t, A = text2mel(L, Y, monotonic_attention=True) Y = torch.cat((zeros, Y_t), -1) _, attention = torch.max(A[0, :, -1], 0) attention = attention.item() if L[0, attention] == vocab.index('E'): # EOS break _, Z = ssrn(Y) Y = Y.cpu().detach().numpy() A = A.cpu().detach().numpy() Z = Z.cpu().detach().numpy() save_to_png('samples/%d-att.png' % (i + 1), A[0, :, :]) save_to_png('samples/%d-mel.png' % (i + 1), Y[0, :, :]) save_to_png('samples/%d-mag.png' % (i + 1), Z[0, :, :]) save_to_wav(Z[0, :, :].T, 'samples/%d-wav.wav' % (i + 1)) ================================================ FILE: train-ssrn.py ================================================ #!/usr/bin/env python """Train the Text2Mel network. See: https://arxiv.org/abs/1710.08969""" __author__ = 'Erdene-Ochir Tuguldur' import sys import time import argparse from tqdm import * import torch import torch.nn.functional as F # project imports from models import SSRN from hparams import HParams as hp from logger import Logger from utils import get_last_checkpoint_file_name, load_checkpoint, save_checkpoint from datasets.data_loader import SSRNDataLoader parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--dataset", required=True, choices=['ljspeech', 'mbspeech'], help='dataset name') args = parser.parse_args() if args.dataset == 'ljspeech': from datasets.lj_speech import LJSpeech as SpeechDataset else: from datasets.mb_speech import MBSpeech as SpeechDataset use_gpu = torch.cuda.is_available() print('use_gpu', use_gpu) if use_gpu: torch.backends.cudnn.benchmark = True train_data_loader = SSRNDataLoader(ssrn_dataset=SpeechDataset(['mags', 'mels']), batch_size=24, mode='train') valid_data_loader = SSRNDataLoader(ssrn_dataset=SpeechDataset(['mags', 'mels']), batch_size=24, mode='valid') ssrn = SSRN().cuda() optimizer = torch.optim.Adam(ssrn.parameters(), lr=hp.ssrn_lr) start_timestamp = int(time.time() * 1000) start_epoch = 0 global_step = 0 logger = Logger(args.dataset, 'ssrn') # load the last checkpoint if exists last_checkpoint_file_name = get_last_checkpoint_file_name(logger.logdir) if last_checkpoint_file_name: print("loading the last checkpoint: %s" % last_checkpoint_file_name) start_epoch, global_step = load_checkpoint(last_checkpoint_file_name, ssrn, optimizer) def get_lr(): return optimizer.param_groups[0]['lr'] def lr_decay(step, warmup_steps=1000): new_lr = hp.ssrn_lr * warmup_steps ** 0.5 * min((step + 1) * warmup_steps ** -1.5, (step + 1) ** -0.5) optimizer.param_groups[0]['lr'] = new_lr def train(train_epoch, phase='train'): global global_step lr_decay(global_step) print("epoch %3d with lr=%.02e" % (train_epoch, get_lr())) ssrn.train() if phase == 'train' else ssrn.eval() torch.set_grad_enabled(True) if phase == 'train' else torch.set_grad_enabled(False) data_loader = train_data_loader if phase == 'train' else valid_data_loader it = 0 running_loss = 0.0 running_l1_loss = 0.0 pbar = tqdm(data_loader, unit="audios", unit_scale=data_loader.batch_size, disable=hp.disable_progress_bar) for batch in pbar: M, S = batch['mags'], batch['mels'] M = M.permute(0, 2, 1) # TODO: because of pre processing S = S.permute(0, 2, 1) # TODO: because of pre processing M.requires_grad = False M = M.cuda() S = S.cuda() Z_logit, Z = ssrn(S) l1_loss = F.l1_loss(Z, M) loss = l1_loss if phase == 'train': lr_decay(global_step) optimizer.zero_grad() loss.backward() optimizer.step() global_step += 1 it += 1 loss = loss.item() l1_loss = l1_loss.item() running_loss += loss running_l1_loss += l1_loss if phase == 'train': # update the progress bar pbar.set_postfix({ 'l1': "%.05f" % (running_l1_loss / it) }) logger.log_step(phase, global_step, {'loss_l1': l1_loss}, {'mags-true': M[:1, :, :], 'mags-pred': Z[:1, :, :], 'mels': S[:1, :, :]}) if global_step % 5000 == 0: # checkpoint at every 5000th step save_checkpoint(logger.logdir, train_epoch, global_step, ssrn, optimizer) epoch_loss = running_loss / it epoch_l1_loss = running_l1_loss / it logger.log_epoch(phase, global_step, {'loss_l1': epoch_l1_loss}) return epoch_loss since = time.time() epoch = start_epoch while True: train_epoch_loss = train(epoch, phase='train') time_elapsed = time.time() - since time_str = 'total time elapsed: {:.0f}h {:.0f}m {:.0f}s '.format(time_elapsed // 3600, time_elapsed % 3600 // 60, time_elapsed % 60) print("train epoch loss %f, step=%d, %s" % (train_epoch_loss, global_step, time_str)) valid_epoch_loss = train(epoch, phase='valid') print("valid epoch loss %f" % valid_epoch_loss) epoch += 1 if global_step >= hp.ssrn_max_iteration: print("max step %d (current step %d) reached, exiting..." % (hp.ssrn_max_iteration, global_step)) sys.exit(0) ================================================ FILE: train-text2mel.py ================================================ #!/usr/bin/env python """Train the Text2Mel network. See: https://arxiv.org/abs/1710.08969""" __author__ = 'Erdene-Ochir Tuguldur' import sys import time import argparse from tqdm import * import numpy as np import torch import torch.nn.functional as F # project imports from models import Text2Mel from hparams import HParams as hp from logger import Logger from utils import get_last_checkpoint_file_name, load_checkpoint, save_checkpoint from datasets.data_loader import Text2MelDataLoader parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--dataset", required=True, choices=['ljspeech', 'mbspeech'], help='dataset name') args = parser.parse_args() if args.dataset == 'ljspeech': from datasets.lj_speech import vocab, LJSpeech as SpeechDataset else: from datasets.mb_speech import vocab, MBSpeech as SpeechDataset use_gpu = torch.cuda.is_available() print('use_gpu', use_gpu) if use_gpu: torch.backends.cudnn.benchmark = True train_data_loader = Text2MelDataLoader(text2mel_dataset=SpeechDataset(['texts', 'mels', 'mel_gates']), batch_size=64, mode='train') valid_data_loader = Text2MelDataLoader(text2mel_dataset=SpeechDataset(['texts', 'mels', 'mel_gates']), batch_size=64, mode='valid') text2mel = Text2Mel(vocab).cuda() optimizer = torch.optim.Adam(text2mel.parameters(), lr=hp.text2mel_lr) start_timestamp = int(time.time() * 1000) start_epoch = 0 global_step = 0 logger = Logger(args.dataset, 'text2mel') # load the last checkpoint if exists last_checkpoint_file_name = get_last_checkpoint_file_name(logger.logdir) if last_checkpoint_file_name: print("loading the last checkpoint: %s" % last_checkpoint_file_name) start_epoch, global_step = load_checkpoint(last_checkpoint_file_name, text2mel, optimizer) def get_lr(): return optimizer.param_groups[0]['lr'] def lr_decay(step, warmup_steps=4000): new_lr = hp.text2mel_lr * warmup_steps ** 0.5 * min((step + 1) * warmup_steps ** -1.5, (step + 1) ** -0.5) optimizer.param_groups[0]['lr'] = new_lr def train(train_epoch, phase='train'): global global_step lr_decay(global_step) print("epoch %3d with lr=%.02e" % (train_epoch, get_lr())) text2mel.train() if phase == 'train' else text2mel.eval() torch.set_grad_enabled(True) if phase == 'train' else torch.set_grad_enabled(False) data_loader = train_data_loader if phase == 'train' else valid_data_loader it = 0 running_loss = 0.0 running_l1_loss = 0.0 running_att_loss = 0.0 pbar = tqdm(data_loader, unit="audios", unit_scale=data_loader.batch_size, disable=hp.disable_progress_bar) for batch in pbar: L, S, gates = batch['texts'], batch['mels'], batch['mel_gates'] S = S.permute(0, 2, 1) # TODO: because of pre processing B, N = L.size() # batch size and text count _, n_mels, T = S.size() # number of melspectrogram bins and time assert gates.size(0) == B # TODO: later remove assert gates.size(1) == T S_shifted = torch.cat((S[:, :, 1:], torch.zeros(B, n_mels, 1)), 2) S.requires_grad = False S_shifted.requires_grad = False gates.requires_grad = False def W_nt(_, n, t, g=0.2): return 1.0 - np.exp(-((n / float(N) - t / float(T)) ** 2) / (2 * g ** 2)) W = np.fromfunction(W_nt, (B, N, T), dtype=np.float32) W = torch.from_numpy(W) L = L.cuda() S = S.cuda() S_shifted = S_shifted.cuda() W = W.cuda() gates = gates.cuda() Y_logit, Y, A = text2mel(L, S) l1_loss = F.l1_loss(Y, S_shifted) masks = gates.reshape(B, 1, T).float() att_loss = (A * W * masks).mean() loss = l1_loss + att_loss if phase == 'train': lr_decay(global_step) optimizer.zero_grad() loss.backward() optimizer.step() global_step += 1 it += 1 loss, l1_loss, att_loss = loss.item(), l1_loss.item(), att_loss.item() running_loss += loss running_l1_loss += l1_loss running_att_loss += att_loss if phase == 'train': # update the progress bar pbar.set_postfix({ 'l1': "%.05f" % (running_l1_loss / it), 'att': "%.05f" % (running_att_loss / it) }) logger.log_step(phase, global_step, {'loss_l1': l1_loss, 'loss_att': att_loss}, {'mels-true': S[:1, :, :], 'mels-pred': Y[:1, :, :], 'attention': A[:1, :, :]}) if global_step % 5000 == 0: # checkpoint at every 5000th step save_checkpoint(logger.logdir, train_epoch, global_step, text2mel, optimizer) epoch_loss = running_loss / it epoch_l1_loss = running_l1_loss / it epoch_att_loss = running_att_loss / it logger.log_epoch(phase, global_step, {'loss_l1': epoch_l1_loss, 'loss_att': epoch_att_loss}) return epoch_loss since = time.time() epoch = start_epoch while True: train_epoch_loss = train(epoch, phase='train') time_elapsed = time.time() - since time_str = 'total time elapsed: {:.0f}h {:.0f}m {:.0f}s '.format(time_elapsed // 3600, time_elapsed % 3600 // 60, time_elapsed % 60) print("train epoch loss %f, step=%d, %s" % (train_epoch_loss, global_step, time_str)) valid_epoch_loss = train(epoch, phase='valid') print("valid epoch loss %f" % valid_epoch_loss) epoch += 1 if global_step >= hp.text2mel_max_iteration: print("max step %d (current step %d) reached, exiting..." % (hp.text2mel_max_iteration, global_step)) sys.exit(0) ================================================ FILE: utils.py ================================================ """Utility methods.""" __author__ = 'Erdene-Ochir Tuguldur' import os import sys import glob import torch import math import requests from tqdm import tqdm from skimage.io import imsave from skimage import img_as_ubyte def get_last_checkpoint_file_name(logdir): """Returns the last checkpoint file name in the given log dir path.""" checkpoints = glob.glob(os.path.join(logdir, '*.pth')) checkpoints.sort() if len(checkpoints) == 0: return None return checkpoints[-1] def load_checkpoint(checkpoint_file_name, model, optimizer): """Loads the checkpoint into the given model and optimizer.""" checkpoint = torch.load(checkpoint_file_name) model.load_state_dict(checkpoint['state_dict']) model.float() if optimizer is not None: optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint.get('epoch', 0) global_step = checkpoint.get('global_step', 0) del checkpoint print("loaded checkpoint epoch=%d step=%d" % (start_epoch, global_step)) return start_epoch, global_step def save_checkpoint(logdir, epoch, global_step, model, optimizer): """Saves the training state into the given log dir path.""" checkpoint_file_name = os.path.join(logdir, 'step-%03dK.pth' % (global_step // 1000)) print("saving the checkpoint file '%s'..." % checkpoint_file_name) checkpoint = { 'epoch': epoch + 1, 'global_step': global_step, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), } torch.save(checkpoint, checkpoint_file_name) del checkpoint def download_file(url, file_path): """Downloads a file from the given URL.""" print("downloading %s..." % url) r = requests.get(url, stream=True) total_size = int(r.headers.get('content-length', 0)) block_size = 1024 * 1024 wrote = 0 with open(file_path, 'wb') as f: for data in tqdm(r.iter_content(block_size), total=math.ceil(total_size // block_size), unit='MB'): wrote = wrote + len(data) f.write(data) if total_size != 0 and wrote != total_size: print("downloading failed") sys.exit(1) def save_to_png(file_name, array): """Save the given numpy array as a PNG file.""" # from skimage._shared._warnings import expected_warnings # with expected_warnings(['precision']): imsave(file_name, img_as_ubyte(array))