[
  {
    "path": ".gitignore",
    "content": ".idea\n.DS_Store\n__pycache__\n.ipynb_checkpoints\n*.ipynb\nlogdir/\nsamples\n*.npy\n*.tar.bz2\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2019 Erdene-Ochir Tuguldur\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "PyTorch implementation of\n[Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention](https://arxiv.org/abs/1710.08969)\nbased partially on the following projects:\n* https://github.com/Kyubyong/dc_tts (audio pre processing)\n* https://github.com/r9y9/deepvoice3_pytorch (data loader sampler)\n\n## Online Text-To-Speech Demo\nThe following notebooks are executable on [https://colab.research.google.com ](https://colab.research.google.com):\n* [Mongolian Male Voice TTS Demo](https://colab.research.google.com/github/tugstugi/pytorch-dc-tts/blob/master/notebooks/MongolianTTS.ipynb)\n* [English Female Voice TTS Demo (LJ-Speech)](https://colab.research.google.com/github/tugstugi/pytorch-dc-tts/blob/master/notebooks/EnglishTTS.ipynb)\n\nFor audio samples and pretrained models, visit the above notebook links.\n\n## Training/Synthesizing English Text-To-Speech\nThe English TTS uses the [LJ-Speech](https://keithito.com/LJ-Speech-Dataset/) dataset.\n1. Download the dataset: `python dl_and_preprop_dataset.py --dataset=ljspeech`\n2. Train the Text2Mel model: `python train-text2mel.py --dataset=ljspeech`\n3. Train the SSRN model: `python train-ssrn.py --dataset=ljspeech`\n4. Synthesize sentences: `python synthesize.py --dataset=ljspeech`\n   * The WAV files are saved in the `samples` folder.\n\n## Training/Synthesizing Mongolian Text-To-Speech\nThe Mongolian text-to-speech uses 5 hours audio from the [Mongolian Bible](https://www.bible.com/mn/versions/1590-2013-ariun-bibli-2013).\n1. Download the dataset: `python dl_and_preprop_dataset.py --dataset=mbspeech`\n2. Train the Text2Mel model: `python train-text2mel.py --dataset=mbspeech`\n3. Train the SSRN model: `python train-ssrn.py --dataset=mbspeech`\n4. Synthesize sentences: `python synthesize.py --dataset=mbspeech`\n   * The WAV files are saved in the `samples` folder.\n"
  },
  {
    "path": "audio.py",
    "content": "\"\"\"These methods are copied from https://github.com/Kyubyong/dc_tts/\"\"\"\n\nimport os\nimport copy\nimport librosa\nimport scipy.io.wavfile\nimport numpy as np\n\nfrom tqdm import tqdm\nfrom scipy import signal\nfrom hparams import HParams as hp\n\n\ndef spectrogram2wav(mag):\n    '''# Generate wave file from linear magnitude spectrogram\n    Args:\n      mag: A numpy array of (T, 1+n_fft//2)\n    Returns:\n      wav: A 1-D numpy array.\n    '''\n    # transpose\n    mag = mag.T\n\n    # de-noramlize\n    mag = (np.clip(mag, 0, 1) * hp.max_db) - hp.max_db + hp.ref_db\n\n    # to amplitude\n    mag = np.power(10.0, mag * 0.05)\n\n    # wav reconstruction\n    wav = griffin_lim(mag ** hp.power)\n\n    # de-preemphasis\n    wav = signal.lfilter([1], [1, -hp.preemphasis], wav)\n\n    # trim\n    wav, _ = librosa.effects.trim(wav)\n\n    return wav.astype(np.float32)\n\n\ndef griffin_lim(spectrogram):\n    '''Applies Griffin-Lim's raw.'''\n    X_best = copy.deepcopy(spectrogram)\n    for i in range(hp.n_iter):\n        X_t = invert_spectrogram(X_best)\n        est = librosa.stft(X_t, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length)\n        phase = est / np.maximum(1e-8, np.abs(est))\n        X_best = spectrogram * phase\n    X_t = invert_spectrogram(X_best)\n    y = np.real(X_t)\n\n    return y\n\n\ndef invert_spectrogram(spectrogram):\n    '''Applies inverse fft.\n    Args:\n      spectrogram: [1+n_fft//2, t]\n    '''\n    return librosa.istft(spectrogram, hop_length=hp.hop_length, win_length=hp.win_length, window=\"hann\")\n\n\ndef get_spectrograms(fpath):\n    '''Parse the wave file in `fpath` and\n    Returns normalized melspectrogram and linear spectrogram.\n    Args:\n      fpath: A string. The full path of a sound file.\n    Returns:\n      mel: A 2d array of shape (T, n_mels) and dtype of float32.\n      mag: A 2d array of shape (T, 1+n_fft/2) and dtype of float32.\n    '''\n    # Loading sound file\n    y, sr = librosa.load(fpath, sr=hp.sr)\n\n    # Trimming\n    y, _ = librosa.effects.trim(y)\n\n    # Preemphasis\n    y = np.append(y[0], y[1:] - hp.preemphasis * y[:-1])\n\n    # stft\n    linear = librosa.stft(y=y,\n                          n_fft=hp.n_fft,\n                          hop_length=hp.hop_length,\n                          win_length=hp.win_length)\n\n    # magnitude spectrogram\n    mag = np.abs(linear)  # (1+n_fft//2, T)\n\n    # mel spectrogram\n    mel_basis = librosa.filters.mel(sr=hp.sr, n_fft=hp.n_fft, n_mels=hp.n_mels)  # (n_mels, 1+n_fft//2)\n    mel = np.dot(mel_basis, mag)  # (n_mels, t)\n\n    # to decibel\n    mel = 20 * np.log10(np.maximum(1e-5, mel))\n    mag = 20 * np.log10(np.maximum(1e-5, mag))\n\n    # normalize\n    mel = np.clip((mel - hp.ref_db + hp.max_db) / hp.max_db, 1e-8, 1)\n    mag = np.clip((mag - hp.ref_db + hp.max_db) / hp.max_db, 1e-8, 1)\n\n    # Transpose\n    mel = mel.T.astype(np.float32)  # (T, n_mels)\n    mag = mag.T.astype(np.float32)  # (T, 1+n_fft//2)\n\n    return mel, mag\n\n\ndef save_to_wav(mag, filename):\n    \"\"\"Generate and save an audio file from the given linear spectrogram using Griffin-Lim.\"\"\"\n    wav = spectrogram2wav(mag)\n    scipy.io.wavfile.write(filename, hp.sr, wav)\n\n\ndef preprocess(dataset_path, speech_dataset):\n    \"\"\"Preprocess the given dataset.\"\"\"\n    wavs_path = os.path.join(dataset_path, 'wavs')\n    mels_path = os.path.join(dataset_path, 'mels')\n    if not os.path.isdir(mels_path):\n        os.mkdir(mels_path)\n    mags_path = os.path.join(dataset_path, 'mags')\n    if not os.path.isdir(mags_path):\n        os.mkdir(mags_path)\n\n    for fname in tqdm(speech_dataset.fnames):\n        mel, mag = get_spectrograms(os.path.join(wavs_path, '%s.wav' % fname))\n\n        t = mel.shape[0]\n        # Marginal padding for reduction shape sync.\n        num_paddings = hp.reduction_rate - (t % hp.reduction_rate) if t % hp.reduction_rate != 0 else 0\n        mel = np.pad(mel, [[0, num_paddings], [0, 0]], mode=\"constant\")\n        mag = np.pad(mag, [[0, num_paddings], [0, 0]], mode=\"constant\")\n        # Reduction\n        mel = mel[::hp.reduction_rate, :]\n\n        np.save(os.path.join(mels_path, '%s.npy' % fname), mel)\n        np.save(os.path.join(mags_path, '%s.npy' % fname), mag)\n"
  },
  {
    "path": "datasets/.gitignore",
    "content": "LJSpeech-1.1/\nMBSpeech-1.0/\n*.tar.gz\n"
  },
  {
    "path": "datasets/__init__.py",
    "content": ""
  },
  {
    "path": "datasets/data_loader.py",
    "content": "import random\n\nimport numpy as np\nimport torch\nfrom torch.utils.data.dataloader import default_collate, DataLoader\nfrom torch.utils.data.sampler import Sampler\n\n__all__ = ['Text2MelDataLoader', 'SSRNDataLoader']\n\n\nclass Text2MelDataLoader(DataLoader):\n    def __init__(self, text2mel_dataset, batch_size, mode='train', num_workers=8):\n        if mode == 'train':\n            text2mel_dataset.slice(0, -batch_size)\n        elif mode == 'valid':\n            text2mel_dataset.slice(len(text2mel_dataset) - batch_size, -1)\n        else:\n            raise ValueError(\"mode must be either 'train' or 'valid'\")\n        super().__init__(text2mel_dataset,\n                         batch_size=batch_size,\n                         num_workers=num_workers,\n                         collate_fn=collate_fn,\n                         shuffle=True)\n\n\nclass SSRNDataLoader(DataLoader):\n    def __init__(self, ssrn_dataset, batch_size, mode='train', num_workers=8):\n        if mode == 'train':\n            ssrn_dataset.slice(0, -batch_size)\n            super().__init__(ssrn_dataset,\n                             batch_size=batch_size,\n                             num_workers=num_workers,\n                             collate_fn=collate_fn,\n                             sampler=PartiallyRandomizedSimilarTimeLengthSampler(lengths=ssrn_dataset.text_lengths,\n                                                                                 data_source=None,\n                                                                                 batch_size=batch_size))\n        elif mode == 'valid':\n            ssrn_dataset.slice(len(ssrn_dataset) - batch_size, -1)\n            super().__init__(ssrn_dataset,\n                             batch_size=batch_size,\n                             num_workers=num_workers,\n                             collate_fn=collate_fn,\n                             shuffle=True)\n        else:\n            raise ValueError(\"mode must be either 'train' or 'valid'\")\n\n\ndef collate_fn(batch):\n    keys = batch[0].keys()\n    max_lengths = {key: 0 for key in keys}\n    collated_batch = {key: [] for key in keys}\n\n    # find out the max lengths\n    for row in batch:\n        for key in keys:\n            max_lengths[key] = max(max_lengths[key], row[key].shape[0])\n\n    # pad to the max lengths\n    for row in batch:\n        for key in keys:\n            array = row[key]\n            dim = len(array.shape)\n            assert dim == 1 or dim == 2\n            # TODO: because of pre processing, later we want to have (n_mels, T)\n            if dim == 1:\n                padded_array = np.pad(array, (0, max_lengths[key] - array.shape[0]), mode='constant')\n            else:\n                padded_array = np.pad(array, ((0, max_lengths[key] - array.shape[0]), (0, 0)), mode='constant')\n            collated_batch[key].append(padded_array)\n\n    # use the default_collate to convert to tensors\n    for key in keys:\n        collated_batch[key] = default_collate(collated_batch[key])\n    return collated_batch\n\n\nclass PartiallyRandomizedSimilarTimeLengthSampler(Sampler):\n    \"\"\"Copied from: https://github.com/r9y9/deepvoice3_pytorch/blob/master/train.py.\n    Partially randomized sampler\n    1. Sort by lengths\n    2. Pick a small patch and randomize it\n    3. Permutate mini-batches\n    \"\"\"\n\n    def __init__(self, lengths, data_source, batch_size=16, batch_group_size=None, permutate=True):\n        super().__init__(data_source)\n        self.lengths, self.sorted_indices = torch.sort(torch.LongTensor(lengths))\n        self.batch_size = batch_size\n        if batch_group_size is None:\n            batch_group_size = min(batch_size * 32, len(self.lengths))\n            if batch_group_size % batch_size != 0:\n                batch_group_size -= batch_group_size % batch_size\n\n        self.batch_group_size = batch_group_size\n        assert batch_group_size % batch_size == 0\n        self.permutate = permutate\n\n    def __iter__(self):\n        indices = self.sorted_indices.clone()\n        batch_group_size = self.batch_group_size\n        s, e = 0, 0\n        for i in range(len(indices) // batch_group_size):\n            s = i * batch_group_size\n            e = s + batch_group_size\n            random.shuffle(indices[s:e])\n\n        # Permutate batches\n        if self.permutate:\n            perm = np.arange(len(indices[:e]) // self.batch_size)\n            random.shuffle(perm)\n            indices[:e] = indices[:e].view(-1, self.batch_size)[perm, :].view(-1)\n\n        # Handle last elements\n        s += batch_group_size\n        if s < len(indices):\n            random.shuffle(indices[s:])\n\n        return iter(indices)\n\n    def __len__(self):\n        return len(self.sorted_indices)\n"
  },
  {
    "path": "datasets/lj_speech.py",
    "content": "\"\"\"Data loader for the LJSpeech dataset. See: https://keithito.com/LJ-Speech-Dataset/\"\"\"\nimport os\nimport re\nimport codecs\nimport unicodedata\nimport numpy as np\n\nfrom torch.utils.data import Dataset\n\nvocab = \"PE abcdefghijklmnopqrstuvwxyz'.?\"  # P: Padding, E: EOS.\nchar2idx = {char: idx for idx, char in enumerate(vocab)}\nidx2char = {idx: char for idx, char in enumerate(vocab)}\n\n\ndef text_normalize(text):\n    text = ''.join(char for char in unicodedata.normalize('NFD', text)\n                   if unicodedata.category(char) != 'Mn')  # Strip accents\n\n    text = text.lower()\n    text = re.sub(\"[^{}]\".format(vocab), \" \", text)\n    text = re.sub(\"[ ]+\", \" \", text)\n    return text\n\n\ndef read_metadata(metadata_file):\n    fnames, text_lengths, texts = [], [], []\n    transcript = os.path.join(metadata_file)\n    lines = codecs.open(transcript, 'r', 'utf-8').readlines()\n    for line in lines:\n        fname, _, text = line.strip().split(\"|\")\n\n        fnames.append(fname)\n\n        text = text_normalize(text) + \"E\"  # E: EOS\n        text = [char2idx[char] for char in text]\n        text_lengths.append(len(text))\n        texts.append(np.array(text, np.longlong))\n\n    return fnames, text_lengths, texts\n\n\ndef get_test_data(sentences, max_n):\n    normalized_sentences = [text_normalize(line).strip() + \"E\" for line in sentences]  # text normalization, E: EOS\n    texts = np.zeros((len(normalized_sentences), max_n + 1), np.longlong)\n    for i, sent in enumerate(normalized_sentences):\n        texts[i, :len(sent)] = [char2idx[char] for char in sent]\n    return texts\n\n\nclass LJSpeech(Dataset):\n    def __init__(self, keys, dir_name='LJSpeech-1.1'):\n        self.keys = keys\n        self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), dir_name)\n        self.fnames, self.text_lengths, self.texts = read_metadata(os.path.join(self.path, 'metadata.csv'))\n\n    def slice(self, start, end):\n        self.fnames = self.fnames[start:end]\n        self.text_lengths = self.text_lengths[start:end]\n        self.texts = self.texts[start:end]\n\n    def __len__(self):\n        return len(self.fnames)\n\n    def __getitem__(self, index):\n        data = {}\n        if 'texts' in self.keys:\n            data['texts'] = self.texts[index]\n        if 'mels' in self.keys:\n            # (39, 80)\n            data['mels'] = np.load(os.path.join(self.path, 'mels', \"%s.npy\" % self.fnames[index]))\n        if 'mags' in self.keys:\n            # (39, 80)\n            data['mags'] = np.load(os.path.join(self.path, 'mags', \"%s.npy\" % self.fnames[index]))\n        if 'mel_gates' in self.keys:\n            data['mel_gates'] = np.ones(data['mels'].shape[0], dtype=np.int64)  # TODO: because pre processing!\n        if 'mag_gates' in self.keys:\n            data['mag_gates'] = np.ones(data['mags'].shape[0], dtype=np.int64)  # TODO: because pre processing!\n        return data\n"
  },
  {
    "path": "datasets/mb_speech.py",
    "content": "\"\"\"Data loader for the Mongolian Bible dataset.\"\"\"\nimport os\nimport codecs\nimport numpy as np\n\nfrom torch.utils.data import Dataset\n\nvocab = \"PE абвгдеёжзийклмноөпрстуүфхцчшъыьэюя-.,!?\"  # P: Padding, E: EOS.\nchar2idx = {char: idx for idx, char in enumerate(vocab)}\nidx2char = {idx: char for idx, char in enumerate(vocab)}\n\n\ndef text_normalize(text):\n    text = text.lower()\n    # text = text.replace(\",\", \"'\")\n    # text = text.replace(\"!\", \"?\")\n    for c in \"-—:\":\n        text = text.replace(c, \"-\")\n    for c in \"()\\\"«»“”'\":\n        text = text.replace(c, \",\")\n    return text\n\n\ndef read_metadata(metadata_file):\n    fnames, text_lengths, texts = [], [], []\n    transcript = os.path.join(metadata_file)\n    lines = codecs.open(transcript, 'r', 'utf-8').readlines()\n    for line in lines:\n        fname, _, text = line.strip().split(\"|\")\n\n        fnames.append(fname)\n\n        text = text_normalize(text) + \"E\"  # E: EOS\n        text = [char2idx[char] for char in text]\n        text_lengths.append(len(text))\n        texts.append(np.array(text, np.longlong))\n\n    return fnames, text_lengths, texts\n\n\ndef get_test_data(sentences, max_n):\n    normalized_sentences = [text_normalize(line).strip() + \"E\" for line in sentences]  # text normalization, E: EOS\n    texts = np.zeros((len(normalized_sentences), max_n + 1), np.longlong)\n    for i, sent in enumerate(normalized_sentences):\n        texts[i, :len(sent)] = [char2idx[char] for char in sent]\n    return texts\n\n\nclass MBSpeech(Dataset):\n    def __init__(self, keys, dir_name='MBSpeech-1.0'):\n        self.keys = keys\n        self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), dir_name)\n        self.fnames, self.text_lengths, self.texts = read_metadata(os.path.join(self.path, 'metadata.csv'))\n\n    def slice(self, start, end):\n        self.fnames = self.fnames[start:end]\n        self.text_lengths = self.text_lengths[start:end]\n        self.texts = self.texts[start:end]\n\n    def __len__(self):\n        return len(self.fnames)\n\n    def __getitem__(self, index):\n        data = {}\n        if 'texts' in self.keys:\n            data['texts'] = self.texts[index]\n        if 'mels' in self.keys:\n            # (39, 80)\n            data['mels'] = np.load(os.path.join(self.path, 'mels', \"%s.npy\" % self.fnames[index]))\n        if 'mags' in self.keys:\n            # (39, 80)\n            data['mags'] = np.load(os.path.join(self.path, 'mags', \"%s.npy\" % self.fnames[index]))\n        if 'mel_gates' in self.keys:\n            data['mel_gates'] = np.ones(data['mels'].shape[0], dtype=np.int64)  # TODO: because pre processing!\n        if 'mag_gates' in self.keys:\n            data['mag_gates'] = np.ones(data['mags'].shape[0], dtype=np.int64)  # TODO: because pre processing!\n        return data\n\n#\n# simple method to convert mongolian numbers to text, copied from somewhere\n#\n\n\ndef number2word(number):\n    digit_len = len(number)\n    digit_name = {1: '', 2: 'мянга', 3: 'сая', 4: 'тэрбум', 5: 'их наяд', 6: 'тунамал'}\n\n    if digit_len == 1:\n        return _last_digit_2_str(number)\n    if digit_len == 2:\n        return _2_digits_2_str(number)\n    if digit_len == 3:\n        return _3_digits_to_str(number)\n    if digit_len < 7:\n        return _3_digits_to_str(number[:-3], False) + ' ' + digit_name[2] + ' ' + _3_digits_to_str(number[-3:])\n\n    digitgroup = [number[0 if i - 3 < 0 else i - 3:i] for i in reversed(range(len(number), 0, -3))]\n    count = len(digitgroup)\n    i = 0\n    result = ''\n    while i < count - 1:\n        result += ' ' + (_3_digits_to_str(digitgroup[i], False) + ' ' + digit_name[count - i])\n        i += 1\n    return result.strip() + ' ' + _3_digits_to_str(digitgroup[-1])\n\n\ndef _1_digit_2_str(digit):\n    return {'0': '', '1': 'нэгэн', '2': 'хоёр', '3': 'гурван', '4': 'дөрвөн', '5': 'таван', '6': 'зургаан',\n            '7': 'долоон', '8': 'найман', '9': 'есөн'}[digit]\n\n\ndef _last_digit_2_str(digit):\n    return {'0': 'тэг', '1': 'нэг', '2': 'хоёр', '3': 'гурав', '4': 'дөрөв', '5': 'тав', '6': 'зургаа', '7': 'долоо',\n            '8': 'найм', '9': 'ес'}[digit]\n\n\ndef _2_digits_2_str(digit, is_fina=True):\n    word2 = {'0': '', '1': 'арван', '2': 'хорин', '3': 'гучин', '4': 'дөчин', '5': 'тавин', '6': 'жаран', '7': 'далан',\n             '8': 'наян', '9': 'ерэн'}\n    word2fina = {'10': 'арав', '20': 'хорь', '30': 'гуч', '40': 'дөч', '50': 'тавь', '60': 'жар', '70': 'дал',\n                 '80': 'ная', '90': 'ер'}\n    if digit[1] == '0':\n        return word2fina[digit] if is_fina else word2[digit[0]]\n    digit1 = _last_digit_2_str(digit[1]) if is_fina else _1_digit_2_str(digit[1])\n    return (word2[digit[0]] + ' ' + digit1).strip()\n\n\ndef _3_digits_to_str(digit, is_fina=True):\n    digstr = digit.lstrip('0')\n    if len(digstr) == 0:\n        return ''\n    if len(digstr) == 1:\n        return _1_digit_2_str(digstr)\n    if len(digstr) == 2:\n        return _2_digits_2_str(digstr, is_fina)\n    if digit[-2:] == '00':\n        return _1_digit_2_str(digit[0]) + ' зуу' if is_fina else _1_digit_2_str(digit[0]) + ' зуун'\n    else:\n        return _1_digit_2_str(digit[0]) + ' зуун ' + _2_digits_2_str(digit[-2:], is_fina)\n"
  },
  {
    "path": "dl_and_preprop_dataset.py",
    "content": "#!/usr/bin/env python\n\"\"\"Download and preprocess datasets. Supported datasets are:\n  * English female: LJSpeech (https://keithito.com/LJ-Speech-Dataset/)\n  * Mongolian male: MBSpeech (Mongolian Bible)\n\"\"\"\n__author__ = 'Erdene-Ochir Tuguldur'\n\nimport os\nimport sys\nimport csv\nimport time\nimport argparse\nimport fnmatch\nimport librosa\nimport pandas as pd\n\nfrom hparams import HParams as hp\nfrom zipfile import ZipFile\nfrom audio import preprocess\nfrom utils import download_file\nfrom datasets.mb_speech import MBSpeech\nfrom datasets.lj_speech import LJSpeech\n\nparser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)\nparser.add_argument(\"--dataset\", required=True, choices=['ljspeech', 'mbspeech'], help='dataset name')\nargs = parser.parse_args()\n\nif args.dataset == 'ljspeech':\n    dataset_file_name = 'LJSpeech-1.1.tar.bz2'\n    datasets_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'datasets')\n    dataset_path = os.path.join(datasets_path, 'LJSpeech-1.1')\n\n    if os.path.isdir(dataset_path) and False:\n        print(\"LJSpeech dataset folder already exists\")\n        sys.exit(0)\n    else:\n        dataset_file_path = os.path.join(datasets_path, dataset_file_name)\n        if not os.path.isfile(dataset_file_path):\n            url = \"http://data.keithito.com/data/speech/%s\" % dataset_file_name\n            download_file(url, dataset_file_path)\n        else:\n            print(\"'%s' already exists\" % dataset_file_name)\n\n        print(\"extracting '%s'...\" % dataset_file_name)\n        os.system('cd %s; tar xvjf %s' % (datasets_path, dataset_file_name))\n\n        # pre process\n        print(\"pre processing...\")\n        lj_speech = LJSpeech([])\n        preprocess(dataset_path, lj_speech)\nelif args.dataset == 'mbspeech':\n    dataset_name = 'MBSpeech-1.0'\n    datasets_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'datasets')\n    dataset_path = os.path.join(datasets_path, dataset_name)\n\n    if os.path.isdir(dataset_path) and False:\n        print(\"MBSpeech dataset folder already exists\")\n        sys.exit(0)\n    else:\n        bible_books = ['01_Genesis', '02_Exodus', '03_Leviticus']\n        for bible_book_name in bible_books:\n            bible_book_file_name = '%s.zip' % bible_book_name\n            bible_book_file_path = os.path.join(datasets_path, bible_book_file_name)\n            if not os.path.isfile(bible_book_file_path):\n                url = \"https://s3.us-east-2.amazonaws.com/bible.davarpartners.com/Mongolian/\" + bible_book_file_name\n                download_file(url, bible_book_file_path)\n            else:\n                print(\"'%s' already exists\" % bible_book_file_name)\n\n            print(\"extracting '%s'...\" % bible_book_file_name)\n            zipfile = ZipFile(bible_book_file_path)\n            zipfile.extractall(datasets_path)\n\n    dataset_csv_file_path = os.path.join(datasets_path, '%s-csv.zip' % dataset_name)\n    dataset_csv_extracted_path = os.path.join(datasets_path, '%s-csv' % dataset_name)\n    if not os.path.isfile(dataset_csv_file_path):\n        url = \"https://www.dropbox.com/s/dafueq0w278lbz6/%s-csv.zip?dl=1\" % dataset_name\n        download_file(url, dataset_csv_file_path)\n    else:\n        print(\"'%s' already exists\" % dataset_csv_file_path)\n\n    print(\"extracting '%s'...\" % dataset_csv_file_path)\n    zipfile = ZipFile(dataset_csv_file_path)\n    zipfile.extractall(datasets_path)\n\n    sample_rate = 44100  # original sample rate\n    total_duration_s = 0\n\n    if not os.path.isdir(dataset_path):\n        os.mkdir(dataset_path)\n    wavs_path = os.path.join(dataset_path, 'wavs')\n    if not os.path.isdir(wavs_path):\n        os.mkdir(wavs_path)\n\n    metadata_csv = open(os.path.join(dataset_path, 'metadata.csv'), 'w')\n    metadata_csv_writer = csv.writer(metadata_csv, delimiter='|')\n\n\n    def _normalize(s):\n        \"\"\"remove leading '-'\"\"\"\n        s = s.strip()\n        if s[0] == '—' or s[0] == '-':\n            s = s[1:].strip()\n        return s\n\n\n    def _get_mp3_file(book_name, chapter):\n        book_download_path = os.path.join(datasets_path, book_name)\n        wildcard = \"*%02d - DPI.mp3\" % chapter\n        for file_name in os.listdir(book_download_path):\n            if fnmatch.fnmatch(file_name, wildcard):\n                return os.path.join(book_download_path, file_name)\n        return None\n\n\n    def _convert_mp3_to_wav(book_name, book_nr):\n        global total_duration_s\n        chapter = 1\n        while True:\n            try:\n                i = 0\n                chapter_csv_file_name = os.path.join(dataset_csv_extracted_path, \"%s_%02d.csv\" % (book_name, chapter))\n                df = pd.read_csv(chapter_csv_file_name, sep=\"|\")\n                print(\"processing %s...\" % chapter_csv_file_name)\n                mp3_file = _get_mp3_file(book_name, chapter)\n                print(\"processing %s...\" % mp3_file)\n                assert mp3_file is not None\n                samples, sr = librosa.load(mp3_file, sr=sample_rate, mono=True)\n                assert sr == sample_rate\n\n                for index, row in df.iterrows():\n                    start, end, sentence = row['start'], row['end'], row['sentence']\n                    assert end > start\n                    duration = end - start\n                    duration_s = int(duration / sample_rate)\n                    if duration_s > 10:\n                        continue  # only audios shorter than 10s\n\n                    total_duration_s += duration_s\n                    i += 1\n                    sentence = _normalize(sentence)\n                    fn = \"MB%d%02d-%04d\" % (book_nr, chapter, i)\n                    metadata_csv_writer.writerow([fn, sentence, sentence])  # same format as LJSpeech\n                    wav = samples[start:end]\n                    wav = librosa.resample(wav, sample_rate, hp.sr)  # use same sample rate as LJSpeech\n                    librosa.output.write_wav(os.path.join(wavs_path, fn + \".wav\"), wav, hp.sr)\n\n                chapter += 1\n            except FileNotFoundError:\n                break\n\n\n    _convert_mp3_to_wav('01_Genesis', 1)\n    _convert_mp3_to_wav('02_Exodus', 2)\n    _convert_mp3_to_wav('03_Leviticus', 3)\n    metadata_csv.close()\n    print(\"total audio duration: %ss\" % (time.strftime('%H:%M:%S', time.gmtime(total_duration_s))))\n\n    # pre process\n    print(\"pre processing...\")\n    mb_speech = MBSpeech([])\n    preprocess(dataset_path, mb_speech)\n"
  },
  {
    "path": "hparams.py",
    "content": "\"\"\"Hyper parameters.\"\"\"\n__author__ = 'Erdene-Ochir Tuguldur'\n\n\nclass HParams:\n    \"\"\"Hyper parameters\"\"\"\n\n    disable_progress_bar = False  # set True if you don't want the progress bar in the console\n\n    logdir = \"logdir\"  # log dir where the checkpoints and tensorboard files are saved\n\n    # audio.py options, these values are from https://github.com/Kyubyong/dc_tts/blob/master/hyperparams.py\n    reduction_rate = 4  # melspectrogram reduction rate, don't change because SSRN is using this rate\n    n_fft = 2048 # fft points (samples)\n    n_mels = 80  # Number of Mel banks to generate\n    power = 1.5  # Exponent for amplifying the predicted magnitude\n    n_iter = 50  # Number of inversion iterations\n    preemphasis = .97\n    max_db = 100\n    ref_db = 20\n    sr = 22050  # Sampling rate\n    frame_shift = 0.0125  # seconds\n    frame_length = 0.05  # seconds\n    hop_length = int(sr * frame_shift)  # samples. =276.\n    win_length = int(sr * frame_length)  # samples. =1102.\n    max_N = 180  # Maximum number of characters.\n    max_T = 210  # Maximum number of mel frames.\n\n    e = 128  # embedding dimension\n    d = 256  # Text2Mel hidden unit dimension\n    c = 512+128  # SSRN hidden unit dimension\n\n    dropout_rate = 0.05  # dropout\n\n    # Text2Mel network options\n    text2mel_lr = 0.005  # learning rate\n    text2mel_max_iteration = 300000  # max train step\n    text2mel_weight_init = 'none'  # 'kaiming', 'xavier' or 'none'\n    text2mel_normalization = 'layer'  # 'layer', 'weight' or 'none'\n    text2mel_basic_block = 'gated_conv'  # 'highway', 'gated_conv' or 'residual'\n\n    # SSRN network options\n    ssrn_lr = 0.0005  # learning rate\n    ssrn_max_iteration = 150000  # max train step\n    ssrn_weight_init = 'kaiming'  # 'kaiming', 'xavier' or 'none'\n    ssrn_normalization = 'weight'  # 'layer', 'weight' or 'none'\n    ssrn_basic_block = 'residual'  # 'highway', 'gated_conv' or 'residual'\n"
  },
  {
    "path": "logger.py",
    "content": "\"\"\"Wrapper class for logging into the TensorBoard and comet.ml\"\"\"\n__author__ = 'Erdene-Ochir Tuguldur'\n__all__ = ['Logger']\n\nimport os\nfrom tensorboardX import SummaryWriter\n\nfrom hparams import HParams as hp\n\n\nclass Logger(object):\n\n    def __init__(self, dataset_name, model_name):\n        self.model_name = model_name\n        self.project_name = \"%s-%s\" % (dataset_name, self.model_name)\n        self.logdir = os.path.join(hp.logdir, self.project_name)\n        self.writer = SummaryWriter(log_dir=self.logdir)\n\n    def log_step(self, phase, step, loss_dict, image_dict):\n        if phase == 'train':\n            if step % 50 == 0:\n                # self.writer.add_scalar('lr', get_lr(), step)\n                # self.writer.add_scalar('%s-step/loss' % phase, loss, step)\n                for key in sorted(loss_dict):\n                    self.writer.add_scalar('%s-step/%s' % (phase, key), loss_dict[key], step)\n\n            if step % 1000 == 0:\n                for key in sorted(image_dict):\n                    self.writer.add_image('%s/%s' % (self.model_name, key), image_dict[key], step)\n\n    def log_epoch(self, phase, step, loss_dict):\n        for key in sorted(loss_dict):\n            self.writer.add_scalar('%s/%s' % (phase, key), loss_dict[key], step)\n"
  },
  {
    "path": "models/__init__.py",
    "content": "from .text2mel import Text2Mel\nfrom .ssrn import SSRN\n"
  },
  {
    "path": "models/layers.py",
    "content": "__author__ = 'Erdene-Ochir Tuguldur'\n__all__ = ['E', 'D', 'C', 'HighwayBlock', 'GatedConvBlock', 'ResidualBlock']\n\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom hparams import HParams as hp\n\n\nclass LayerNorm(nn.LayerNorm):\n    def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):\n        \"\"\"Layer Norm.\"\"\"\n        super(LayerNorm, self).__init__(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)\n\n    def forward(self, x):\n        x = x.permute(0, 2, 1)  # PyTorch LayerNorm seems to be expect (B, T, C)\n        y = super(LayerNorm, self).forward(x)\n        y = y.permute(0, 2, 1)  # reverse\n        return y\n\n\nclass D(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size, dilation, weight_init='none', normalization='weight', nonlinearity='linear'):\n        \"\"\"1D Deconvolution.\"\"\"\n        super(D, self).__init__()\n        self.deconv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size,\n                                         stride=2,  # paper: stride of deconvolution is always 2\n                                         dilation=dilation)\n\n        if normalization == 'weight':\n            self.deconv = nn.utils.weight_norm(self.deconv)\n        elif normalization == 'layer':\n            self.layer_norm = LayerNorm(out_channels)\n\n        self.nonlinearity = nonlinearity\n        if weight_init == 'kaiming':\n            nn.init.kaiming_normal_(self.deconv.weight, mode='fan_out', nonlinearity=nonlinearity)\n        elif weight_init == 'xavier':\n            nn.init.xavier_uniform_(self.deconv.weight, nn.init.calculate_gain(nonlinearity))\n\n    def forward(self, x, output_size=None):\n        y = self.deconv(x, output_size=output_size)\n        if hasattr(self, 'layer_norm'):\n            y = self.layer_norm(y)\n        y = F.dropout(y, p=hp.dropout_rate, training=self.training, inplace=True)\n        if self.nonlinearity == 'relu':\n            y = F.relu(y, inplace=True)\n        return y\n\n\nclass C(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size, dilation, causal=False, weight_init='none', normalization='weight', nonlinearity='linear'):\n        \"\"\"1D convolution.\n        The argument 'causal' indicates whether the causal convolution should be used or not.\n        \"\"\"\n        super(C, self).__init__()\n        self.causal = causal\n        if causal:\n            self.padding = (kernel_size - 1) * dilation\n        else:\n            self.padding = (kernel_size - 1) * dilation // 2\n\n        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size,\n                              stride=1,  # paper: 'The stride of convolution is always 1.'\n                              padding=self.padding, dilation=dilation)\n\n        if normalization == 'weight':\n            self.conv = nn.utils.weight_norm(self.conv)\n        elif normalization == 'layer':\n            self.layer_norm = LayerNorm(out_channels)\n\n        self.nonlinearity = nonlinearity\n        if weight_init == 'kaiming':\n            nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity=nonlinearity)\n        elif weight_init == 'xavier':\n            nn.init.xavier_uniform_(self.conv.weight, nn.init.calculate_gain(nonlinearity))\n\n    def forward(self, x):\n        y = self.conv(x)\n        padding = self.padding\n        if self.causal and padding > 0:\n            y = y[:, :, :-padding]\n\n        if hasattr(self, 'layer_norm'):\n            y = self.layer_norm(y)\n        y = F.dropout(y, p=hp.dropout_rate, training=self.training, inplace=True)\n        if self.nonlinearity == 'relu':\n            y = F.relu(y, inplace=True)\n        return y\n\n\nclass E(nn.Module):\n    def __init__(self, num_embeddings, embedding_dim):\n        super(E, self).__init__()\n        self.embedding = nn.Embedding(num_embeddings, embedding_dim, padding_idx=0)\n\n    def forward(self, x):\n        return self.embedding(x)\n\n\nclass HighwayBlock(nn.Module):\n    def __init__(self, d, k, delta, causal=False, weight_init='none', normalization='weight'):\n        \"\"\"Highway Network like layer: https://arxiv.org/abs/1505.00387\n        The input and output shapes remain same.\n        Args:\n            d: input channel\n            k: kernel size\n            delta: dilation\n            causal: causal convolution or not\n        \"\"\"\n        super(HighwayBlock, self).__init__()\n        self.d = d\n        self.C = C(in_channels=d, out_channels=2 * d, kernel_size=k, dilation=delta, causal=causal, weight_init=weight_init, normalization=normalization)\n\n    def forward(self, x):\n        L = self.C(x)\n        H1 = L[:, :self.d, :]\n        H2 = L[:, self.d:, :]\n        sigH1 = F.sigmoid(H1)\n        return sigH1 * H2 + (1 - sigH1) * x\n\n\nclass GatedConvBlock(nn.Module):\n    def __init__(self, d, k, delta, causal=False, weight_init='none', normalization='weight'):\n        \"\"\"Gated convolutional layer: https://arxiv.org/abs/1612.08083\n        The input and output shapes remain same.\n        Args:\n            d: input channel\n            k: kernel size\n            delta: dilation\n            causal: causal convolution or not\n        \"\"\"\n        super(GatedConvBlock, self).__init__()\n        self.C = C(in_channels=d, out_channels=2 * d, kernel_size=k, dilation=delta, causal=causal,\n                   weight_init=weight_init, normalization=normalization)\n        self.glu = nn.GLU(dim=1)\n\n    def forward(self, x):\n        L = self.C(x)\n        return self.glu(L) + x\n\n\nclass ResidualBlock(nn.Module):\n    def __init__(self, d, k, delta, causal=False, weight_init='none', normalization='weight',\n                 widening_factor=2):\n        \"\"\"Residual block: https://arxiv.org/abs/1512.03385\n        The input and output shapes remain same.\n        Args:\n            d: input channel\n            k: kernel size\n            delta: dilation\n            causal: causal convolution or not\n        \"\"\"\n        super(ResidualBlock, self).__init__()\n        self.C1 = C(in_channels=d, out_channels=widening_factor * d, kernel_size=k, dilation=delta, causal=causal,\n                    weight_init=weight_init, normalization=normalization, nonlinearity='relu')\n        self.C2 = C(in_channels=widening_factor * d, out_channels=d, kernel_size=k, dilation=delta, causal=causal,\n                    weight_init=weight_init, normalization=normalization, nonlinearity='relu')\n\n    def forward(self, x):\n        return self.C2(self.C1(x)) + x\n"
  },
  {
    "path": "models/ssrn.py",
    "content": "\"\"\"\nHideyuki Tachibana, Katsuya Uenoyama, Shunsuke Aihara\nEfficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention\nhttps://arxiv.org/abs/1710.08969\n\nSSRN Network.\n\"\"\"\n__author__ = 'Erdene-Ochir Tuguldur'\n__all__ = ['SSRN']\n\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom hparams import HParams as hp\nfrom .layers import D, C, HighwayBlock, GatedConvBlock, ResidualBlock\n\n\ndef Conv(in_channels, out_channels, kernel_size, dilation, nonlinearity='linear'):\n    return C(in_channels, out_channels, kernel_size, dilation, causal=False,\n             weight_init=hp.ssrn_weight_init, normalization=hp.ssrn_normalization, nonlinearity=nonlinearity)\n\n\ndef DeConv(in_channels, out_channels, kernel_size, dilation, nonlinearity='linear'):\n    return D(in_channels, out_channels, kernel_size, dilation,\n             weight_init=hp.ssrn_weight_init, normalization=hp.ssrn_normalization, nonlinearity=nonlinearity)\n\n\ndef BasicBlock(d, k, delta):\n    if hp.ssrn_basic_block == 'gated_conv':\n        return GatedConvBlock(d, k, delta, causal=False,\n                              weight_init=hp.ssrn_weight_init, normalization=hp.ssrn_normalization)\n    elif hp.ssrn_basic_block == 'highway':\n        return HighwayBlock(d, k, delta, causal=False,\n                            weight_init=hp.ssrn_weight_init, normalization=hp.ssrn_normalization)\n    else:\n        return ResidualBlock(d, k, delta, causal=False,\n                             weight_init=hp.ssrn_weight_init, normalization=hp.ssrn_normalization,\n                             widening_factor=1)\n\n\nclass SSRN(nn.Module):\n    def __init__(self, c=hp.c, f=hp.n_mels, f_prime=(1 + hp.n_fft // 2)):\n        \"\"\"Spectrogram super-resolution network.\n        Args:\n            c: SSRN dim\n            f: Number of mel bins\n            f_prime: full spectrogram dim\n        Input:\n            Y: (B, f, T) predicted melspectrograms\n        Outputs:\n            Z_logit: logit of Z\n            Z: (B, f_prime, 4*T) full spectrograms\n        \"\"\"\n        super(SSRN, self).__init__()\n        self.layers = nn.Sequential(\n            Conv(f, c, 1, 1),\n\n            BasicBlock(c, 3, 1), BasicBlock(c, 3, 3),\n\n            DeConv(c, c, 2, 1), BasicBlock(c, 3, 1), BasicBlock(c, 3, 3),\n            DeConv(c, c, 2, 1), BasicBlock(c, 3, 1), BasicBlock(c, 3, 3),\n\n            Conv(c, 2 * c, 1, 1),\n\n            BasicBlock(2 * c, 3, 1), BasicBlock(2 * c, 3, 1),\n\n            Conv(2 * c, f_prime, 1, 1),\n\n            # Conv(f_prime, f_prime, 1, 1, nonlinearity='relu'),\n            # Conv(f_prime, f_prime, 1, 1, nonlinearity='relu'),\n            BasicBlock(f_prime, 1, 1),\n\n            Conv(f_prime, f_prime, 1, 1)\n        )\n\n    def forward(self, x):\n        Z_logit = self.layers(x)\n        Z = F.sigmoid(Z_logit)\n        return Z_logit, Z"
  },
  {
    "path": "models/text2mel.py",
    "content": "\"\"\"\nHideyuki Tachibana, Katsuya Uenoyama, Shunsuke Aihara\nEfficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention\nhttps://arxiv.org/abs/1710.08969\n\nText2Mel Network.\n\"\"\"\n__author__ = 'Erdene-Ochir Tuguldur'\n__all__ = ['Text2Mel']\n\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom hparams import HParams as hp\nfrom .layers import E, C, HighwayBlock, GatedConvBlock, ResidualBlock\n\n\ndef Conv(in_channels, out_channels, kernel_size, dilation, causal=False, nonlinearity='linear'):\n    return C(in_channels, out_channels, kernel_size, dilation, causal=causal,\n             weight_init=hp.text2mel_weight_init, normalization=hp.text2mel_normalization, nonlinearity=nonlinearity)\n\n\ndef BasicBlock(d, k, delta, causal=False):\n    if hp.text2mel_basic_block == 'gated_conv':\n        return GatedConvBlock(d, k, delta, causal=causal,\n                              weight_init=hp.text2mel_weight_init, normalization=hp.text2mel_normalization)\n    elif hp.text2mel_basic_block == 'highway':\n        return HighwayBlock(d, k, delta, causal=causal,\n                            weight_init=hp.text2mel_weight_init, normalization=hp.text2mel_normalization)\n    else:\n        return ResidualBlock(d, k, delta, causal=causal,\n                             weight_init=hp.text2mel_weight_init, normalization=hp.text2mel_normalization,\n                             widening_factor=2)\n\n\ndef CausalConv(in_channels, out_channels, kernel_size, dilation, nonlinearity='linear'):\n    return Conv(in_channels, out_channels, kernel_size, dilation, causal=True, nonlinearity=nonlinearity)\n\n\ndef CausalBasicBlock(d, k, delta):\n    return BasicBlock(d, k, delta, causal=True)\n\n\nclass TextEnc(nn.Module):\n\n    def __init__(self, vocab, e=hp.e, d=hp.d):\n        \"\"\"Text encoder network.\n        Args:\n            vocab: vocabulary\n            e: embedding dim\n            d: Text2Mel dim\n        Input:\n            L: (B, N) text inputs\n        Outputs:\n            K: (B, d, N) keys\n            V: (N, d, N) values\n        \"\"\"\n        super(TextEnc, self).__init__()\n        self.d = d\n        self.embedding = E(len(vocab), e)\n\n        self.layers = nn.Sequential(\n            Conv(e, 2 * d, 1, 1, nonlinearity='relu'),\n            Conv(2 * d, 2 * d, 1, 1),\n\n            BasicBlock(2 * d, 3, 1), BasicBlock(2 * d, 3, 3), BasicBlock(2 * d, 3, 9), BasicBlock(2 * d, 3, 27),\n            BasicBlock(2 * d, 3, 1), BasicBlock(2 * d, 3, 3), BasicBlock(2 * d, 3, 9), BasicBlock(2 * d, 3, 27),\n\n            BasicBlock(2 * d, 3, 1), BasicBlock(2 * d, 3, 1),\n\n            BasicBlock(2 * d, 1, 1), BasicBlock(2 * d, 1, 1)\n        )\n\n    def forward(self, x):\n        out = self.embedding(x)\n        out = out.permute(0, 2, 1)  # change to (B, e, N)\n        out = self.layers(out)  # (B, 2*d, N)\n        K = out[:, :self.d, :]  # (B, d, N)\n        V = out[:, self.d:, :]  # (B, d, N)\n        return K, V\n\n\nclass AudioEnc(nn.Module):\n    def __init__(self, d=hp.d, f=hp.n_mels):\n        \"\"\"Audio encoder network.\n        Args:\n            d: Text2Mel dim\n            f: Number of mel bins\n        Input:\n            S: (B, f, T) melspectrograms\n        Output:\n            Q: (B, d, T) queries\n        \"\"\"\n        super(AudioEnc, self).__init__()\n        self.layers = nn.Sequential(\n            CausalConv(f, d, 1, 1, nonlinearity='relu'),\n            CausalConv(d, d, 1, 1, nonlinearity='relu'),\n            CausalConv(d, d, 1, 1),\n\n            CausalBasicBlock(d, 3, 1), CausalBasicBlock(d, 3, 3), CausalBasicBlock(d, 3, 9), CausalBasicBlock(d, 3, 27),\n            CausalBasicBlock(d, 3, 1), CausalBasicBlock(d, 3, 3), CausalBasicBlock(d, 3, 9), CausalBasicBlock(d, 3, 27),\n\n            CausalBasicBlock(d, 3, 3), CausalBasicBlock(d, 3, 3),\n        )\n\n    def forward(self, x):\n        return self.layers(x)\n\n\nclass AudioDec(nn.Module):\n    def __init__(self, d=hp.d, f=hp.n_mels):\n        \"\"\"Audio decoder network.\n        Args:\n            d: Text2Mel dim\n            f: Number of mel bins\n        Input:\n            R_prime: (B, 2d, T) [V*Attention, Q] paper says: \"we found it beneficial in our pilot study.\"\n        Output:\n            Y: (B, f, T)\n        \"\"\"\n        super(AudioDec, self).__init__()\n        self.layers = nn.Sequential(\n            CausalConv(2 * d, d, 1, 1),\n\n            CausalBasicBlock(d, 3, 1), CausalBasicBlock(d, 3, 3), CausalBasicBlock(d, 3, 9), CausalBasicBlock(d, 3, 27),\n\n            CausalBasicBlock(d, 3, 1), CausalBasicBlock(d, 3, 1),\n\n            # CausalConv(d, d, 1, 1, nonlinearity='relu'),\n            # CausalConv(d, d, 1, 1, nonlinearity='relu'),\n            CausalBasicBlock(d, 1, 1),\n            CausalConv(d, d, 1, 1, nonlinearity='relu'),\n\n            CausalConv(d, f, 1, 1)\n        )\n\n    def forward(self, x):\n        return self.layers(x)\n\n\nclass Text2Mel(nn.Module):\n    def __init__(self, vocab, d=hp.d):\n        \"\"\"Text to melspectrogram network.\n        Args:\n            vocab: vocabulary\n            d: Text2Mel dim\n        Input:\n            L: (B, N) text inputs\n            S: (B, f, T) melspectrograms\n        Outputs:\n            Y_logit: logit of Y\n            Y: predicted melspectrograms\n            A: (B, N, T) attention matrix\n        \"\"\"\n        super(Text2Mel, self).__init__()\n        self.d = d\n        self.text_enc = TextEnc(vocab)\n        self.audio_enc = AudioEnc()\n        self.audio_dec = AudioDec()\n\n    def forward(self, L, S, monotonic_attention=False):\n        K, V = self.text_enc(L)\n        Q = self.audio_enc(S)\n        A = torch.bmm(K.permute(0, 2, 1), Q) / np.sqrt(self.d)\n\n        if monotonic_attention:\n            # TODO: vectorize instead of loops\n            B, N, T = A.size()\n            for i in range(B):\n                prva = -1  # previous attention\n                for t in range(T):\n                    _, n = torch.max(A[i, :, t], 0)\n                    if not (-1 <= n - prva <= 3):\n                        A[i, :, t] = -2 ** 20  # some small numbers\n                        A[i, min(N - 1, prva + 1), t] = 1\n                    _, prva = torch.max(A[i, :, t], 0)\n\n        A = F.softmax(A, dim=1)\n        R = torch.bmm(V, A)\n        R_prime = torch.cat((R, Q), 1)\n        Y_logit = self.audio_dec(R_prime)\n        Y = F.sigmoid(Y_logit)\n        return Y_logit, Y, A\n"
  },
  {
    "path": "requirements.txt",
    "content": "librosa>=0.5.1\ntorch>=0.4\ntensorboardX>=1.2\ntqdm>=4.15.0\nnumpy>=1.25.0\nscipy\npandas\nrequests\nscikit-image\n"
  },
  {
    "path": "synthesize.py",
    "content": "#!/usr/bin/env python\n\"\"\"Synthetize sentences into speech.\"\"\"\n__author__ = 'Erdene-Ochir Tuguldur'\n\nimport os\nimport sys\nimport argparse\nfrom tqdm import *\n\nimport numpy as np\nimport torch\n\nfrom models import Text2Mel, SSRN\nfrom hparams import HParams as hp\nfrom audio import save_to_wav\nfrom utils import get_last_checkpoint_file_name, load_checkpoint, save_to_png\n\nparser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)\nparser.add_argument(\"--dataset\", required=True, choices=['ljspeech', 'mbspeech'], help='dataset name')\nargs = parser.parse_args()\n\nif args.dataset == 'ljspeech':\n    from datasets.lj_speech import vocab, get_test_data\n\n    SENTENCES = [\n        \"The birch canoe slid on the smooth planks.\",\n        \"Glue the sheet to the dark blue background.\",\n        \"It's easy to tell the depth of a well.\",\n        \"These days a chicken leg is a rare dish.\",\n        \"Rice is often served in round bowls.\",\n        \"The juice of lemons makes fine punch.\",\n        \"The box was thrown beside the parked truck.\",\n        \"The hogs were fed chopped corn and garbage.\",\n        \"Four hours of steady work faced us.\",\n        \"Large size in stockings is hard to sell.\",\n        \"The boy was there when the sun rose.\",\n        \"A rod is used to catch pink salmon.\",\n        \"The source of the huge river is the clear spring.\",\n        \"Kick the ball straight and follow through.\",\n        \"Help the woman get back to her feet.\",\n        \"A pot of tea helps to pass the evening.\",\n        \"Smoky fires lack flame and heat.\",\n        \"The soft cushion broke the man's fall.\",\n        \"The salt breeze came across from the sea.\",\n        \"The girl at the booth sold fifty bonds.\"\n    ]\nelse:\n    from datasets.mb_speech import vocab, get_test_data\n\n    SENTENCES = [\n        \"Нийслэлийн прокурорын газраас төрийн өндөр албан тушаалтнуудад холбогдох зарим эрүүгийн хэргүүдийг шүүхэд шилжүүлэв.\",\n        \"Мөнх тэнгэрийн хүчин дор Монгол Улс цэцэглэн хөгжих болтугай.\",\n        \"Унасан хүлгээ түрүү магнай, аман хүзүүнд уралдуулж, айрагдуулсан унаач хүүхдүүдэд бэлэг гардууллаа.\",\n        \"Албан ёсоор хэлэхэд “Монгол Улсын хэрэг эрхлэх газрын гэгээнтэн” гэж нэрлээд байгаа зүйл огт байхгүй.\",\n        \"Сайн чанарын бохирын хоолой зарна.\",\n        \"Хараа тэглэх мэс заслын дараа хараа дахин муудах магадлал бага.\",\n        \"Ер нь бол хараа тэглэх мэс заслыг гоо сайхны мэс засалтай адилхан гэж зүйрлэж болно.\",\n        \"Хашлага даван, зүлэг гэмтээсэн жолоочийн эрхийг хоёр жилээр хасжээ.\",\n        \"Монгол хүн бидний сэтгэлийг сорсон орон. Энэ бол миний төрсөн нутаг. Монголын сайхан орон.\",\n        \"Постройка крейсера затягивалась из-за проектных неувязок, необходимости.\"\n    ]\n\ntorch.set_grad_enabled(False)\n\ntext2mel = Text2Mel(vocab).eval()\nlast_checkpoint_file_name = get_last_checkpoint_file_name(os.path.join(hp.logdir, '%s-text2mel' % args.dataset))\n# last_checkpoint_file_name = 'logdir/%s-text2mel/step-020K.pth' % args.dataset\nif last_checkpoint_file_name:\n    print(\"loading text2mel checkpoint '%s'...\" % last_checkpoint_file_name)\n    load_checkpoint(last_checkpoint_file_name, text2mel, None)\nelse:\n    print(\"text2mel not exits\")\n    sys.exit(1)\n\nssrn = SSRN().eval()\nlast_checkpoint_file_name = get_last_checkpoint_file_name(os.path.join(hp.logdir, '%s-ssrn' % args.dataset))\n# last_checkpoint_file_name = 'logdir/%s-ssrn/step-005K.pth' % args.dataset\nif last_checkpoint_file_name:\n    print(\"loading ssrn checkpoint '%s'...\" % last_checkpoint_file_name)\n    load_checkpoint(last_checkpoint_file_name, ssrn, None)\nelse:\n    print(\"ssrn not exits\")\n    sys.exit(1)\n\n# synthetize by one by one because there is a batch processing bug!\nfor i in range(len(SENTENCES)):\n    sentences = [SENTENCES[i]]\n\n    max_N = len(SENTENCES[i])\n    L = torch.from_numpy(get_test_data(sentences, max_N))\n    zeros = torch.from_numpy(np.zeros((1, hp.n_mels, 1), np.float32))\n    Y = zeros\n    A = None\n\n    for t in tqdm(range(hp.max_T)):\n        _, Y_t, A = text2mel(L, Y, monotonic_attention=True)\n        Y = torch.cat((zeros, Y_t), -1)\n        _, attention = torch.max(A[0, :, -1], 0)\n        attention = attention.item()\n        if L[0, attention] == vocab.index('E'):  # EOS\n            break\n\n    _, Z = ssrn(Y)\n\n    Y = Y.cpu().detach().numpy()\n    A = A.cpu().detach().numpy()\n    Z = Z.cpu().detach().numpy()\n\n    save_to_png('samples/%d-att.png' % (i + 1), A[0, :, :])\n    save_to_png('samples/%d-mel.png' % (i + 1), Y[0, :, :])\n    save_to_png('samples/%d-mag.png' % (i + 1), Z[0, :, :])\n    save_to_wav(Z[0, :, :].T, 'samples/%d-wav.wav' % (i + 1))\n"
  },
  {
    "path": "train-ssrn.py",
    "content": "#!/usr/bin/env python\n\"\"\"Train the Text2Mel network. See: https://arxiv.org/abs/1710.08969\"\"\"\n__author__ = 'Erdene-Ochir Tuguldur'\n\nimport sys\nimport time\nimport argparse\nfrom tqdm import *\n\nimport torch\nimport torch.nn.functional as F\n\n# project imports\nfrom models import SSRN\nfrom hparams import HParams as hp\nfrom logger import Logger\nfrom utils import get_last_checkpoint_file_name, load_checkpoint, save_checkpoint\nfrom datasets.data_loader import SSRNDataLoader\n\nparser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)\nparser.add_argument(\"--dataset\", required=True, choices=['ljspeech', 'mbspeech'], help='dataset name')\nargs = parser.parse_args()\n\nif args.dataset == 'ljspeech':\n    from datasets.lj_speech import LJSpeech as SpeechDataset\nelse:\n    from datasets.mb_speech import MBSpeech as SpeechDataset\n\nuse_gpu = torch.cuda.is_available()\nprint('use_gpu', use_gpu)\nif use_gpu:\n    torch.backends.cudnn.benchmark = True\n\ntrain_data_loader = SSRNDataLoader(ssrn_dataset=SpeechDataset(['mags', 'mels']), batch_size=24, mode='train')\nvalid_data_loader = SSRNDataLoader(ssrn_dataset=SpeechDataset(['mags', 'mels']), batch_size=24, mode='valid')\n\nssrn = SSRN().cuda()\n\noptimizer = torch.optim.Adam(ssrn.parameters(), lr=hp.ssrn_lr)\n\nstart_timestamp = int(time.time() * 1000)\nstart_epoch = 0\nglobal_step = 0\n\nlogger = Logger(args.dataset, 'ssrn')\n\n# load the last checkpoint if exists\nlast_checkpoint_file_name = get_last_checkpoint_file_name(logger.logdir)\nif last_checkpoint_file_name:\n    print(\"loading the last checkpoint: %s\" % last_checkpoint_file_name)\n    start_epoch, global_step = load_checkpoint(last_checkpoint_file_name, ssrn, optimizer)\n\n\ndef get_lr():\n    return optimizer.param_groups[0]['lr']\n\n\ndef lr_decay(step, warmup_steps=1000):\n    new_lr = hp.ssrn_lr * warmup_steps ** 0.5 * min((step + 1) * warmup_steps ** -1.5, (step + 1) ** -0.5)\n    optimizer.param_groups[0]['lr'] = new_lr\n\n\ndef train(train_epoch, phase='train'):\n    global global_step\n\n    lr_decay(global_step)\n    print(\"epoch %3d with lr=%.02e\" % (train_epoch, get_lr()))\n\n    ssrn.train() if phase == 'train' else ssrn.eval()\n    torch.set_grad_enabled(True) if phase == 'train' else torch.set_grad_enabled(False)\n    data_loader = train_data_loader if phase == 'train' else valid_data_loader\n\n    it = 0\n    running_loss = 0.0\n    running_l1_loss = 0.0\n\n    pbar = tqdm(data_loader, unit=\"audios\", unit_scale=data_loader.batch_size, disable=hp.disable_progress_bar)\n    for batch in pbar:\n        M, S = batch['mags'], batch['mels']\n        M = M.permute(0, 2, 1)  # TODO: because of pre processing\n        S = S.permute(0, 2, 1)  # TODO: because of pre processing\n\n        M.requires_grad = False\n        M = M.cuda()\n        S = S.cuda()\n\n        Z_logit, Z = ssrn(S)\n\n        l1_loss = F.l1_loss(Z, M)\n\n        loss = l1_loss\n\n        if phase == 'train':\n            lr_decay(global_step)\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n            global_step += 1\n\n        it += 1\n\n        loss = loss.item()\n        l1_loss = l1_loss.item()\n        running_loss += loss\n        running_l1_loss += l1_loss\n\n        if phase == 'train':\n            # update the progress bar\n            pbar.set_postfix({\n                'l1': \"%.05f\" % (running_l1_loss / it)\n            })\n            logger.log_step(phase, global_step, {'loss_l1': l1_loss},\n                            {'mags-true': M[:1, :, :], 'mags-pred': Z[:1, :, :], 'mels': S[:1, :, :]})\n            if global_step % 5000 == 0:\n                # checkpoint at every 5000th step\n                save_checkpoint(logger.logdir, train_epoch, global_step, ssrn, optimizer)\n\n    epoch_loss = running_loss / it\n    epoch_l1_loss = running_l1_loss / it\n\n    logger.log_epoch(phase, global_step, {'loss_l1': epoch_l1_loss})\n\n    return epoch_loss\n\n\nsince = time.time()\nepoch = start_epoch\nwhile True:\n    train_epoch_loss = train(epoch, phase='train')\n    time_elapsed = time.time() - since\n    time_str = 'total time elapsed: {:.0f}h {:.0f}m {:.0f}s '.format(time_elapsed // 3600, time_elapsed % 3600 // 60,\n                                                                     time_elapsed % 60)\n    print(\"train epoch loss %f, step=%d, %s\" % (train_epoch_loss, global_step, time_str))\n\n    valid_epoch_loss = train(epoch, phase='valid')\n    print(\"valid epoch loss %f\" % valid_epoch_loss)\n\n    epoch += 1\n    if global_step >= hp.ssrn_max_iteration:\n        print(\"max step %d (current step %d) reached, exiting...\" % (hp.ssrn_max_iteration, global_step))\n        sys.exit(0)\n"
  },
  {
    "path": "train-text2mel.py",
    "content": "#!/usr/bin/env python\n\"\"\"Train the Text2Mel network. See: https://arxiv.org/abs/1710.08969\"\"\"\n__author__ = 'Erdene-Ochir Tuguldur'\n\nimport sys\nimport time\nimport argparse\nfrom tqdm import *\n\nimport numpy as np\n\nimport torch\nimport torch.nn.functional as F\n\n# project imports\nfrom models import Text2Mel\nfrom hparams import HParams as hp\nfrom logger import Logger\nfrom utils import get_last_checkpoint_file_name, load_checkpoint, save_checkpoint\nfrom datasets.data_loader import Text2MelDataLoader\n\nparser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)\nparser.add_argument(\"--dataset\", required=True, choices=['ljspeech', 'mbspeech'], help='dataset name')\nargs = parser.parse_args()\n\nif args.dataset == 'ljspeech':\n    from datasets.lj_speech import vocab, LJSpeech as SpeechDataset\nelse:\n    from datasets.mb_speech import vocab, MBSpeech as SpeechDataset\n\nuse_gpu = torch.cuda.is_available()\nprint('use_gpu', use_gpu)\nif use_gpu:\n    torch.backends.cudnn.benchmark = True\n\ntrain_data_loader = Text2MelDataLoader(text2mel_dataset=SpeechDataset(['texts', 'mels', 'mel_gates']), batch_size=64,\n                                       mode='train')\nvalid_data_loader = Text2MelDataLoader(text2mel_dataset=SpeechDataset(['texts', 'mels', 'mel_gates']), batch_size=64,\n                                       mode='valid')\n\ntext2mel = Text2Mel(vocab).cuda()\n\noptimizer = torch.optim.Adam(text2mel.parameters(), lr=hp.text2mel_lr)\n\nstart_timestamp = int(time.time() * 1000)\nstart_epoch = 0\nglobal_step = 0\n\nlogger = Logger(args.dataset, 'text2mel')\n\n# load the last checkpoint if exists\nlast_checkpoint_file_name = get_last_checkpoint_file_name(logger.logdir)\nif last_checkpoint_file_name:\n    print(\"loading the last checkpoint: %s\" % last_checkpoint_file_name)\n    start_epoch, global_step = load_checkpoint(last_checkpoint_file_name, text2mel, optimizer)\n\n\ndef get_lr():\n    return optimizer.param_groups[0]['lr']\n\n\ndef lr_decay(step, warmup_steps=4000):\n    new_lr = hp.text2mel_lr * warmup_steps ** 0.5 * min((step + 1) * warmup_steps ** -1.5, (step + 1) ** -0.5)\n    optimizer.param_groups[0]['lr'] = new_lr\n\n\ndef train(train_epoch, phase='train'):\n    global global_step\n\n    lr_decay(global_step)\n    print(\"epoch %3d with lr=%.02e\" % (train_epoch, get_lr()))\n\n    text2mel.train() if phase == 'train' else text2mel.eval()\n    torch.set_grad_enabled(True) if phase == 'train' else torch.set_grad_enabled(False)\n    data_loader = train_data_loader if phase == 'train' else valid_data_loader\n\n    it = 0\n    running_loss = 0.0\n    running_l1_loss = 0.0\n    running_att_loss = 0.0\n\n    pbar = tqdm(data_loader, unit=\"audios\", unit_scale=data_loader.batch_size, disable=hp.disable_progress_bar)\n    for batch in pbar:\n        L, S, gates = batch['texts'], batch['mels'], batch['mel_gates']\n        S = S.permute(0, 2, 1)  # TODO: because of pre processing\n\n        B, N = L.size()  # batch size and text count\n        _, n_mels, T = S.size()  # number of melspectrogram bins and time\n\n        assert gates.size(0) == B  # TODO: later remove\n        assert gates.size(1) == T\n\n        S_shifted = torch.cat((S[:, :, 1:], torch.zeros(B, n_mels, 1)), 2)\n\n        S.requires_grad = False\n        S_shifted.requires_grad = False\n        gates.requires_grad = False\n\n        def W_nt(_, n, t, g=0.2):\n            return 1.0 - np.exp(-((n / float(N) - t / float(T)) ** 2) / (2 * g ** 2))\n\n        W = np.fromfunction(W_nt, (B, N, T), dtype=np.float32)\n        W = torch.from_numpy(W)\n\n        L = L.cuda()\n        S = S.cuda()\n        S_shifted = S_shifted.cuda()\n        W = W.cuda()\n        gates = gates.cuda()\n\n        Y_logit, Y, A = text2mel(L, S)\n\n        l1_loss = F.l1_loss(Y, S_shifted)\n        masks = gates.reshape(B, 1, T).float()\n        att_loss = (A * W * masks).mean()\n\n        loss = l1_loss + att_loss\n\n        if phase == 'train':\n            lr_decay(global_step)\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n            global_step += 1\n\n        it += 1\n\n        loss, l1_loss, att_loss = loss.item(), l1_loss.item(), att_loss.item()\n        running_loss += loss\n        running_l1_loss += l1_loss\n        running_att_loss += att_loss\n\n        if phase == 'train':\n            # update the progress bar\n            pbar.set_postfix({\n                'l1': \"%.05f\" % (running_l1_loss / it),\n                'att': \"%.05f\" % (running_att_loss / it)\n            })\n            logger.log_step(phase, global_step, {'loss_l1': l1_loss, 'loss_att': att_loss},\n                            {'mels-true': S[:1, :, :], 'mels-pred': Y[:1, :, :], 'attention': A[:1, :, :]})\n            if global_step % 5000 == 0:\n                # checkpoint at every 5000th step\n                save_checkpoint(logger.logdir, train_epoch, global_step, text2mel, optimizer)\n\n    epoch_loss = running_loss / it\n    epoch_l1_loss = running_l1_loss / it\n    epoch_att_loss = running_att_loss / it\n\n    logger.log_epoch(phase, global_step, {'loss_l1': epoch_l1_loss, 'loss_att': epoch_att_loss})\n\n    return epoch_loss\n\n\nsince = time.time()\nepoch = start_epoch\nwhile True:\n    train_epoch_loss = train(epoch, phase='train')\n    time_elapsed = time.time() - since\n    time_str = 'total time elapsed: {:.0f}h {:.0f}m {:.0f}s '.format(time_elapsed // 3600, time_elapsed % 3600 // 60,\n                                                                     time_elapsed % 60)\n    print(\"train epoch loss %f, step=%d, %s\" % (train_epoch_loss, global_step, time_str))\n\n    valid_epoch_loss = train(epoch, phase='valid')\n    print(\"valid epoch loss %f\" % valid_epoch_loss)\n\n    epoch += 1\n    if global_step >= hp.text2mel_max_iteration:\n        print(\"max step %d (current step %d) reached, exiting...\" % (hp.text2mel_max_iteration, global_step))\n        sys.exit(0)\n"
  },
  {
    "path": "utils.py",
    "content": "\"\"\"Utility methods.\"\"\"\n__author__ = 'Erdene-Ochir Tuguldur'\n\nimport os\nimport sys\nimport glob\nimport torch\nimport math\nimport requests\nfrom tqdm import tqdm\nfrom skimage.io import imsave\nfrom skimage import img_as_ubyte\n\n\ndef get_last_checkpoint_file_name(logdir):\n    \"\"\"Returns the last checkpoint file name in the given log dir path.\"\"\"\n    checkpoints = glob.glob(os.path.join(logdir, '*.pth'))\n    checkpoints.sort()\n    if len(checkpoints) == 0:\n        return None\n    return checkpoints[-1]\n\n\ndef load_checkpoint(checkpoint_file_name, model, optimizer):\n    \"\"\"Loads the checkpoint into the given model and optimizer.\"\"\"\n    checkpoint = torch.load(checkpoint_file_name)\n    model.load_state_dict(checkpoint['state_dict'])\n    model.float()\n    if optimizer is not None:\n        optimizer.load_state_dict(checkpoint['optimizer'])\n    start_epoch = checkpoint.get('epoch', 0)\n    global_step = checkpoint.get('global_step', 0)\n    del checkpoint\n    print(\"loaded checkpoint epoch=%d step=%d\" % (start_epoch, global_step))\n    return start_epoch, global_step\n\n\ndef save_checkpoint(logdir, epoch, global_step, model, optimizer):\n    \"\"\"Saves the training state into the given log dir path.\"\"\"\n    checkpoint_file_name = os.path.join(logdir, 'step-%03dK.pth' % (global_step // 1000))\n    print(\"saving the checkpoint file '%s'...\" % checkpoint_file_name)\n    checkpoint = {\n        'epoch': epoch + 1,\n        'global_step': global_step,\n        'state_dict': model.state_dict(),\n        'optimizer': optimizer.state_dict(),\n    }\n    torch.save(checkpoint, checkpoint_file_name)\n    del checkpoint\n\n\ndef download_file(url, file_path):\n    \"\"\"Downloads a file from the given URL.\"\"\"\n    print(\"downloading %s...\" % url)\n    r = requests.get(url, stream=True)\n    total_size = int(r.headers.get('content-length', 0))\n    block_size = 1024 * 1024\n    wrote = 0\n    with open(file_path, 'wb') as f:\n        for data in tqdm(r.iter_content(block_size), total=math.ceil(total_size // block_size), unit='MB'):\n            wrote = wrote + len(data)\n            f.write(data)\n\n    if total_size != 0 and wrote != total_size:\n        print(\"downloading failed\")\n        sys.exit(1)\n\n\ndef save_to_png(file_name, array):\n    \"\"\"Save the given numpy array as a PNG file.\"\"\"\n    # from skimage._shared._warnings import expected_warnings\n    # with expected_warnings(['precision']):\n    imsave(file_name, img_as_ubyte(array))\n"
  }
]