[
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2023 PlayVoice\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "<div align=\"center\">\n<h1> Grad-SVC based on Grad-TTS from HUAWEI Noah's Ark Lab </h1>\n\n[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/maxmax20160403/grad-svc)\n<img alt=\"GitHub Repo stars\" src=\"https://img.shields.io/github/stars/PlayVoice/Grad-SVC\">\n<img alt=\"GitHub forks\" src=\"https://img.shields.io/github/forks/PlayVoice/Grad-SVC\">\n<img alt=\"GitHub issues\" src=\"https://img.shields.io/github/issues/PlayVoice/Grad-SVC\">\n<img alt=\"GitHub\" src=\"https://img.shields.io/github/license/PlayVoice/Grad-SVC\">\n\nThis project is named as [Grad-SVC](), or [GVC]() for short. Its core technology is diffusion, but so different from other diffusion based SVC models. Codes are adapted from `Grad-TTS` and `whisper-vits-svc`. So the features from `whisper-vits-svc` are used in this project. By the way, [Diff-VC](https://github.com/huawei-noah/Speech-Backbones/tree/main/DiffVC) is a follow-up of [Grad-TTS](), [Diffusion-Based Any-to-Any Voice Conversion](https://arxiv.org/abs/2109.13821)\n\n[Grad-TTS: A Diffusion Probabilistic Model for Text-to-Speech](https://arxiv.org/abs/2105.06337)\n\n![grad_tts](./assets/grad_tts.jpg)\n\n![grad_svc](./assets/grad_svc.jpg)\n\nThe framework of grad-svc-v1\n\n![grad_svc_v2](./assets/grad_svc_v2.jpg)\n\nThe framework of grad-svc-v2 & v3, encoder:768->512, diffusion:64->96\n\nhttps://github.com/PlayVoice/Grad-SVC/assets/16432329/f9b66af7-b5b5-4efb-b73d-adb0dc84a0ae\n\n</div>\n\n## Features\n1. Such beautiful codes from Grad-TTS\n\n    `easy to read`\n\n2. Multi-speaker based on speaker encoder\n\n3. No speaker leaky based on `Perturbation` & `Instance Normlize` & `GRL`\n\n\t[One-shot Voice Conversion by Separating Speaker and Content Representations with Instance Normalization](https://arxiv.org/abs/1904.05742)\n\n4. No electronic sound\n\n5. Integrated [DPM Solver-k](https://github.com/LuChengTHU/dpm-solver) for less steps\n\n6. Integrated [Fast Maximum Likelihood Sampling Scheme](https://github.com/huawei-noah/Speech-Backbones/tree/main/DiffVC), for less steps\n\n7. [Conditional Flow Matching](https://voicebox.metademolab.com/) (V3), first used in SVC\n\n8. [Rectified Flow Matching](https://github.com/cantabile-kwok/VoiceFlow-TTS) (TODO)\n\n## Setup Environment\n1. Install project dependencies\n\n    ```shell\n    pip install -r requirements.txt\n    ```\n\n2. Download the Timbre Encoder: [Speaker-Encoder by @mueller91](https://drive.google.com/drive/folders/15oeBYf6Qn1edONkVLXe82MzdIi3O_9m3), put `best_model.pth.tar`  into `speaker_pretrain/`.\n\n3. Download [hubert_soft model](https://github.com/bshall/hubert/releases/tag/v0.1)，put `hubert-soft-0d54a1f4.pt` into `hubert_pretrain/`.\n\n4. Download pretrained [nsf_bigvgan_pretrain_32K.pth](https://github.com/PlayVoice/NSF-BigVGAN/releases/augment), and put it into `bigvgan_pretrain/`.\n   \n\t**Performance Bottleneck: Generator and Discriminator are 116Mb, but Generator is only 22Mb**\n\n\t**系统性能瓶颈：生成器和判别器一共116M，而生成器只有22M**\n\n6. Download pretrain model [gvc.pretrain.pth](https://github.com/PlayVoice/Grad-SVC/releases/tag/20230920), and put it into `grad_pretrain/`.\n    ```\n    python gvc_inference.py --model ./grad_pretrain/gvc.pretrain.pth --spk ./assets/singers/singer0001.npy --wave test.wav\n    ```\n    \n    For this pretrain model, `temperature` is set `temperature=1.015` in `gvc_inference.py` to get good result.\n   \n## Dataset preparation\nPut the dataset into the `data_raw` directory following the structure below.\n```\ndata_raw\n├───speaker0\n│   ├───000001.wav\n│   ├───...\n│   └───000xxx.wav\n└───speaker1\n    ├───000001.wav\n    ├───...\n    └───000xxx.wav\n```\n\n## Data preprocessing\nAfter preprocessing you will get an output with following structure.\n```\ndata_gvc/\n└── waves-16k\n│    └── speaker0\n│    │      ├── 000001.wav\n│    │      └── 000xxx.wav\n│    └── speaker1\n│           ├── 000001.wav\n│           └── 000xxx.wav\n└── waves-32k\n│    └── speaker0\n│    │      ├── 000001.wav\n│    │      └── 000xxx.wav\n│    └── speaker1\n│           ├── 000001.wav\n│           └── 000xxx.wav\n└── mel\n│    └── speaker0\n│    │      ├── 000001.mel.pt\n│    │      └── 000xxx.mel.pt\n│    └── speaker1\n│           ├── 000001.mel.pt\n│           └── 000xxx.mel.pt\n└── pitch\n│    └── speaker0\n│    │      ├── 000001.pit.npy\n│    │      └── 000xxx.pit.npy\n│    └── speaker1\n│           ├── 000001.pit.npy\n│           └── 000xxx.pit.npy\n└── hubert\n│    └── speaker0\n│    │      ├── 000001.vec.npy\n│    │      └── 000xxx.vec.npy\n│    └── speaker1\n│           ├── 000001.vec.npy\n│           └── 000xxx.vec.npy\n└── speaker\n│    └── speaker0\n│    │      ├── 000001.spk.npy\n│    │      └── 000xxx.spk.npy\n│    └── speaker1\n│           ├── 000001.spk.npy\n│           └── 000xxx.spk.npy\n└── singer\n    ├── speaker0.spk.npy\n    └── speaker1.spk.npy\n```\n\n1.  Re-sampling\n    - Generate audio with a sampling rate of 16000Hz in `./data_gvc/waves-16k` \n    ```\n    python prepare/preprocess_a.py -w ./data_raw -o ./data_gvc/waves-16k -s 16000\n    ```\n    - Generate audio with a sampling rate of 32000Hz in `./data_gvc/waves-32k`\n    ```\n    python prepare/preprocess_a.py -w ./data_raw -o ./data_gvc/waves-32k -s 32000\n    ```\n2. Use 16K audio to extract pitch\n    ```\n    python prepare/preprocess_f0.py -w data_gvc/waves-16k/ -p data_gvc/pitch\n    ```\n3. use 32k audio to extract mel\n    ```\n    python prepare/preprocess_spec.py -w data_gvc/waves-32k/ -s data_gvc/mel\n    ``` \n4. Use 16K audio to extract hubert\n    ```\n    python prepare/preprocess_hubert.py -w data_gvc/waves-16k/ -v data_gvc/hubert\n    ```\n5. Use 16k audio to extract timbre code\n    ```\n    python prepare/preprocess_speaker.py data_gvc/waves-16k/ data_gvc/speaker\n    ```\n6. Extract the average value of the timbre code for inference\n    ```\n    python prepare/preprocess_speaker_ave.py data_gvc/speaker/ data_gvc/singer\n    ``` \n8. Use 32k audio to generate training index\n    ```\n    python prepare/preprocess_train.py\n    ```\n9. Training file debugging\n    ```\n    python prepare/preprocess_zzz.py\n    ```\n\n## Train\n1. Start training\n   ```\n   python gvc_trainer.py\n   ``` \n2. Resume training\n   ```\n   python gvc_trainer.py -p logs/grad_svc/grad_svc_***.pth\n   ```\n3. Log visualization\n   ```\n   tensorboard --logdir logs/\n   ```\n\n## Train Loss\n\n![loss_96_v2](./assets/loss_96_v2.jpg)\n\n![grad_svc_mel](./assets/grad_svc_mel.jpg)\n\n\n## Inference\n\n1. Export inference model\n   ```\n   python gvc_export.py --checkpoint_path logs/grad_svc/grad_svc_***.pth\n   ```\n\n2. Inference\n    ```\n    python gvc_inference.py --model gvc.pth --spk ./data_gvc/singer/your_singer.spk.npy --wave test.wav --rature 1.015 --shift 0\n    ```\n    temperature=1.015, needs to be adjusted to get good results; Recommended range is (1.001, 1.035).\n\n2. Inference step by step\n    - Extract hubert content vector\n        ```\n        python hubert/inference.py -w test.wav -v test.vec.npy\n        ```\n    - Extract pitch to the csv text format\n        ```\n        python pitch/inference.py -w test.wav -p test.csv\n        ```\n    - Convert hubert & pitch to wave\n        ```\n        python gvc_inference.py --model gvc.pth --spk ./data_gvc/singer/your_singer.spk.npy --wave test.wav --vec test.vec.npy --pit test.csv --shift 0\n        ```\n\n## Data\n\n| Name | URL |\n| :--- | :--- |\n|PopCS          |https://github.com/MoonInTheRiver/DiffSinger/blob/master/resources/apply_form.md|\n|opencpop       |https://wenet.org.cn/opencpop/download/|\n|Multi-Singer   |https://github.com/Multi-Singer/Multi-Singer.github.io|\n|M4Singer       |https://github.com/M4Singer/M4Singer/blob/master/apply_form.md|\n|VCTK           |https://datashare.ed.ac.uk/handle/10283/2651|\n\n## Code sources and references\n\nhttps://github.com/huawei-noah/Speech-Backbones/blob/main/Grad-TTS\n\nhttps://github.com/huawei-noah/Speech-Backbones/tree/main/DiffVC\n\nhttps://github.com/facebookresearch/speech-resynthesis\n\nhttps://github.com/cantabile-kwok/VoiceFlow-TTS\n\nhttps://github.com/shivammehta25/Matcha-TTS\n\nhttps://github.com/shivammehta25/Diff-TTSG\n\nhttps://github.com/majidAdibian77/ResGrad\n\nhttps://github.com/LuChengTHU/dpm-solver\n\nhttps://github.com/gmltmd789/UnitSpeech\n\nhttps://github.com/zhenye234/CoMoSpeech\n\nhttps://github.com/seahore/PPG-GradVC\n\nhttps://github.com/thuhcsi/LightGrad\n\nhttps://github.com/lmnt-com/wavegrad\n\nhttps://github.com/naver-ai/facetts\n\nhttps://github.com/jaywalnut310/vits\n\nhttps://github.com/NVIDIA/BigVGAN\n\nhttps://github.com/bshall/soft-vc\n\nhttps://github.com/mozilla/TTS\n\nhttps://github.com/ubisoft/ubisoft-laforge-daft-exprt\n\n##\n\nhttps://github.com/yl4579/StyleTTS-VC\n\nhttps://github.com/MingjieChen/DYGANVC\n\nhttps://github.com/sony/ai-research-code/tree/master/nvcnet\n"
  },
  {
    "path": "bigvgan/LICENSE",
    "content": "MIT License\n\nCopyright (c) 2022 PlayVoice\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "bigvgan/README.md",
    "content": "<div align=\"center\">\n<h1> Neural Source-Filter BigVGAN </h1>\n    Just For Fun\n</div>\n\n![nsf_bigvgan_mel](https://github.com/PlayVoice/NSF-BigVGAN/assets/16432329/eebb8dca-a8d3-4e69-b02c-632a3a1cdd6a)\n\n## Dataset preparation\n\nPut the dataset into the data_raw directory according to the following file structure\n```shell\ndata_raw\n├───speaker0\n│   ├───000001.wav\n│   ├───...\n│   └───000xxx.wav\n└───speaker1\n    ├───000001.wav\n    ├───...\n    └───000xxx.wav\n```\n\n## Install dependencies\n\n- 1 software dependency\n  \n  > pip install -r requirements.txt\n\n- 2 download [release](https://github.com/PlayVoice/NSF-BigVGAN/releases/tag/debug) model, and test\n  \n  > python nsf_bigvgan_inference.py --config configs/nsf_bigvgan.yaml --model nsf_bigvgan_g.pth --wave test.wav\n\n## Data preprocessing\n\n- 1， re-sampling: 32kHz\n\n    > python prepare/preprocess_a.py -w ./data_raw -o ./data_bigvgan/waves-32k\n\n- 3， extract pitch\n\n    > python prepare/preprocess_f0.py -w data_bigvgan/waves-32k/ -p data_bigvgan/pitch\n\n- 4， extract mel: [100, length]\n\n    > python prepare/preprocess_spec.py -w data_bigvgan/waves-32k/ -s data_bigvgan/mel\n\n- 5， generate training index\n\n    > python prepare/preprocess_train.py\n\n```shell\ndata_bigvgan/\n│\n└── waves-32k\n│    └── speaker0\n│    │      ├── 000001.wav\n│    │      └── 000xxx.wav\n│    └── speaker1\n│           ├── 000001.wav\n│           └── 000xxx.wav\n└── pitch\n│    └── speaker0\n│    │      ├── 000001.pit.npy\n│    │      └── 000xxx.pit.npy\n│    └── speaker1\n│           ├── 000001.pit.npy\n│           └── 000xxx.pit.npy\n└── mel\n     └── speaker0\n     │      ├── 000001.mel.pt\n     │      └── 000xxx.mel.pt\n     └── speaker1\n            ├── 000001.mel.pt\n            └── 000xxx.mel.pt\n\n```\n\n## Train\n\n- 1， start training\n\n    > python nsf_bigvgan_trainer.py -c configs/nsf_bigvgan.yaml -n nsf_bigvgan\n\n- 2， resume training\n\n    > python nsf_bigvgan_trainer.py -c configs/nsf_bigvgan.yaml -n nsf_bigvgan -p chkpt/nsf_bigvgan/***.pth\n\n- 3， view log\n\n    > tensorboard --logdir logs/\n\n\n## Inference\n\n- 1， export inference model\n\n    > python nsf_bigvgan_export.py --config configs/maxgan.yaml --checkpoint_path chkpt/nsf_bigvgan/***.pt\n\n- 2， extract mel\n\n    > python spec/inference.py -w test.wav -m test.mel.pt\n\n- 3， extract F0\n\n    > python pitch/inference.py -w test.wav -p test.csv\n\n- 4， infer\n\n    > python nsf_bigvgan_inference.py --config configs/nsf_bigvgan.yaml --model nsf_bigvgan_g.pth --wave test.wav\n\n    or\n\n    > python nsf_bigvgan_inference.py --config configs/nsf_bigvgan.yaml --model nsf_bigvgan_g.pth --mel test.mel.pt --pit test.csv\n\n## Augmentation of mel\nFor the over smooth output of acoustic model, we use gaussian blur for mel when train vocoder\n```\n# gaussian blur\nmodel_b = get_gaussian_kernel(kernel_size=5, sigma=2, channels=1).to(device)\n# mel blur\nmel_b = mel[:, None, :, :]\nmel_b = model_b(mel_b)\nmel_b = torch.squeeze(mel_b, 1)\nmel_r = torch.rand(1).to(device) * 0.5\nmel_b = (1 - mel_r) * mel_b + mel_r * mel\n# generator\noptim_g.zero_grad()\nfake_audio = model_g(mel_b, pit)\n```\n![mel_gaussian_blur](https://github.com/PlayVoice/NSF-BigVGAN/assets/16432329/7fa96ef7-5e3b-4ae6-bc61-9b6da3b9d0b9)\n\n## Source of code and References\n\nhttps://github.com/nii-yamagishilab/project-NN-Pytorch-scripts/tree/master/project/01-nsf\n\nhttps://github.com/mindslab-ai/univnet [[paper]](https://arxiv.org/abs/2106.07889)\n\nhttps://github.com/NVIDIA/BigVGAN [[paper]](https://arxiv.org/abs/2206.04658)"
  },
  {
    "path": "bigvgan/configs/nsf_bigvgan.yaml",
    "content": "data:\n  train_file: 'files/train.txt'\n  val_file: 'files/valid.txt'\n#############################\ntrain:\n  num_workers: 4\n  batch_size: 8\n  optimizer: 'adam'\n  seed: 1234\n  adam:\n    lr: 0.0002\n    beta1: 0.8\n    beta2: 0.99\n  mel_lamb: 5\n  stft_lamb: 2.5\n  pretrain: ''\n  lora: False\n#############################\naudio:\n  n_mel_channels: 100\n  segment_length: 12800 # Should be multiple of 320\n  filter_length: 1024\n  hop_length: 320 # WARNING: this can't be changed.\n  win_length: 1024\n  sampling_rate: 32000\n  mel_fmin: 40.0\n  mel_fmax: 16000.0\n#############################\ngen:\n  mel_channels: 100\n  upsample_rates: [5,4,2,2,2,2]\n  upsample_kernel_sizes: [15,8,4,4,4,4]\n  upsample_initial_channel: 320\n  resblock_kernel_sizes: [3,7,11]\n  resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]\n#############################\nmpd:\n  periods: [2,3,5,7,11]\n  kernel_size: 5\n  stride: 3\n  use_spectral_norm: False\n  lReLU_slope: 0.2\n#############################\nmrd:\n  resolutions: \"[(1024, 120, 600), (2048, 240, 1200), (4096, 480, 2400), (512, 50, 240)]\" # (filter_length, hop_length, win_length)\n  use_spectral_norm: False\n  lReLU_slope: 0.2\n#############################\ndist_config:\n  dist_backend: \"nccl\"\n  dist_url: \"tcp://localhost:54321\"\n  world_size: 1\n#############################\nlog:\n  info_interval: 100\n  eval_interval: 1000\n  save_interval: 10000\n  num_audio: 6\n  pth_dir: 'chkpt'\n  log_dir: 'logs'\n"
  },
  {
    "path": "bigvgan/inference.py",
    "content": "import sys,os\nsys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))\nimport torch\nimport argparse\n\nfrom omegaconf import OmegaConf\nfrom scipy.io.wavfile import write\nfrom bigvgan.model.generator import Generator\nfrom pitch import load_csv_pitch\n\n\ndef load_bigv_model(checkpoint_path, model):\n    assert os.path.isfile(checkpoint_path)\n    checkpoint_dict = torch.load(checkpoint_path, map_location=\"cpu\")\n    saved_state_dict = checkpoint_dict[\"model_g\"]\n    state_dict = model.state_dict()\n    new_state_dict = {}\n    for k, v in state_dict.items():\n        try:\n            new_state_dict[k] = saved_state_dict[k]\n        except:\n            print(\"%s is not in the checkpoint\" % k)\n            new_state_dict[k] = v\n    model.load_state_dict(new_state_dict)\n    return model\n\n\ndef main(args):\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    hp = OmegaConf.load(args.config)\n    model = Generator(hp)\n    load_bigv_model(args.model, model)\n    model.eval()\n    model.to(device)\n\n    mel = torch.load(args.mel)\n\n    pit = load_csv_pitch(args.pit)\n    pit = torch.FloatTensor(pit)\n\n    len_pit = pit.size()[0]\n    len_mel = mel.size()[1]\n    len_min = min(len_pit, len_mel)\n    pit = pit[:len_min]\n    mel = mel[:, :len_min]\n\n    with torch.no_grad():\n        mel = mel.unsqueeze(0).to(device)\n        pit = pit.unsqueeze(0).to(device)\n        audio = model.inference(mel, pit)\n        audio = audio.cpu().detach().numpy()\n\n        pitwav = model.pitch2wav(pit)\n        pitwav = pitwav.cpu().detach().numpy()\n\n    write(\"gvc_out.wav\", hp.audio.sampling_rate, audio)\n    write(\"gvc_pitch.wav\", hp.audio.sampling_rate, pitwav)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--mel', type=str,\n                        help=\"Path of content vector.\")\n    parser.add_argument('--pit', type=str,\n                        help=\"Path of pitch csv file.\")\n    args = parser.parse_args()\n\n    args.config = \"./bigvgan/configs/nsf_bigvgan.yaml\"\n    args.model = \"./bigvgan_pretrain/nsf_bigvgan_pretrain_32K.pth\"\n\n    main(args)\n"
  },
  {
    "path": "bigvgan/model/__init__.py",
    "content": "from .alias.act import SnakeAlias"
  },
  {
    "path": "bigvgan/model/alias/__init__.py",
    "content": "# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0\n#   LICENSE is in incl_licenses directory.\n\nfrom .filter import *\nfrom .resample import *\nfrom .act import *"
  },
  {
    "path": "bigvgan/model/alias/act.py",
    "content": "# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0\n#   LICENSE is in incl_licenses directory.\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom torch import sin, pow\nfrom torch.nn import Parameter\nfrom .resample import UpSample1d, DownSample1d\n\n\nclass Activation1d(nn.Module):\n    def __init__(self,\n                 activation,\n                 up_ratio: int = 2,\n                 down_ratio: int = 2,\n                 up_kernel_size: int = 12,\n                 down_kernel_size: int = 12):\n        super().__init__()\n        self.up_ratio = up_ratio\n        self.down_ratio = down_ratio\n        self.act = activation\n        self.upsample = UpSample1d(up_ratio, up_kernel_size)\n        self.downsample = DownSample1d(down_ratio, down_kernel_size)\n\n    # x: [B,C,T]\n    def forward(self, x):\n        x = self.upsample(x)\n        x = self.act(x)\n        x = self.downsample(x)\n\n        return x\n\n\nclass SnakeBeta(nn.Module):\n    '''\n    A modified Snake function which uses separate parameters for the magnitude of the periodic components\n    Shape:\n        - Input: (B, C, T)\n        - Output: (B, C, T), same shape as the input\n    Parameters:\n        - alpha - trainable parameter that controls frequency\n        - beta - trainable parameter that controls magnitude\n    References:\n        - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:\n        https://arxiv.org/abs/2006.08195\n    Examples:\n        >>> a1 = snakebeta(256)\n        >>> x = torch.randn(256)\n        >>> x = a1(x)\n    '''\n\n    def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):\n        '''\n        Initialization.\n        INPUT:\n            - in_features: shape of the input\n            - alpha - trainable parameter that controls frequency\n            - beta - trainable parameter that controls magnitude\n            alpha is initialized to 1 by default, higher values = higher-frequency.\n            beta is initialized to 1 by default, higher values = higher-magnitude.\n            alpha will be trained along with the rest of your model.\n        '''\n        super(SnakeBeta, self).__init__()\n        self.in_features = in_features\n        # initialize alpha\n        self.alpha_logscale = alpha_logscale\n        if self.alpha_logscale:  # log scale alphas initialized to zeros\n            self.alpha = Parameter(torch.zeros(in_features) * alpha)\n            self.beta = Parameter(torch.zeros(in_features) * alpha)\n        else:  # linear scale alphas initialized to ones\n            self.alpha = Parameter(torch.ones(in_features) * alpha)\n            self.beta = Parameter(torch.ones(in_features) * alpha)\n        self.alpha.requires_grad = alpha_trainable\n        self.beta.requires_grad = alpha_trainable\n        self.no_div_by_zero = 0.000000001\n\n    def forward(self, x):\n        '''\n        Forward pass of the function.\n        Applies the function to the input elementwise.\n        SnakeBeta = x + 1/b * sin^2 (xa)\n        '''\n        alpha = self.alpha.unsqueeze(\n            0).unsqueeze(-1)  # line up with x to [B, C, T]\n        beta = self.beta.unsqueeze(0).unsqueeze(-1)\n        if self.alpha_logscale:\n            alpha = torch.exp(alpha)\n            beta = torch.exp(beta)\n        x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)\n        return x\n\n\nclass Mish(nn.Module):\n    \"\"\"\n    Mish activation function is proposed in \"Mish: A Self \n    Regularized Non-Monotonic Neural Activation Function\" \n    paper, https://arxiv.org/abs/1908.08681.\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, x):\n        return x * torch.tanh(F.softplus(x))\n\n\nclass SnakeAlias(nn.Module):\n    def __init__(self,\n                 channels,\n                 up_ratio: int = 2,\n                 down_ratio: int = 2,\n                 up_kernel_size: int = 12,\n                 down_kernel_size: int = 12):\n        super().__init__()\n        self.up_ratio = up_ratio\n        self.down_ratio = down_ratio\n        self.act = SnakeBeta(channels, alpha_logscale=True)\n        self.upsample = UpSample1d(up_ratio, up_kernel_size)\n        self.downsample = DownSample1d(down_ratio, down_kernel_size)\n\n    # x: [B,C,T]\n    def forward(self, x):\n        x = self.upsample(x)\n        x = self.act(x)\n        x = self.downsample(x)\n\n        return x"
  },
  {
    "path": "bigvgan/model/alias/filter.py",
    "content": "# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0\n#   LICENSE is in incl_licenses directory.\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport math\n\nif 'sinc' in dir(torch):\n    sinc = torch.sinc\nelse:\n    # This code is adopted from adefossez's julius.core.sinc under the MIT License\n    # https://adefossez.github.io/julius/julius/core.html\n    #   LICENSE is in incl_licenses directory.\n    def sinc(x: torch.Tensor):\n        \"\"\"\n        Implementation of sinc, i.e. sin(pi * x) / (pi * x)\n        __Warning__: Different to julius.sinc, the input is multiplied by `pi`!\n        \"\"\"\n        return torch.where(x == 0,\n                           torch.tensor(1., device=x.device, dtype=x.dtype),\n                           torch.sin(math.pi * x) / math.pi / x)\n\n\n# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License\n# https://adefossez.github.io/julius/julius/lowpass.html\n#   LICENSE is in incl_licenses directory.\ndef kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]\n    even = (kernel_size % 2 == 0)\n    half_size = kernel_size // 2\n\n    #For kaiser window\n    delta_f = 4 * half_width\n    A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95\n    if A > 50.:\n        beta = 0.1102 * (A - 8.7)\n    elif A >= 21.:\n        beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)\n    else:\n        beta = 0.\n    window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)\n\n    # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio\n    if even:\n        time = (torch.arange(-half_size, half_size) + 0.5)\n    else:\n        time = torch.arange(kernel_size) - half_size\n    if cutoff == 0:\n        filter_ = torch.zeros_like(time)\n    else:\n        filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)\n        # Normalize filter to have sum = 1, otherwise we will have a small leakage\n        # of the constant component in the input signal.\n        filter_ /= filter_.sum()\n        filter = filter_.view(1, 1, kernel_size)\n\n    return filter\n\n\nclass LowPassFilter1d(nn.Module):\n    def __init__(self,\n                 cutoff=0.5,\n                 half_width=0.6,\n                 stride: int = 1,\n                 padding: bool = True,\n                 padding_mode: str = 'replicate',\n                 kernel_size: int = 12):\n        # kernel_size should be even number for stylegan3 setup,\n        # in this implementation, odd number is also possible.\n        super().__init__()\n        if cutoff < -0.:\n            raise ValueError(\"Minimum cutoff must be larger than zero.\")\n        if cutoff > 0.5:\n            raise ValueError(\"A cutoff above 0.5 does not make sense.\")\n        self.kernel_size = kernel_size\n        self.even = (kernel_size % 2 == 0)\n        self.pad_left = kernel_size // 2 - int(self.even)\n        self.pad_right = kernel_size // 2\n        self.stride = stride\n        self.padding = padding\n        self.padding_mode = padding_mode\n        filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)\n        self.register_buffer(\"filter\", filter)\n\n    #input [B, C, T]\n    def forward(self, x):\n        _, C, _ = x.shape\n\n        if self.padding:\n            x = F.pad(x, (self.pad_left, self.pad_right),\n                      mode=self.padding_mode)\n        out = F.conv1d(x, self.filter.expand(C, -1, -1),\n                       stride=self.stride, groups=C)\n\n        return out"
  },
  {
    "path": "bigvgan/model/alias/resample.py",
    "content": "# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0\n#   LICENSE is in incl_licenses directory.\n\nimport torch.nn as nn\nfrom torch.nn import functional as F\nfrom .filter import LowPassFilter1d\nfrom .filter import kaiser_sinc_filter1d\n\n\nclass UpSample1d(nn.Module):\n    def __init__(self, ratio=2, kernel_size=None):\n        super().__init__()\n        self.ratio = ratio\n        self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size\n        self.stride = ratio\n        self.pad = self.kernel_size // ratio - 1\n        self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2\n        self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2\n        filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,\n                                      half_width=0.6 / ratio,\n                                      kernel_size=self.kernel_size)\n        self.register_buffer(\"filter\", filter)\n\n    # x: [B, C, T]\n    def forward(self, x):\n        _, C, _ = x.shape\n\n        x = F.pad(x, (self.pad, self.pad), mode='replicate')\n        x = self.ratio * F.conv_transpose1d(\n            x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)\n        x = x[..., self.pad_left:-self.pad_right]\n\n        return x\n\n\nclass DownSample1d(nn.Module):\n    def __init__(self, ratio=2, kernel_size=None):\n        super().__init__()\n        self.ratio = ratio\n        self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size\n        self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,\n                                       half_width=0.6 / ratio,\n                                       stride=ratio,\n                                       kernel_size=self.kernel_size)\n\n    def forward(self, x):\n        xx = self.lowpass(x)\n\n        return xx"
  },
  {
    "path": "bigvgan/model/bigv.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom torch.nn import Conv1d\nfrom torch.nn.utils import weight_norm, remove_weight_norm\nfrom .alias.act import SnakeAlias\n\n\ndef init_weights(m, mean=0.0, std=0.01):\n    classname = m.__class__.__name__\n    if classname.find(\"Conv\") != -1:\n        m.weight.data.normal_(mean, std)\n\n\ndef get_padding(kernel_size, dilation=1):\n    return int((kernel_size*dilation - dilation)/2)\n\n\nclass AMPBlock(torch.nn.Module):\n    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):\n        super(AMPBlock, self).__init__()\n        self.convs1 = nn.ModuleList([\n            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],\n                               padding=get_padding(kernel_size, dilation[0]))),\n            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],\n                               padding=get_padding(kernel_size, dilation[1]))),\n            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],\n                               padding=get_padding(kernel_size, dilation[2])))\n        ])\n        self.convs1.apply(init_weights)\n\n        self.convs2 = nn.ModuleList([\n            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,\n                               padding=get_padding(kernel_size, 1))),\n            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,\n                               padding=get_padding(kernel_size, 1))),\n            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,\n                               padding=get_padding(kernel_size, 1)))\n        ])\n        self.convs2.apply(init_weights)\n\n        # total number of conv layers\n        self.num_layers = len(self.convs1) + len(self.convs2)\n\n        # periodic nonlinearity with snakebeta function and anti-aliasing\n        self.activations = nn.ModuleList([\n            SnakeAlias(channels) for _ in range(self.num_layers)\n        ])\n\n    def forward(self, x):\n        acts1, acts2 = self.activations[::2], self.activations[1::2]\n        for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):\n            xt = a1(x)\n            xt = c1(xt)\n            xt = a2(xt)\n            xt = c2(xt)\n            x = xt + x\n        return x\n\n    def remove_weight_norm(self):\n        for l in self.convs1:\n            remove_weight_norm(l)\n        for l in self.convs2:\n            remove_weight_norm(l)"
  },
  {
    "path": "bigvgan/model/generator.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\n\nfrom torch.nn import Conv1d\nfrom torch.nn import ConvTranspose1d\nfrom torch.nn.utils import weight_norm\nfrom torch.nn.utils import remove_weight_norm\n\nfrom .nsf import SourceModuleHnNSF\nfrom .bigv import init_weights, AMPBlock, SnakeAlias\n\n\nclass Generator(torch.nn.Module):\n    # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.\n    def __init__(self, hp):\n        super(Generator, self).__init__()\n        self.hp = hp\n        self.num_kernels = len(hp.gen.resblock_kernel_sizes)\n        self.num_upsamples = len(hp.gen.upsample_rates)\n        # pre conv\n        self.conv_pre = nn.utils.weight_norm(\n            Conv1d(hp.gen.mel_channels, hp.gen.upsample_initial_channel, 7, 1, padding=3))\n        # nsf\n        self.f0_upsamp = torch.nn.Upsample(\n            scale_factor=np.prod(hp.gen.upsample_rates))\n        self.m_source = SourceModuleHnNSF(sampling_rate=hp.audio.sampling_rate)\n        self.noise_convs = nn.ModuleList()\n        # transposed conv-based upsamplers. does not apply anti-aliasing\n        self.ups = nn.ModuleList()\n        for i, (u, k) in enumerate(zip(hp.gen.upsample_rates, hp.gen.upsample_kernel_sizes)):\n            # print(f'ups: {i} {k}, {u}, {(k - u) // 2}')\n            # base\n            self.ups.append(\n                weight_norm(\n                    ConvTranspose1d(\n                        hp.gen.upsample_initial_channel // (2 ** i),\n                        hp.gen.upsample_initial_channel // (2 ** (i + 1)),\n                        k,\n                        u,\n                        padding=(k - u) // 2)\n                )\n            )\n            # nsf\n            if i + 1 < len(hp.gen.upsample_rates):\n                stride_f0 = np.prod(hp.gen.upsample_rates[i + 1:])\n                stride_f0 = int(stride_f0)\n                self.noise_convs.append(\n                    Conv1d(\n                        1,\n                        hp.gen.upsample_initial_channel // (2 ** (i + 1)),\n                        kernel_size=stride_f0 * 2,\n                        stride=stride_f0,\n                        padding=stride_f0 // 2,\n                    )\n                )\n            else:\n                self.noise_convs.append(\n                    Conv1d(1, hp.gen.upsample_initial_channel //\n                           (2 ** (i + 1)), kernel_size=1)\n                )\n\n        # residual blocks using anti-aliased multi-periodicity composition modules (AMP)\n        self.resblocks = nn.ModuleList()\n        for i in range(len(self.ups)):\n            ch = hp.gen.upsample_initial_channel // (2 ** (i + 1))\n            for k, d in zip(hp.gen.resblock_kernel_sizes, hp.gen.resblock_dilation_sizes):\n                self.resblocks.append(AMPBlock(ch, k, d))\n\n        # post conv\n        self.activation_post = SnakeAlias(ch)\n        self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)\n        # weight initialization\n        self.ups.apply(init_weights)\n\n    def forward(self, x, f0, train=True):\n        # nsf\n        f0 = f0[:, None]\n        f0 = self.f0_upsamp(f0).transpose(1, 2)\n        har_source = self.m_source(f0)\n        har_source = har_source.transpose(1, 2)\n        # pre conv\n        if train:\n            x = x + torch.randn_like(x) * 0.1     # Perturbation\n        x = self.conv_pre(x)\n        x = x * torch.tanh(F.softplus(x))\n\n        for i in range(self.num_upsamples):\n            # upsampling\n            x = self.ups[i](x)\n            # nsf\n            x_source = self.noise_convs[i](har_source)\n            x = x + x_source\n            # AMP blocks\n            xs = None\n            for j in range(self.num_kernels):\n                if xs is None:\n                    xs = self.resblocks[i * self.num_kernels + j](x)\n                else:\n                    xs += self.resblocks[i * self.num_kernels + j](x)\n            x = xs / self.num_kernels\n\n        # post conv\n        x = self.activation_post(x)\n        x = self.conv_post(x)\n        x = torch.tanh(x)\n        return x\n\n    def remove_weight_norm(self):\n        for l in self.ups:\n            remove_weight_norm(l)\n        for l in self.resblocks:\n            l.remove_weight_norm()\n        remove_weight_norm(self.conv_pre)\n\n    def eval(self, inference=False):\n        super(Generator, self).eval()\n        # don't remove weight norm while validation in training loop\n        if inference:\n            self.remove_weight_norm()\n\n    def inference(self, mel, f0):\n        MAX_WAV_VALUE = 32768.0\n        audio = self.forward(mel, f0, False)\n        audio = audio.squeeze()  # collapse all dimension except time axis\n        audio = MAX_WAV_VALUE * audio\n        audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1)\n        audio = audio.short()\n        return audio\n\n    def pitch2wav(self, f0):\n        MAX_WAV_VALUE = 32768.0\n        # nsf\n        f0 = f0[:, None]\n        f0 = self.f0_upsamp(f0).transpose(1, 2)\n        har_source = self.m_source(f0)\n        audio = har_source.transpose(1, 2)\n        audio = audio.squeeze()  # collapse all dimension except time axis\n        audio = MAX_WAV_VALUE * audio\n        audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1)\n        audio = audio.short()\n        return audio\n"
  },
  {
    "path": "bigvgan/model/nsf.py",
    "content": "import torch\nimport numpy as np\nimport sys\nimport torch.nn.functional as torch_nn_func\n\n\nclass PulseGen(torch.nn.Module):\n    \"\"\"Definition of Pulse train generator\n\n    There are many ways to implement pulse generator.\n    Here, PulseGen is based on SinGen. For a perfect\n    \"\"\"\n\n    def __init__(self, samp_rate, pulse_amp=0.1, noise_std=0.003, voiced_threshold=0):\n        super(PulseGen, self).__init__()\n        self.pulse_amp = pulse_amp\n        self.sampling_rate = samp_rate\n        self.voiced_threshold = voiced_threshold\n        self.noise_std = noise_std\n        self.l_sinegen = SineGen(\n            self.sampling_rate,\n            harmonic_num=0,\n            sine_amp=self.pulse_amp,\n            noise_std=0,\n            voiced_threshold=self.voiced_threshold,\n            flag_for_pulse=True,\n        )\n\n    def forward(self, f0):\n        \"\"\"Pulse train generator\n        pulse_train, uv = forward(f0)\n        input F0: tensor(batchsize=1, length, dim=1)\n                  f0 for unvoiced steps should be 0\n        output pulse_train: tensor(batchsize=1, length, dim)\n        output uv: tensor(batchsize=1, length, 1)\n\n        Note: self.l_sine doesn't make sure that the initial phase of\n        a voiced segment is np.pi, the first pulse in a voiced segment\n        may not be at the first time step within a voiced segment\n        \"\"\"\n        with torch.no_grad():\n            sine_wav, uv, noise = self.l_sinegen(f0)\n\n            # sine without additive noise\n            pure_sine = sine_wav - noise\n\n            # step t corresponds to a pulse if\n            # sine[t] > sine[t+1] & sine[t] > sine[t-1]\n            # & sine[t-1], sine[t+1], and sine[t] are voiced\n            # or\n            # sine[t] is voiced, sine[t-1] is unvoiced\n            # we use torch.roll to simulate sine[t+1] and sine[t-1]\n            sine_1 = torch.roll(pure_sine, shifts=1, dims=1)\n            uv_1 = torch.roll(uv, shifts=1, dims=1)\n            uv_1[:, 0, :] = 0\n            sine_2 = torch.roll(pure_sine, shifts=-1, dims=1)\n            uv_2 = torch.roll(uv, shifts=-1, dims=1)\n            uv_2[:, -1, :] = 0\n\n            loc = (pure_sine > sine_1) * (pure_sine > sine_2) \\\n                  * (uv_1 > 0) * (uv_2 > 0) * (uv > 0) \\\n                  + (uv_1 < 1) * (uv > 0)\n\n            # pulse train without noise\n            pulse_train = pure_sine * loc\n\n            # additive noise to pulse train\n            # note that noise from sinegen is zero in voiced regions\n            pulse_noise = torch.randn_like(pure_sine) * self.noise_std\n\n            # with additive noise on pulse, and unvoiced regions\n            pulse_train += pulse_noise * loc + pulse_noise * (1 - uv)\n        return pulse_train, sine_wav, uv, pulse_noise\n\n\nclass SignalsConv1d(torch.nn.Module):\n    \"\"\"Filtering input signal with time invariant filter\n    Note: FIRFilter conducted filtering given fixed FIR weight\n          SignalsConv1d convolves two signals\n    Note: this is based on torch.nn.functional.conv1d\n\n    \"\"\"\n\n    def __init__(self):\n        super(SignalsConv1d, self).__init__()\n\n    def forward(self, signal, system_ir):\n        \"\"\"output = forward(signal, system_ir)\n\n        signal:    (batchsize, length1, dim)\n        system_ir: (length2, dim)\n\n        output:    (batchsize, length1, dim)\n        \"\"\"\n        if signal.shape[-1] != system_ir.shape[-1]:\n            print(\"Error: SignalsConv1d expects shape:\")\n            print(\"signal    (batchsize, length1, dim)\")\n            print(\"system_id (batchsize, length2, dim)\")\n            print(\"But received signal: {:s}\".format(str(signal.shape)))\n            print(\" system_ir: {:s}\".format(str(system_ir.shape)))\n            sys.exit(1)\n        padding_length = system_ir.shape[0] - 1\n        groups = signal.shape[-1]\n\n        # pad signal on the left\n        signal_pad = torch_nn_func.pad(signal.permute(0, 2, 1), (padding_length, 0))\n        # prepare system impulse response as (dim, 1, length2)\n        # also flip the impulse response\n        ir = torch.flip(system_ir.unsqueeze(1).permute(2, 1, 0), dims=[2])\n        # convolute\n        output = torch_nn_func.conv1d(signal_pad, ir, groups=groups)\n        return output.permute(0, 2, 1)\n\n\nclass CyclicNoiseGen_v1(torch.nn.Module):\n    \"\"\"CyclicnoiseGen_v1\n    Cyclic noise with a single parameter of beta.\n    Pytorch v1 implementation assumes f_t is also fixed\n    \"\"\"\n\n    def __init__(self, samp_rate, noise_std=0.003, voiced_threshold=0):\n        super(CyclicNoiseGen_v1, self).__init__()\n        self.samp_rate = samp_rate\n        self.noise_std = noise_std\n        self.voiced_threshold = voiced_threshold\n\n        self.l_pulse = PulseGen(\n            samp_rate,\n            pulse_amp=1.0,\n            noise_std=noise_std,\n            voiced_threshold=voiced_threshold,\n        )\n        self.l_conv = SignalsConv1d()\n\n    def noise_decay(self, beta, f0mean):\n        \"\"\"decayed_noise = noise_decay(beta, f0mean)\n        decayed_noise =  n[t]exp(-t * f_mean / beta / samp_rate)\n\n        beta: (dim=1) or (batchsize=1, 1, dim=1)\n        f0mean (batchsize=1, 1, dim=1)\n\n        decayed_noise (batchsize=1, length, dim=1)\n        \"\"\"\n        with torch.no_grad():\n            # exp(-1.0 n / T) < 0.01 => n > -log(0.01)*T = 4.60*T\n            # truncate the noise when decayed by -40 dB\n            length = 4.6 * self.samp_rate / f0mean\n            length = length.int()\n            time_idx = torch.arange(0, length, device=beta.device)\n            time_idx = time_idx.unsqueeze(0).unsqueeze(2)\n            time_idx = time_idx.repeat(beta.shape[0], 1, beta.shape[2])\n\n        noise = torch.randn(time_idx.shape, device=beta.device)\n\n        # due to Pytorch implementation, use f0_mean as the f0 factor\n        decay = torch.exp(-time_idx * f0mean / beta / self.samp_rate)\n        return noise * self.noise_std * decay\n\n    def forward(self, f0s, beta):\n        \"\"\"Producde cyclic-noise\"\"\"\n        # pulse train\n        pulse_train, sine_wav, uv, noise = self.l_pulse(f0s)\n        pure_pulse = pulse_train - noise\n\n        # decayed_noise (length, dim=1)\n        if (uv < 1).all():\n            # all unvoiced\n            cyc_noise = torch.zeros_like(sine_wav)\n        else:\n            f0mean = f0s[uv > 0].mean()\n\n            decayed_noise = self.noise_decay(beta, f0mean)[0, :, :]\n            # convolute\n            cyc_noise = self.l_conv(pure_pulse, decayed_noise)\n\n        # add noise in invoiced segments\n        cyc_noise = cyc_noise + noise * (1.0 - uv)\n        return cyc_noise, pulse_train, sine_wav, uv, noise\n\n\nclass SineGen(torch.nn.Module):\n    \"\"\"Definition of sine generator\n    SineGen(samp_rate, harmonic_num = 0,\n            sine_amp = 0.1, noise_std = 0.003,\n            voiced_threshold = 0,\n            flag_for_pulse=False)\n\n    samp_rate: sampling rate in Hz\n    harmonic_num: number of harmonic overtones (default 0)\n    sine_amp: amplitude of sine-wavefrom (default 0.1)\n    noise_std: std of Gaussian noise (default 0.003)\n    voiced_thoreshold: F0 threshold for U/V classification (default 0)\n    flag_for_pulse: this SinGen is used inside PulseGen (default False)\n\n    Note: when flag_for_pulse is True, the first time step of a voiced\n        segment is always sin(np.pi) or cos(0)\n    \"\"\"\n\n    def __init__(\n        self,\n        samp_rate,\n        harmonic_num=0,\n        sine_amp=0.1,\n        noise_std=0.003,\n        voiced_threshold=0,\n        flag_for_pulse=False,\n    ):\n        super(SineGen, self).__init__()\n        self.sine_amp = sine_amp\n        self.noise_std = noise_std\n        self.harmonic_num = harmonic_num\n        self.dim = self.harmonic_num + 1\n        self.sampling_rate = samp_rate\n        self.voiced_threshold = voiced_threshold\n        self.flag_for_pulse = flag_for_pulse\n\n    def _f02uv(self, f0):\n        # generate uv signal\n        uv = torch.ones_like(f0)\n        uv = uv * (f0 > self.voiced_threshold)\n        return uv\n\n    def _f02sine(self, f0_values):\n        \"\"\"f0_values: (batchsize, length, dim)\n        where dim indicates fundamental tone and overtones\n        \"\"\"\n        # convert to F0 in rad. The interger part n can be ignored\n        # because 2 * np.pi * n doesn't affect phase\n        rad_values = (f0_values / self.sampling_rate) % 1\n\n        # initial phase noise (no noise for fundamental component)\n        rand_ini = torch.rand(\n            f0_values.shape[0], f0_values.shape[2], device=f0_values.device\n        )\n        rand_ini[:, 0] = 0\n        rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini\n\n        # instantanouse phase sine[t] = sin(2*pi \\sum_i=1 ^{t} rad)\n        if not self.flag_for_pulse:\n            # for normal case\n\n            # To prevent torch.cumsum numerical overflow,\n            # it is necessary to add -1 whenever \\sum_k=1^n rad_value_k > 1.\n            # Buffer tmp_over_one_idx indicates the time step to add -1.\n            # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi\n            tmp_over_one = torch.cumsum(rad_values, 1) % 1\n            tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0\n            cumsum_shift = torch.zeros_like(rad_values)\n            cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0\n\n            sines = torch.sin(\n                torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi\n            )\n        else:\n            # If necessary, make sure that the first time step of every\n            # voiced segments is sin(pi) or cos(0)\n            # This is used for pulse-train generation\n\n            # identify the last time step in unvoiced segments\n            uv = self._f02uv(f0_values)\n            uv_1 = torch.roll(uv, shifts=-1, dims=1)\n            uv_1[:, -1, :] = 1\n            u_loc = (uv < 1) * (uv_1 > 0)\n\n            # get the instantanouse phase\n            tmp_cumsum = torch.cumsum(rad_values, dim=1)\n            # different batch needs to be processed differently\n            for idx in range(f0_values.shape[0]):\n                temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]\n                temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]\n                # stores the accumulation of i.phase within\n                # each voiced segments\n                tmp_cumsum[idx, :, :] = 0\n                tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum\n\n            # rad_values - tmp_cumsum: remove the accumulation of i.phase\n            # within the previous voiced segment.\n            i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)\n\n            # get the sines\n            sines = torch.cos(i_phase * 2 * np.pi)\n        return sines\n\n    def forward(self, f0):\n        \"\"\"sine_tensor, uv = forward(f0)\n        input F0: tensor(batchsize=1, length, dim=1)\n                  f0 for unvoiced steps should be 0\n        output sine_tensor: tensor(batchsize=1, length, dim)\n        output uv: tensor(batchsize=1, length, 1)\n        \"\"\"\n        with torch.no_grad():\n            f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)\n            # fundamental component\n            f0_buf[:, :, 0] = f0[:, :, 0]\n            for idx in np.arange(self.harmonic_num):\n                # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic\n                f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)\n\n            # generate sine waveforms\n            sine_waves = self._f02sine(f0_buf) * self.sine_amp\n\n            # generate uv signal\n            # uv = torch.ones(f0.shape)\n            # uv = uv * (f0 > self.voiced_threshold)\n            uv = self._f02uv(f0)\n\n            # noise: for unvoiced should be similar to sine_amp\n            #        std = self.sine_amp/3 -> max value ~ self.sine_amp\n            # .       for voiced regions is self.noise_std\n            noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3\n            noise = noise_amp * torch.randn_like(sine_waves)\n\n            # first: set the unvoiced part to 0 by uv\n            # then: additive noise\n            sine_waves = sine_waves * uv + noise\n        return sine_waves\n\n\nclass SourceModuleCycNoise_v1(torch.nn.Module):\n    \"\"\"SourceModuleCycNoise_v1\n    SourceModule(sampling_rate, noise_std=0.003, voiced_threshod=0)\n    sampling_rate: sampling_rate in Hz\n\n    noise_std: std of Gaussian noise (default: 0.003)\n    voiced_threshold: threshold to set U/V given F0 (default: 0)\n\n    cyc, noise, uv = SourceModuleCycNoise_v1(F0_upsampled, beta)\n    F0_upsampled (batchsize, length, 1)\n    beta (1)\n    cyc (batchsize, length, 1)\n    noise (batchsize, length, 1)\n    uv (batchsize, length, 1)\n    \"\"\"\n\n    def __init__(self, sampling_rate, noise_std=0.003, voiced_threshod=0):\n        super(SourceModuleCycNoise_v1, self).__init__()\n        self.sampling_rate = sampling_rate\n        self.noise_std = noise_std\n        self.l_cyc_gen = CyclicNoiseGen_v1(sampling_rate, noise_std, voiced_threshod)\n\n    def forward(self, f0_upsamped, beta):\n        \"\"\"\n        cyc, noise, uv = SourceModuleCycNoise_v1(F0, beta)\n        F0_upsampled (batchsize, length, 1)\n        beta (1)\n        cyc (batchsize, length, 1)\n        noise (batchsize, length, 1)\n        uv (batchsize, length, 1)\n        \"\"\"\n        # source for harmonic branch\n        cyc, pulse, sine, uv, add_noi = self.l_cyc_gen(f0_upsamped, beta)\n\n        # source for noise branch, in the same shape as uv\n        noise = torch.randn_like(uv) * self.noise_std / 3\n        return cyc, noise, uv\n\n\nclass SourceModuleHnNSF(torch.nn.Module):\n    def __init__(\n        self,\n        sampling_rate=32000,\n        sine_amp=0.1,\n        add_noise_std=0.003,\n        voiced_threshod=0,\n    ):\n        super(SourceModuleHnNSF, self).__init__()\n        harmonic_num = 10\n        self.sine_amp = sine_amp\n        self.noise_std = add_noise_std\n\n        # to produce sine waveforms\n        self.l_sin_gen = SineGen(\n            sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod\n        )\n\n        # to merge source harmonics into a single excitation\n        self.l_tanh = torch.nn.Tanh()\n        self.register_buffer('merge_w', torch.FloatTensor([[\n            0.2942, -0.2243, 0.0033, -0.0056, -0.0020, -0.0046,\n            0.0221, -0.0083, -0.0241, -0.0036, -0.0581]]))\n        self.register_buffer('merge_b', torch.FloatTensor([0.0008]))\n\n    def forward(self, x):\n        \"\"\"\n        Sine_source = SourceModuleHnNSF(F0_sampled)\n        F0_sampled (batchsize, length, 1)\n        Sine_source (batchsize, length, 1)\n        \"\"\"\n        # source for harmonic branch\n        sine_wavs = self.l_sin_gen(x)\n        sine_wavs = torch_nn_func.linear(\n            sine_wavs, self.merge_w) + self.merge_b\n        sine_merge = self.l_tanh(sine_wavs)\n        return sine_merge\n"
  },
  {
    "path": "bigvgan_pretrain/README.md",
    "content": "Path for:\n\n    nsf_bigvgan_pretrain_32K.pth\n\n    DownLoad link:https://github.com/PlayVoice/NSF-BigVGAN/releases/tag/augment\n"
  },
  {
    "path": "configs/base.yaml",
    "content": "train:\n  seed: 37\n  train_files: \"files/train.txt\"\n  valid_files: \"files/valid.txt\"\n  log_dir: 'logs/grad_svc'\n  full_epochs: 500\n  fast_epochs: 100\n  learning_rate: 2e-4\n  batch_size: 8\n  test_size: 4\n  test_step: 5\n  save_step: 10\n  pretrain: \"grad_pretrain/gvc.pretrain.pth\"\n#############################\ndata: \n  segment_size: 16000  # WARNING: base on hop_length\n  max_wav_value: 32768.0\n  sampling_rate: 32000\n  filter_length: 1024\n  hop_length: 320\n  win_length: 1024\n  mel_channels: 100\n  mel_fmin: 40.0\n  mel_fmax: 16000.0\n#############################\ngrad:\n  n_mels: 100\n  n_vecs: 256\n  n_pits: 256\n  n_spks: 256\n  n_embs: 64\n\n  # encoder parameters\n  n_enc_channels: 192\n  filter_channels: 512\n\n  # decoder parameters\n  dec_dim: 96\n  beta_min: 0.05\n  beta_max: 20.0\n  pe_scale: 1000\n"
  },
  {
    "path": "grad/LICENSE",
    "content": "Copyright (c) 2021 Huawei Technologies Co., Ltd.\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE."
  },
  {
    "path": "grad/__init__.py",
    "content": ""
  },
  {
    "path": "grad/base.py",
    "content": "import numpy as np\nimport torch\n\n\nclass BaseModule(torch.nn.Module):\n    def __init__(self):\n        super(BaseModule, self).__init__()\n\n    @property\n    def nparams(self):\n        \"\"\"\n        Returns number of trainable parameters of the module.\n        \"\"\"\n        num_params = 0\n        for name, param in self.named_parameters():\n            if param.requires_grad:\n                num_params += np.prod(param.detach().cpu().numpy().shape)\n        return num_params\n\n\n    def relocate_input(self, x: list):\n        \"\"\"\n        Relocates provided tensors to the same device set for the module.\n        \"\"\"\n        device = next(self.parameters()).device\n        for i in range(len(x)):\n            if isinstance(x[i], torch.Tensor) and x[i].device != device:\n                x[i] = x[i].to(device)\n        return x\n"
  },
  {
    "path": "grad/diffusion.py",
    "content": "import math\nimport torch\nfrom einops import rearrange\nfrom grad.base import BaseModule\nfrom grad.solver import NoiseScheduleVP, MaxLikelihood, GradRaw\n\n\nclass Mish(BaseModule):\n    def forward(self, x):\n        return x * torch.tanh(torch.nn.functional.softplus(x))\n\n\nclass Upsample(BaseModule):\n    def __init__(self, dim):\n        super(Upsample, self).__init__()\n        self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1)\n\n    def forward(self, x):\n        return self.conv(x)\n\n\nclass Downsample(BaseModule):\n    def __init__(self, dim):\n        super(Downsample, self).__init__()\n        self.conv = torch.nn.Conv2d(dim, dim, 3, 2, 1)\n\n    def forward(self, x):\n        return self.conv(x)\n\n\nclass Rezero(BaseModule):\n    def __init__(self, fn):\n        super(Rezero, self).__init__()\n        self.fn = fn\n        self.g = torch.nn.Parameter(torch.zeros(1))\n\n    def forward(self, x):\n        return self.fn(x) * self.g\n\n\nclass Block(BaseModule):\n    def __init__(self, dim, dim_out, groups=8):\n        super(Block, self).__init__()\n        self.block = torch.nn.Sequential(torch.nn.Conv2d(dim, dim_out, 3, \n                                         padding=1), torch.nn.GroupNorm(\n                                         groups, dim_out), Mish())\n\n    def forward(self, x, mask):\n        output = self.block(x * mask)\n        return output * mask\n\n\nclass ResnetBlock(BaseModule):\n    def __init__(self, dim, dim_out, time_emb_dim, groups=8):\n        super(ResnetBlock, self).__init__()\n        self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, \n                                                               dim_out))\n\n        self.block1 = Block(dim, dim_out, groups=groups)\n        self.block2 = Block(dim_out, dim_out, groups=groups)\n        if dim != dim_out:\n            self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)\n        else:\n            self.res_conv = torch.nn.Identity()\n\n    def forward(self, x, mask, time_emb):\n        h = self.block1(x, mask)\n        h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)\n        h = self.block2(h, mask)\n        output = h + self.res_conv(x * mask)\n        return output\n\n\nclass LinearAttention(BaseModule):\n    def __init__(self, dim, heads=4, dim_head=32):\n        super(LinearAttention, self).__init__()\n        self.heads = heads\n        hidden_dim = dim_head * heads\n        self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)\n        self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)            \n\n    def forward(self, x):\n        b, c, h, w = x.shape\n        qkv = self.to_qkv(x)\n        q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', \n                            heads = self.heads, qkv=3)            \n        k = k.softmax(dim=-1)\n        context = torch.einsum('bhdn,bhen->bhde', k, v)\n        out = torch.einsum('bhde,bhdn->bhen', context, q)\n        out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', \n                        heads=self.heads, h=h, w=w)\n        return self.to_out(out)\n\n\nclass Residual(BaseModule):\n    def __init__(self, fn):\n        super(Residual, self).__init__()\n        self.fn = fn\n\n    def forward(self, x, *args, **kwargs):\n        output = self.fn(x, *args, **kwargs) + x\n        return output\n\n\nclass SinusoidalPosEmb(BaseModule):\n    def __init__(self, dim):\n        super(SinusoidalPosEmb, self).__init__()\n        self.dim = dim\n\n    def forward(self, x, scale=1000):\n        device = x.device\n        half_dim = self.dim // 2\n        emb = math.log(10000) / (half_dim - 1)\n        emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)\n        emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)\n        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)\n        return emb\n\n\nclass GradLogPEstimator2d(BaseModule):\n    def __init__(self, dim, dim_mults=(1, 2, 4), emb_dim=64, n_mels=100,\n                 groups=8, pe_scale=1000):\n        super(GradLogPEstimator2d, self).__init__()\n        self.dim = dim\n        self.dim_mults = dim_mults\n        self.emb_dim = emb_dim\n        self.groups = groups\n        self.pe_scale = pe_scale\n\n        self.spk_mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, emb_dim * 4), Mish(),\n                                           torch.nn.Linear(emb_dim * 4, n_mels))\n        self.time_pos_emb = SinusoidalPosEmb(dim)\n        self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(),\n                                       torch.nn.Linear(dim * 4, dim))\n\n        dims = [2 + 1, *map(lambda m: dim * m, dim_mults)]\n        in_out = list(zip(dims[:-1], dims[1:]))\n        self.downs = torch.nn.ModuleList([])\n        self.ups = torch.nn.ModuleList([])\n        num_resolutions = len(in_out)\n\n        for ind, (dim_in, dim_out) in enumerate(in_out):  # 2 downs\n            is_last = ind >= (num_resolutions - 1)\n            self.downs.append(torch.nn.ModuleList([\n                       ResnetBlock(dim_in, dim_out, time_emb_dim=dim),\n                       ResnetBlock(dim_out, dim_out, time_emb_dim=dim),\n                       Residual(Rezero(LinearAttention(dim_out))),\n                       Downsample(dim_out) if not is_last else torch.nn.Identity()]))\n\n        mid_dim = dims[-1]\n        self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)\n        self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))\n        self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)\n\n        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):  # 2 ups\n            self.ups.append(torch.nn.ModuleList([\n                     ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),\n                     ResnetBlock(dim_in, dim_in, time_emb_dim=dim),\n                     Residual(Rezero(LinearAttention(dim_in))),\n                     Upsample(dim_in)]))\n        self.final_block = Block(dim, dim)\n        self.final_conv = torch.nn.Conv2d(dim, 1, 1)\n\n    def forward(self, spk, x, mask, mu, t):\n        s = self.spk_mlp(spk)\n\n        t = self.time_pos_emb(t, scale=self.pe_scale)\n        t = self.mlp(t)\n\n        s = s.unsqueeze(-1).repeat(1, 1, x.shape[-1])\n        x = torch.stack([mu, x, s], 1)\n        mask = mask.unsqueeze(1)\n\n        hiddens = []\n        masks = [mask]\n        for resnet1, resnet2, attn, downsample in self.downs:\n            mask_down = masks[-1]\n            x = resnet1(x, mask_down, t)\n            x = resnet2(x, mask_down, t)\n            x = attn(x)\n            hiddens.append(x)\n            x = downsample(x * mask_down)\n            masks.append(mask_down[:, :, :, ::2])\n\n        masks = masks[:-1]\n        mask_mid = masks[-1]\n        x = self.mid_block1(x, mask_mid, t)\n        x = self.mid_attn(x)\n        x = self.mid_block2(x, mask_mid, t)\n\n        for resnet1, resnet2, attn, upsample in self.ups:\n            mask_up = masks.pop()\n            x = torch.cat((x, hiddens.pop()), dim=1)\n            x = resnet1(x, mask_up, t)\n            x = resnet2(x, mask_up, t)\n            x = attn(x)\n            x = upsample(x * mask_up)\n\n        x = self.final_block(x, mask)\n        output = self.final_conv(x * mask)\n\n        return (output * mask).squeeze(1)\n\n\ndef get_noise(t, beta_init, beta_term, cumulative=False):\n    if cumulative:\n        noise = beta_init*t + 0.5*(beta_term - beta_init)*(t**2)\n    else:\n        noise = beta_init + (beta_term - beta_init)*t\n    return noise\n\n\nclass Diffusion(BaseModule):\n    def __init__(self, n_mels, dim, emb_dim=64,\n                 beta_min=0.05, beta_max=20, pe_scale=1000):\n        super(Diffusion, self).__init__()\n        self.n_mels = n_mels\n        self.beta_min = beta_min\n        self.beta_max = beta_max\n        # self.solver = NoiseScheduleVP()\n        self.solver = MaxLikelihood()\n        # self.solver = GradRaw()\n        self.estimator = GradLogPEstimator2d(dim,\n                                             n_mels=n_mels,\n                                             emb_dim=emb_dim,\n                                             pe_scale=pe_scale)\n\n    def forward_diffusion(self, mel, mask, mu, t):\n        time = t.unsqueeze(-1).unsqueeze(-1)\n        cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True)\n        mean = mel*torch.exp(-0.5*cum_noise) + mu*(1.0 - torch.exp(-0.5*cum_noise))\n        variance = 1.0 - torch.exp(-cum_noise)\n        z = torch.randn(mel.shape, dtype=mel.dtype, device=mel.device, \n                        requires_grad=False)\n        xt = mean + z * torch.sqrt(variance)\n        return xt * mask, z * mask\n\n    def forward(self, spk, z, mask, mu, n_timesteps, stoc=False):\n        return self.solver.reverse_diffusion(self.estimator, spk, z, mask, mu, n_timesteps, stoc)\n\n    def loss_t(self, spk, mel, mask, mu, t):\n        xt, z = self.forward_diffusion(mel, mask, mu, t)\n        time = t.unsqueeze(-1).unsqueeze(-1)\n        cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True)\n        noise_estimation = self.estimator(spk, xt, mask, mu, t)\n        noise_estimation *= torch.sqrt(1.0 - torch.exp(-cum_noise))\n        loss = torch.sum((noise_estimation + z)**2) / (torch.sum(mask)*self.n_mels)\n        return loss, xt\n\n    def compute_loss(self, spk, mel, mask, mu, offset=1e-5):\n        t = torch.rand(mel.shape[0], dtype=mel.dtype, device=mel.device, requires_grad=False)\n        t = torch.clamp(t, offset, 1.0 - offset)\n        return self.loss_t(spk, mel, mask, mu, t)\n"
  },
  {
    "path": "grad/encoder.py",
    "content": "import math\nimport torch\n\nfrom grad.base import BaseModule\nfrom grad.reversal import SpeakerClassifier\nfrom grad.utils import sequence_mask, convert_pad_shape\n\n\nclass LayerNorm(BaseModule):\n    def __init__(self, channels, eps=1e-4):\n        super(LayerNorm, self).__init__()\n        self.channels = channels\n        self.eps = eps\n\n        self.gamma = torch.nn.Parameter(torch.ones(channels))\n        self.beta = torch.nn.Parameter(torch.zeros(channels))\n\n    def forward(self, x):\n        n_dims = len(x.shape)\n        mean = torch.mean(x, 1, keepdim=True)\n        variance = torch.mean((x - mean)**2, 1, keepdim=True)\n\n        x = (x - mean) * torch.rsqrt(variance + self.eps)\n\n        shape = [1, -1] + [1] * (n_dims - 2)\n        x = x * self.gamma.view(*shape) + self.beta.view(*shape)\n        return x\n\n\nclass ConvReluNorm(BaseModule):\n    def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, \n                 n_layers, p_dropout, eps=1e-5):\n        super(ConvReluNorm, self).__init__()\n        self.in_channels = in_channels\n        self.hidden_channels = hidden_channels\n        self.out_channels = out_channels\n        self.kernel_size = kernel_size\n        self.n_layers = n_layers\n        self.p_dropout = p_dropout\n        self.eps = eps\n\n        self.conv_layers = torch.nn.ModuleList()\n        self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, \n                                                kernel_size, padding=kernel_size//2))\n        self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout))\n        for _ in range(n_layers - 1):\n            self.conv_layers.append(torch.nn.Conv1d(hidden_channels, hidden_channels, \n                                                    kernel_size, padding=kernel_size//2))\n        self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)\n        self.proj.weight.data.zero_()\n        self.proj.bias.data.zero_()\n\n    def forward(self, x, x_mask):\n        for i in range(self.n_layers):\n            x = self.conv_layers[i](x * x_mask)\n            x = self.instance_norm(x, x_mask)\n            x = self.relu_drop(x)\n        x = self.proj(x)\n        return x * x_mask\n\n    def instance_norm(self, x, mask, return_mean_std=False):\n        mean, std = self.calc_mean_std(x, mask)\n        x = (x - mean) / std\n        if return_mean_std:\n            return x, mean, std\n        else:\n            return x\n\n    def calc_mean_std(self, x, mask=None):\n        x = x * mask\n        B, C = x.shape[:2]\n        mn = x.view(B, C, -1).mean(-1)\n        sd = (x.view(B, C, -1).var(-1) + self.eps).sqrt()\n        mn = mn.view(B, C, *((len(x.shape) - 2) * [1]))\n        sd = sd.view(B, C, *((len(x.shape) - 2) * [1]))\n        return mn, sd\n\n\nclass MultiHeadAttention(BaseModule):\n    def __init__(self, channels, out_channels, n_heads, window_size=None, \n                 heads_share=True, p_dropout=0.0, proximal_bias=False, \n                 proximal_init=False):\n        super(MultiHeadAttention, self).__init__()\n        assert channels % n_heads == 0\n\n        self.channels = channels\n        self.out_channels = out_channels\n        self.n_heads = n_heads\n        self.window_size = window_size\n        self.heads_share = heads_share\n        self.proximal_bias = proximal_bias\n        self.p_dropout = p_dropout\n        self.attn = None\n\n        self.k_channels = channels // n_heads\n        self.conv_q = torch.nn.Conv1d(channels, channels, 1)\n        self.conv_k = torch.nn.Conv1d(channels, channels, 1)\n        self.conv_v = torch.nn.Conv1d(channels, channels, 1)\n        if window_size is not None:\n            n_heads_rel = 1 if heads_share else n_heads\n            rel_stddev = self.k_channels**-0.5\n            self.emb_rel_k = torch.nn.Parameter(torch.randn(n_heads_rel, \n                             window_size * 2 + 1, self.k_channels) * rel_stddev)\n            self.emb_rel_v = torch.nn.Parameter(torch.randn(n_heads_rel, \n                             window_size * 2 + 1, self.k_channels) * rel_stddev)\n        self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)\n        self.drop = torch.nn.Dropout(p_dropout)\n\n        torch.nn.init.xavier_uniform_(self.conv_q.weight)\n        torch.nn.init.xavier_uniform_(self.conv_k.weight)\n        if proximal_init:\n            self.conv_k.weight.data.copy_(self.conv_q.weight.data)\n            self.conv_k.bias.data.copy_(self.conv_q.bias.data)\n        torch.nn.init.xavier_uniform_(self.conv_v.weight)\n        \n    def forward(self, x, c, attn_mask=None):\n        q = self.conv_q(x)\n        k = self.conv_k(c)\n        v = self.conv_v(c)\n        \n        x, self.attn = self.attention(q, k, v, mask=attn_mask)\n\n        x = self.conv_o(x)\n        return x\n\n    def attention(self, query, key, value, mask=None):\n        b, d, t_s, t_t = (*key.size(), query.size(2))\n        query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)\n        key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)\n        value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)\n\n        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)\n        if self.window_size is not None:\n            assert t_s == t_t, \"Relative attention is only available for self-attention.\"\n            key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)\n            rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)\n            rel_logits = self._relative_position_to_absolute_position(rel_logits)\n            scores_local = rel_logits / math.sqrt(self.k_channels)\n            scores = scores + scores_local\n        if self.proximal_bias:\n            assert t_s == t_t, \"Proximal bias is only available for self-attention.\"\n            scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, \n                                                                    dtype=scores.dtype)\n        if mask is not None:\n            scores = scores.masked_fill(mask == 0, -1e4)\n        p_attn = torch.nn.functional.softmax(scores, dim=-1)\n        p_attn = self.drop(p_attn)\n        output = torch.matmul(p_attn, value)\n        if self.window_size is not None:\n            relative_weights = self._absolute_position_to_relative_position(p_attn)\n            value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)\n            output = output + self._matmul_with_relative_values(relative_weights, \n                                                                value_relative_embeddings)\n        output = output.transpose(2, 3).contiguous().view(b, d, t_t)\n        return output, p_attn\n\n    def _matmul_with_relative_values(self, x, y):\n        ret = torch.matmul(x, y.unsqueeze(0))\n        return ret\n\n    def _matmul_with_relative_keys(self, x, y):\n        ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))\n        return ret\n\n    def _get_relative_embeddings(self, relative_embeddings, length):\n        pad_length = max(length - (self.window_size + 1), 0)\n        slice_start_position = max((self.window_size + 1) - length, 0)\n        slice_end_position = slice_start_position + 2 * length - 1\n        if pad_length > 0:\n            padded_relative_embeddings = torch.nn.functional.pad(\n                            relative_embeddings, convert_pad_shape([[0, 0], \n                            [pad_length, pad_length], [0, 0]]))\n        else:\n            padded_relative_embeddings = relative_embeddings\n        used_relative_embeddings = padded_relative_embeddings[:,\n                                   slice_start_position:slice_end_position]\n        return used_relative_embeddings\n\n    def _relative_position_to_absolute_position(self, x):\n        batch, heads, length, _ = x.size()\n        x = torch.nn.functional.pad(x, convert_pad_shape([[0,0],[0,0],[0,0],[0,1]]))\n        x_flat = x.view([batch, heads, length * 2 * length])\n        x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0,0],[0,0],[0,length-1]]))\n        x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:]\n        return x_final\n\n    def _absolute_position_to_relative_position(self, x):\n        batch, heads, length, _ = x.size()\n        x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]]))\n        x_flat = x.view([batch, heads, length**2 + length*(length - 1)])\n        x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))\n        x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:]\n        return x_final\n\n    def _attention_bias_proximal(self, length):\n        r = torch.arange(length, dtype=torch.float32)\n        diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)\n        return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)\n\n\nclass FFN(BaseModule):\n    def __init__(self, in_channels, out_channels, filter_channels, kernel_size, \n                 p_dropout=0.0):\n        super(FFN, self).__init__()\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.filter_channels = filter_channels\n        self.kernel_size = kernel_size\n        self.p_dropout = p_dropout\n\n        self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, \n                                      padding=kernel_size//2)\n        self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, \n                                      padding=kernel_size//2)\n        self.drop = torch.nn.Dropout(p_dropout)\n\n    def forward(self, x, x_mask):\n        x = self.conv_1(x * x_mask)\n        x = torch.relu(x)\n        x = self.drop(x)\n        x = self.conv_2(x * x_mask)\n        return x * x_mask\n\n\nclass Encoder(BaseModule):\n    def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, \n                 kernel_size=1, p_dropout=0.0, window_size=None, **kwargs):\n        super(Encoder, self).__init__()\n        self.hidden_channels = hidden_channels\n        self.filter_channels = filter_channels\n        self.n_heads = n_heads\n        self.n_layers = n_layers\n        self.kernel_size = kernel_size\n        self.p_dropout = p_dropout\n        self.window_size = window_size\n\n        self.drop = torch.nn.Dropout(p_dropout)\n        self.attn_layers = torch.nn.ModuleList()\n        self.norm_layers_1 = torch.nn.ModuleList()\n        self.ffn_layers = torch.nn.ModuleList()\n        self.norm_layers_2 = torch.nn.ModuleList()\n        for _ in range(self.n_layers):\n            self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels,\n                                    n_heads, window_size=window_size, p_dropout=p_dropout))\n            self.norm_layers_1.append(LayerNorm(hidden_channels))\n            self.ffn_layers.append(FFN(hidden_channels, hidden_channels,\n                                       filter_channels, kernel_size, p_dropout=p_dropout))\n            self.norm_layers_2.append(LayerNorm(hidden_channels))\n\n    def forward(self, x, x_mask):\n        attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)\n        for i in range(self.n_layers):\n            x = x * x_mask\n            y = self.attn_layers[i](x, x, attn_mask)\n            y = self.drop(y)\n            x = self.norm_layers_1[i](x + y)\n            y = self.ffn_layers[i](x, x_mask)\n            y = self.drop(y)\n            x = self.norm_layers_2[i](x + y)\n        x = x * x_mask\n        return x\n\n\nclass TextEncoder(BaseModule):\n    def __init__(self, n_vecs, n_mels, n_embs,\n                 n_channels,\n                 filter_channels,\n                 n_heads=2,\n                 n_layers=6,\n                 kernel_size=3,\n                 p_dropout=0.1,\n                 window_size=4):\n        super(TextEncoder, self).__init__()\n        self.n_vecs = n_vecs\n        self.n_mels = n_mels\n        self.n_embs = n_embs\n        self.n_channels = n_channels\n        self.filter_channels = filter_channels\n        self.n_heads = n_heads\n        self.n_layers = n_layers\n        self.kernel_size = kernel_size\n        self.p_dropout = p_dropout\n        self.window_size = window_size\n\n        self.prenet = ConvReluNorm(n_vecs,\n                                   n_channels,\n                                   n_channels,\n                                   kernel_size=5,\n                                   n_layers=5,\n                                   p_dropout=0.5)\n\n        self.speaker = SpeakerClassifier(\n            n_channels,\n            256,  # n_spks: 256\n        )\n\n        self.encoder = Encoder(n_channels + n_embs + n_embs,\n                               filter_channels,\n                               n_heads,\n                               n_layers,\n                               kernel_size,\n                               p_dropout,\n                               window_size=window_size)\n\n        self.proj_m = torch.nn.Conv1d(n_channels + n_embs + n_embs, n_mels, 1)\n\n    def forward(self, x_lengths, x, pit, spk, training=False):\n        x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)\n        # IN\n        x = self.prenet(x, x_mask)\n        if training:\n            r = self.speaker(x)\n        else:\n            r = None\n        # pitch + speaker\n        spk = spk.unsqueeze(-1).repeat(1, 1, x.shape[-1])\n        x = torch.cat([x, pit], dim=1)\n        x = torch.cat([x, spk], dim=1)\n        x = self.encoder(x, x_mask)\n        mu = self.proj_m(x) * x_mask\n        return mu, x_mask, r\n\n    def fine_tune(self):\n        for p in self.prenet.parameters():\n            p.requires_grad = False\n        for p in self.speaker.parameters():\n            p.requires_grad = False\n"
  },
  {
    "path": "grad/model.py",
    "content": "import math\nimport torch\n\nfrom grad.ssim import SSIM\nfrom grad.base import BaseModule\nfrom grad.encoder import TextEncoder\nfrom grad.diffusion import Diffusion\nfrom grad.utils import f0_to_coarse, rand_ids_segments, slice_segments\n\nSpeakerLoss = torch.nn.CosineEmbeddingLoss()\nSsimLoss = SSIM()\n\nclass GradTTS(BaseModule):\n    def __init__(self, n_mels, n_vecs, n_pits, n_spks, n_embs, \n                 n_enc_channels, filter_channels, \n                 dec_dim, beta_min, beta_max, pe_scale):\n        super(GradTTS, self).__init__()\n        # common\n        self.n_mels = n_mels\n        self.n_vecs = n_vecs\n        self.n_spks = n_spks\n        self.n_embs = n_embs\n        # encoder\n        self.n_enc_channels = n_enc_channels\n        self.filter_channels = filter_channels\n        # decoder\n        self.dec_dim = dec_dim\n        self.beta_min = beta_min\n        self.beta_max = beta_max\n        self.pe_scale = pe_scale\n\n        self.pit_emb = torch.nn.Embedding(n_pits, n_embs)\n        self.spk_emb = torch.nn.Linear(n_spks, n_embs)\n        self.encoder = TextEncoder(n_vecs,\n                                   n_mels,\n                                   n_embs,\n                                   n_enc_channels,\n                                   filter_channels)\n        self.decoder = Diffusion(n_mels, dec_dim, n_embs, beta_min, beta_max, pe_scale)\n\n    def fine_tune(self):\n        for p in self.pit_emb.parameters():\n            p.requires_grad = False\n        for p in self.spk_emb.parameters():\n            p.requires_grad = False\n        self.encoder.fine_tune()\n\n    @torch.no_grad()\n    def forward(self, lengths, vec, pit, spk, n_timesteps, temperature=1.0, stoc=False):\n        \"\"\"\n        Generates mel-spectrogram from vec. Returns:\n            1. encoder outputs\n            2. decoder outputs\n\n        Args:\n            lengths (torch.Tensor): lengths of texts in batch.\n            vec (torch.Tensor): batch of speech vec\n            pit (torch.Tensor): batch of speech pit\n            spk (torch.Tensor): batch of speaker\n            \n            n_timesteps (int): number of steps to use for reverse diffusion in decoder.\n            temperature (float, optional): controls variance of terminal distribution.\n            stoc (bool, optional): flag that adds stochastic term to the decoder sampler.\n                Usually, does not provide synthesis improvements.\n        \"\"\"\n        lengths, vec, pit, spk = self.relocate_input([lengths, vec, pit, spk])\n\n        # Get pitch embedding\n        pit = self.pit_emb(f0_to_coarse(pit))\n\n        # Get speaker embedding\n        spk = self.spk_emb(spk)\n\n        # Transpose\n        vec = torch.transpose(vec, 1, -1)\n        pit = torch.transpose(pit, 1, -1)\n\n        # Get encoder_outputs `mu_x`\n        mu_x, mask_x, _ = self.encoder(lengths, vec, pit, spk)\n        encoder_outputs = mu_x\n\n        # Sample latent representation from terminal distribution N(mu_y, I)\n        z = mu_x + torch.randn_like(mu_x, device=mu_x.device) / temperature\n        # Generate sample by performing reverse dynamics\n        decoder_outputs = self.decoder(spk, z, mask_x, mu_x, n_timesteps, stoc)\n        encoder_outputs = encoder_outputs + torch.randn_like(encoder_outputs)\n        return encoder_outputs, decoder_outputs\n\n    def compute_loss(self, lengths, vec, pit, spk, mel, out_size, skip_diff=False):\n        \"\"\"\n        Computes 2 losses:\n            1. prior loss: loss between mel-spectrogram and encoder outputs.\n            2. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder.\n            \n        Args:\n            lengths (torch.Tensor): lengths of texts in batch.\n            vec (torch.Tensor): batch of speech vec\n            pit (torch.Tensor): batch of speech pit\n            spk (torch.Tensor): batch of speaker\n            mel (torch.Tensor): batch of corresponding mel-spectrogram\n\n            out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained.\n                Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size.\n        \"\"\"\n        lengths, vec, pit, spk, mel = self.relocate_input([lengths, vec, pit, spk, mel])\n\n        # Get pitch embedding\n        pit = self.pit_emb(f0_to_coarse(pit))\n\n        # Get speaker embedding\n        spk_64 = self.spk_emb(spk)\n\n        # Transpose\n        vec = torch.transpose(vec, 1, -1)\n        pit = torch.transpose(pit, 1, -1)\n\n        # Get encoder_outputs `mu_x`\n        mu_x, mask_x, spk_preds = self.encoder(lengths, vec, pit, spk_64, training=True)\n\n        # Compute loss between aligned encoder outputs and mel-spectrogram\n        prior_loss = torch.sum(0.5 * ((mel - mu_x) ** 2 + math.log(2 * math.pi)) * mask_x)\n        prior_loss = prior_loss / (torch.sum(mask_x) * self.n_mels)\n\n        # Mel ssim\n        mel_loss = SsimLoss(mu_x, mel, mask_x)\n\n        # Compute loss of speaker for GRL\n        spk_loss = SpeakerLoss(spk, spk_preds, torch.Tensor(spk_preds.size(0))\n                               .to(spk.device).fill_(1.0))\n\n        # Compute loss of score-based decoder\n        if skip_diff:\n            diff_loss = prior_loss.clone()\n            diff_loss.fill_(0)\n        else:\n            # Cut a small segment of mel-spectrogram in order to increase batch size\n            if not isinstance(out_size, type(None)):\n                ids = rand_ids_segments(lengths, out_size)\n                mel = slice_segments(mel, ids, out_size)\n\n                mask_y = slice_segments(mask_x, ids, out_size)\n                mu_y = slice_segments(mu_x, ids, out_size)\n                mu_y = mu_y + torch.randn_like(mu_y)\n\n            diff_loss, xt = self.decoder.compute_loss(\n                spk_64, mel, mask_y, mu_y)\n\n        return prior_loss, diff_loss, mel_loss, spk_loss\n"
  },
  {
    "path": "grad/reversal.py",
    "content": "# Adapted from https://github.com/ubisoft/ubisoft-laforge-daft-exprt Apache License Version 2.0\n# Unsupervised Domain Adaptation by Backpropagation\n\nimport torch\nimport torch.nn as nn\n\nfrom torch.autograd import Function\nfrom torch.nn.utils import weight_norm\n\n\nclass GradientReversalFunction(Function):\n    @staticmethod\n    def forward(ctx, x, lambda_):\n        ctx.lambda_ = lambda_\n        return x.clone()\n\n    @staticmethod\n    def backward(ctx, grads):\n        lambda_ = ctx.lambda_\n        lambda_ = grads.new_tensor(lambda_)\n        dx = -lambda_ * grads\n        return dx, None\n\n\nclass GradientReversal(torch.nn.Module):\n    ''' Gradient Reversal Layer\n            Y. Ganin, V. Lempitsky,\n            \"Unsupervised Domain Adaptation by Backpropagation\",\n            in ICML, 2015.\n        Forward pass is the identity function\n        In the backward pass, upstream gradients are multiplied by -lambda (i.e. gradient are reversed)\n    '''\n\n    def __init__(self, lambda_reversal=1):\n        super(GradientReversal, self).__init__()\n        self.lambda_ = lambda_reversal\n\n    def forward(self, x):\n        return GradientReversalFunction.apply(x, self.lambda_)\n\n\nclass SpeakerClassifier(nn.Module):\n\n    def __init__(self, idim, odim):\n        super(SpeakerClassifier, self).__init__()\n        self.classifier = nn.Sequential(\n            GradientReversal(lambda_reversal=1),\n            weight_norm(nn.Conv1d(idim, 1024, kernel_size=5, padding=2)),\n            nn.ReLU(),\n            weight_norm(nn.Conv1d(1024, 1024, kernel_size=5, padding=2)),\n            nn.ReLU(),\n            weight_norm(nn.Conv1d(1024, odim, kernel_size=5, padding=2))\n        )\n\n    def forward(self, x):\n        ''' Forward function of Speaker Classifier:\n            x = (B, idim, len)\n        '''\n        # pass through classifier\n        outputs = self.classifier(x)  # (B, nb_speakers)\n        outputs = torch.mean(outputs, dim=-1)\n        return outputs\n"
  },
  {
    "path": "grad/solver.py",
    "content": "import torch\r\n\r\n\r\nclass NoiseScheduleVP:\r\n\r\n    def __init__(self, beta_min=0.05, beta_max=20):\r\n        self.beta_min = beta_min\r\n        self.beta_max = beta_max\r\n        self.T = 1.\r\n    \r\n    def get_noise(self, t, beta_init, beta_term, cumulative=False):\r\n        if cumulative:\r\n            noise = beta_init*t + 0.5*(beta_term - beta_init)*(t**2)\r\n        else:\r\n            noise = beta_init + (beta_term - beta_init)*t\r\n        return noise\r\n\r\n    def marginal_log_mean_coeff(self, t):\r\n        return -0.25 * t**2 * (self.beta_max -\r\n                               self.beta_min) - 0.5 * t * self.beta_min\r\n\r\n    def marginal_std(self, t):\r\n        return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))\r\n\r\n    def marginal_lambda(self, t):\r\n        log_mean_coeff = self.marginal_log_mean_coeff(t)\r\n        log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))\r\n        return log_mean_coeff - log_std\r\n\r\n    def inverse_lambda(self, lamb):\r\n        tmp = 2. * (self.beta_max - self.beta_min) * torch.logaddexp(\r\n            -2. * lamb,\r\n            torch.zeros((1, )).to(lamb))\r\n        Delta = self.beta_min**2 + tmp\r\n        return tmp / (torch.sqrt(Delta) + self.beta_min) / (self.beta_max -\r\n                                                            self.beta_min)\r\n\r\n    def get_time_steps(self, t_T, t_0, N):\r\n        lambda_T = self.marginal_lambda(torch.tensor(t_T))\r\n        lambda_0 = self.marginal_lambda(torch.tensor(t_0))\r\n        logSNR_steps = torch.linspace(lambda_T, lambda_0, N + 1)\r\n        return self.inverse_lambda(logSNR_steps)\r\n    \r\n    @torch.no_grad()\r\n    def reverse_diffusion(self, estimator, spk, z, mask, mu, n_timesteps, stoc):\r\n        print(\"use dpm-solver reverse\")\r\n        xt = z * mask\r\n        yt = xt - mu\r\n        T = 1\r\n        eps = 1e-3\r\n        time = self.get_time_steps(T, eps, n_timesteps)\r\n        for i in range(n_timesteps):\r\n            s = torch.ones((xt.shape[0], )).to(xt.device) * time[i]\r\n            t = torch.ones((xt.shape[0], )).to(xt.device) * time[i + 1]\r\n\r\n            lambda_s = self.marginal_lambda(s)\r\n            lambda_t = self.marginal_lambda(t)\r\n            h = lambda_t - lambda_s\r\n\r\n            log_alpha_s = self.marginal_log_mean_coeff(s)\r\n            log_alpha_t = self.marginal_log_mean_coeff(t)\r\n\r\n            sigma_t = self.marginal_std(t)\r\n            phi_1 = torch.expm1(h)\r\n\r\n            noise_s = estimator(spk, yt + mu, mask, mu, s)\r\n            lt = 1 - torch.exp(-self.get_noise(s, self.beta_min, self.beta_max, cumulative=True))\r\n            a = torch.exp(log_alpha_t - log_alpha_s)\r\n            b = sigma_t * phi_1 * torch.sqrt(lt)\r\n            yt = a * yt + (b * noise_s)\r\n        xt = yt + mu\r\n        return xt\r\n\r\n\r\nclass MaxLikelihood:\r\n\r\n    def __init__(self, beta_min=0.05, beta_max=20):\r\n        self.beta_min = beta_min\r\n        self.beta_max = beta_max\r\n   \r\n    def get_noise(self, t, beta_init, beta_term, cumulative=False):\r\n        if cumulative:\r\n            noise = beta_init*t + 0.5*(beta_term - beta_init)*(t**2)\r\n        else:\r\n            noise = beta_init + (beta_term - beta_init)*t\r\n        return noise\r\n    \r\n    def get_gamma(self, s, t, beta_init, beta_term):\r\n        gamma = beta_init*(t-s) + 0.5*(beta_term-beta_init)*(t**2-s**2)\r\n        gamma = torch.exp(-0.5*gamma)\r\n        return gamma\r\n\r\n    def get_mu(self, s, t):\r\n        gamma_0_s = self.get_gamma(0, s, self.beta_min, self.beta_max)\r\n        gamma_0_t = self.get_gamma(0, t, self.beta_min, self.beta_max)\r\n        gamma_s_t = self.get_gamma(s, t, self.beta_min, self.beta_max)\r\n        mu = gamma_s_t * ((1-gamma_0_s**2) / (1-gamma_0_t**2))\r\n        return mu        \r\n\r\n    def get_nu(self, s, t):\r\n        gamma_0_s = self.get_gamma(0, s, self.beta_min, self.beta_max)\r\n        gamma_0_t = self.get_gamma(0, t, self.beta_min, self.beta_max)\r\n        gamma_s_t = self.get_gamma(s, t, self.beta_min, self.beta_max)\r\n        nu = gamma_0_s * ((1-gamma_s_t**2) / (1-gamma_0_t**2))\r\n        return nu\r\n\r\n    def get_sigma(self, s, t):\r\n        gamma_0_s = self.get_gamma(0, s, self.beta_min, self.beta_max)\r\n        gamma_0_t = self.get_gamma(0, t, self.beta_min, self.beta_max)\r\n        gamma_s_t = self.get_gamma(s, t, self.beta_min, self.beta_max)\r\n        sigma = torch.sqrt(((1 - gamma_0_s**2) * (1 - gamma_s_t**2)) / (1 - gamma_0_t**2))\r\n        return sigma        \r\n\r\n    def get_kappa(self, t, h, noise):\r\n        nu = self.get_nu(t-h, t)\r\n        gamma_0_t = self.get_gamma(0, t, self.beta_min, self.beta_max)\r\n        kappa = (nu*(1-gamma_0_t**2)/(gamma_0_t*noise*h) - 1)\r\n        return kappa\r\n\r\n    def get_omega(self, t, h, noise):\r\n        mu = self.get_mu(t-h, t)\r\n        kappa = self.get_kappa(t, h, noise)\r\n        gamma_0_t = self.get_gamma(0, t, self.beta_min, self.beta_max)\r\n        omega = (mu-1)/(noise*h) + (1+kappa)/(1-gamma_0_t**2) - 0.5\r\n        return omega \r\n\r\n    @torch.no_grad()\r\n    def reverse_diffusion(self, estimator, spk, z, mask, mu, n_timesteps, stoc=False):\r\n        print(\"use MaxLikelihood reverse\")\r\n        h = 1.0 / n_timesteps\r\n        xt = z * mask\r\n        for i in range(n_timesteps):\r\n            t = (1.0 - i*h) * torch.ones(z.shape[0], dtype=z.dtype,\r\n                                                 device=z.device)            \r\n            time = t.unsqueeze(-1).unsqueeze(-1)\r\n            noise_t = self.get_noise(time, self.beta_min, self.beta_max,\r\n                                cumulative=False)\r\n\r\n            kappa_t_h = self.get_kappa(t, h, noise_t) \r\n            omega_t_h = self.get_omega(t, h, noise_t)\r\n            sigma_t_h = self.get_sigma(t-h, t)\r\n \r\n            es = estimator(spk, xt, mask, mu, t)\r\n\r\n            dxt = ((0.5+omega_t_h)*(xt - mu) + (1+kappa_t_h) * es)\r\n            dxt_stoc = torch.randn(z.shape, dtype=z.dtype, device=z.device,\r\n                                   requires_grad=False)\r\n            dxt_stoc = dxt_stoc * sigma_t_h\r\n\r\n            dxt = dxt * noise_t * h + dxt_stoc\r\n            xt = (xt + dxt) * mask\r\n        return xt\r\n\r\n\r\nclass GradRaw:\r\n\r\n    def __init__(self, beta_min=0.05, beta_max=20):\r\n        self.beta_min = beta_min\r\n        self.beta_max = beta_max\r\n\r\n    def get_noise(self, t, beta_init, beta_term, cumulative=False):\r\n        if cumulative:\r\n            noise = beta_init*t + 0.5*(beta_term - beta_init)*(t**2)\r\n        else:\r\n            noise = beta_init + (beta_term - beta_init)*t\r\n        return noise\r\n    \r\n    @torch.no_grad()\r\n    def reverse_diffusion(self, estimator, spk, z, mask, mu, n_timesteps, stoc=False):\r\n        print(\"use grad-raw reverse\")\r\n        h = 1.0 / n_timesteps\r\n        xt = z * mask\r\n        for i in range(n_timesteps):\r\n            t = (1.0 - (i + 0.5)*h) * \\\r\n                torch.ones(z.shape[0], dtype=z.dtype, device=z.device)\r\n            time = t.unsqueeze(-1).unsqueeze(-1)\r\n            noise_t = self.get_noise(time, self.beta_min, self.beta_max,\r\n                                cumulative=False)\r\n            if stoc:  # adds stochastic term\r\n                dxt_det = 0.5 * (mu - xt) - estimator(spk, xt, mask, mu, t)\r\n                dxt_det = dxt_det * noise_t * h\r\n                dxt_stoc = torch.randn(z.shape, dtype=z.dtype, device=z.device,\r\n                                       requires_grad=False)\r\n                dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h)\r\n                dxt = dxt_det + dxt_stoc\r\n            else:\r\n                dxt = 0.5 * (mu - xt - estimator(spk, xt, mask, mu, t))\r\n                dxt = dxt * noise_t * h\r\n            xt = (xt - dxt) * mask\r\n        return xt\r\n"
  },
  {
    "path": "grad/ssim.py",
    "content": "\"\"\"\nAdapted from https://github.com/Po-Hsun-Su/pytorch-ssim\n\"\"\"\nimport torch\nimport torch.nn.functional as F\nfrom torch.autograd import Variable\nfrom math import exp\n\n\ndef gaussian(window_size, sigma):\n    gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])\n    return gauss / gauss.sum()\n\n\ndef create_window(window_size, channel):\n    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)\n    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)\n    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())\n    return window\n\n\ndef _ssim(img1, img2, window, window_size, channel, size_average=True):\n    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)\n    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)\n\n    mu1_sq = mu1.pow(2)\n    mu2_sq = mu2.pow(2)\n    mu1_mu2 = mu1 * mu2\n\n    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq\n    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq\n    sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2\n\n    C1 = 0.01 ** 2\n    C2 = 0.03 ** 2\n\n    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))\n\n    if size_average:\n        return ssim_map.mean()\n    else:\n        return ssim_map.mean(1)\n\n\nclass SSIM(torch.nn.Module):\n    def __init__(self, window_size=11, size_average=True):\n        super(SSIM, self).__init__()\n        self.window_size = window_size\n        self.size_average = size_average\n        self.channel = 1\n        self.window = create_window(window_size, self.channel)\n\n    def forward(self, fake, real, mask, bias=6.0):\n        fake = fake[:, None, :, :] + bias  # [B, 1, T, 80]\n        real = real[:, None, :, :] + bias  # [B, 1, T, 80]\n        self.window = self.window.to(dtype=fake.dtype, device=fake.device)\n        loss = 1 - _ssim(fake, real, self.window, self.window_size, self.channel, self.size_average)\n        loss = (loss * mask).sum() / mask.sum()\n        return loss\n"
  },
  {
    "path": "grad/utils.py",
    "content": "import torch\nimport numpy as np\nimport inspect\n\n\ndef sequence_mask(length, max_length=None):\n    if max_length is None:\n        max_length = length.max()\n    x = torch.arange(int(max_length), dtype=length.dtype, device=length.device)\n    return x.unsqueeze(0) < length.unsqueeze(1)\n\n\ndef fix_len_compatibility(length, num_downsamplings_in_unet=2):\n    while True:\n        if length % (2**num_downsamplings_in_unet) == 0:\n            return length\n        length += 1\n\n\ndef convert_pad_shape(pad_shape):\n    l = pad_shape[::-1]\n    pad_shape = [item for sublist in l for item in sublist]\n    return pad_shape\n\n\ndef generate_path(duration, mask):\n    device = duration.device\n\n    b, t_x, t_y = mask.shape\n    cum_duration = torch.cumsum(duration, 1)\n    path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)\n\n    cum_duration_flat = cum_duration.view(b * t_x)\n    path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)\n    path = path.view(b, t_x, t_y)\n    path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], \n                                          [1, 0], [0, 0]]))[:, :-1]\n    path = path * mask\n    return path\n\n\ndef duration_loss(logw, logw_, lengths):\n    loss = torch.sum((logw - logw_)**2) / torch.sum(lengths)\n    return loss\n\n\nf0_bin = 256\nf0_max = 1100.0\nf0_min = 50.0\nf0_mel_min = 1127 * np.log(1 + f0_min / 700)\nf0_mel_max = 1127 * np.log(1 + f0_max / 700)\n\n\ndef f0_to_coarse(f0):\n    is_torch = isinstance(f0, torch.Tensor)\n    f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * \\\n        np.log(1 + f0 / 700)\n    f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * \\\n        (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1\n\n    f0_mel[f0_mel <= 1] = 1\n    f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1\n    f0_coarse = (\n        f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(np.int)\n    assert f0_coarse.max() <= 255 and f0_coarse.min(\n    ) >= 1, (f0_coarse.max(), f0_coarse.min())\n    return f0_coarse\n\n\ndef rand_ids_segments(lengths, segment_size=200):\n    b = lengths.shape[0]\n    ids_str_max = lengths - segment_size\n    ids_str = (torch.rand([b]).to(device=lengths.device) * ids_str_max).to(dtype=torch.long)\n    return ids_str\n\n\ndef slice_segments(x, ids_str, segment_size=200):\n    ret = torch.zeros_like(x[:, :, :segment_size])\n    for i in range(x.size(0)):\n        idx_str = ids_str[i]\n        idx_end = idx_str + segment_size\n        ret[i] = x[i, :, idx_str:idx_end]\n    return ret\n\n\ndef retrieve_name(var):\n    for fi in reversed(inspect.stack()):\n        names = [var_name for var_name,\n                 var_val in fi.frame.f_locals.items() if var_val is var]\n        if len(names) > 0:\n            return names[0]\n\n\nDebug_Enable = True\n\n\ndef debug_shapes(var):\n    if Debug_Enable:\n        print(retrieve_name(var), var.shape)\n"
  },
  {
    "path": "grad_extend/data.py",
    "content": "import os\nimport random\nimport numpy as np\n\nimport torch\n\nfrom grad.utils import fix_len_compatibility\nfrom grad_extend.utils import parse_filelist\n\n\nclass TextMelSpeakerDataset(torch.utils.data.Dataset):\n    def __init__(self, filelist_path):\n        super().__init__()\n        self.filelist = parse_filelist(filelist_path, split_char='|')\n        self._filter()\n        print(f'----------{len(self.filelist)}----------')\n\n    def _filter(self):\n        items_new = []\n        # segment = 200\n        items_min = 250  # 10ms * 250 = 2.5 S\n        items_max = 500  # 10ms * 400 = 5.0 S\n        for mel, vec, pit, spk in self.filelist:\n            if not os.path.isfile(mel):\n                continue\n            if not os.path.isfile(vec):\n                continue\n            if not os.path.isfile(pit):\n                continue\n            if not os.path.isfile(spk):\n                continue\n            temp = np.load(pit)\n            usel = int(temp.shape[0] - 1)  # useful length\n            if (usel < items_min):\n                continue\n            if (usel >= items_max):\n                usel = items_max\n            items_new.append([mel, vec, pit, spk, usel])\n        self.filelist = items_new\n\n    def get_triplet(self, item):\n        # print(item)\n        mel = item[0]\n        vec = item[1]\n        pit = item[2]\n        spk = item[3]\n        use = item[4]\n\n        mel = torch.load(mel)\n        vec = np.load(vec)\n        vec = np.repeat(vec, 2, 0)  # 320 VEC -> 160 * 2\n        pit = np.load(pit)\n        spk = np.load(spk)\n\n        vec = torch.FloatTensor(vec)\n        pit = torch.FloatTensor(pit)\n        spk = torch.FloatTensor(spk)\n\n        vec = vec + torch.randn_like(vec)  # Perturbation\n\n        len_vec = vec.size()[0] - 2 # for safe\n        len_pit = pit.size()[0]\n        len_min = min(len_pit, len_vec)\n\n        mel = mel[:, :len_min]\n        vec = vec[:len_min, :]\n        pit = pit[:len_min]\n\n        if len_min > use:\n            max_frame_start = vec.size(0) - use - 1\n            frame_start = random.randint(0, max_frame_start)\n            frame_end = frame_start + use\n\n            mel = mel[:, frame_start:frame_end]\n            vec = vec[frame_start:frame_end, :]\n            pit = pit[frame_start:frame_end]\n        # print(mel.shape)\n        # print(vec.shape)\n        # print(pit.shape)\n        # print(spk.shape)\n        return (mel, vec, pit, spk)\n\n    def __getitem__(self, index):\n        mel, vec, pit, spk = self.get_triplet(self.filelist[index])\n        item = {'mel': mel, 'vec': vec, 'pit': pit, 'spk': spk}\n        return item\n\n    def __len__(self):\n        return len(self.filelist)\n\n    def sample_test_batch(self, size):\n        idx = np.random.choice(range(len(self)), size=size, replace=False)\n        test_batch = []\n        for index in idx:\n            test_batch.append(self.__getitem__(index))\n        return test_batch\n\n\nclass TextMelSpeakerBatchCollate(object):\n    # mel: [freq, length]\n    # vec: [len, 256]\n    # pit: [len]\n    # spk: [256]\n    def __call__(self, batch):\n        B = len(batch)\n        mel_max_length = max([item['mel'].shape[-1] for item in batch])\n        max_length = fix_len_compatibility(mel_max_length)\n\n        d_mel = batch[0]['mel'].shape[0]\n        d_vec = batch[0]['vec'].shape[1]\n        d_spk = batch[0]['spk'].shape[0]\n        # print(\"d_mel\", d_mel)\n        # print(\"d_vec\", d_vec)\n        # print(\"d_spk\", d_spk)\n        mel = torch.zeros((B, d_mel, max_length), dtype=torch.float32)\n        vec = torch.zeros((B, max_length, d_vec), dtype=torch.float32)\n        pit = torch.zeros((B, max_length), dtype=torch.float32)\n        spk = torch.zeros((B, d_spk), dtype=torch.float32)\n        lengths = torch.LongTensor(B)\n\n        for i, item in enumerate(batch):\n            y_, x_, p_, s_ = item['mel'], item['vec'], item['pit'], item['spk']\n\n            mel[i, :, :y_.shape[1]] = y_\n            vec[i, :x_.shape[0], :] = x_\n            pit[i, :p_.shape[0]] = p_\n            spk[i] = s_\n\n            lengths[i] = y_.shape[1]\n        # print(\"lengths\", lengths.shape)\n        # print(\"vec\", vec.shape)\n        # print(\"pit\", pit.shape)\n        # print(\"spk\", spk.shape)\n        # print(\"mel\", mel.shape)\n        return {'lengths': lengths, 'vec': vec, 'pit': pit, 'spk': spk, 'mel': mel}\n"
  },
  {
    "path": "grad_extend/train.py",
    "content": "import os\nimport torch\nimport numpy as np\n\nfrom torch.utils.data import DataLoader\nfrom torch.utils.tensorboard import SummaryWriter\n\nfrom tqdm import tqdm\nfrom grad_extend.data import TextMelSpeakerDataset, TextMelSpeakerBatchCollate\nfrom grad_extend.utils import plot_tensor, save_plot, load_model, print_error\nfrom grad.utils import fix_len_compatibility\nfrom grad.model import GradTTS\n\n\n# 200 frames\nout_size = fix_len_compatibility(200)\n\n\ndef train(hps, chkpt_path=None):\n\n    print('Initializing logger...')\n    logger = SummaryWriter(log_dir=hps.train.log_dir)\n\n    print('Initializing data loaders...')\n    train_dataset = TextMelSpeakerDataset(hps.train.train_files)\n    batch_collate = TextMelSpeakerBatchCollate()\n    loader = DataLoader(dataset=train_dataset,\n                        batch_size=hps.train.batch_size,\n                        collate_fn=batch_collate,\n                        drop_last=True,\n                        num_workers=8,\n                        shuffle=True)\n    test_dataset = TextMelSpeakerDataset(hps.train.valid_files)\n\n    print('Initializing model...')\n    model = GradTTS(hps.grad.n_mels, hps.grad.n_vecs, hps.grad.n_pits, hps.grad.n_spks, hps.grad.n_embs,\n                    hps.grad.n_enc_channels, hps.grad.filter_channels,\n                    hps.grad.dec_dim, hps.grad.beta_min, hps.grad.beta_max, hps.grad.pe_scale).cuda()\n    print('Number of encoder parameters = %.2fm' % (model.encoder.nparams/1e6))\n    print('Number of decoder parameters = %.2fm' % (model.decoder.nparams/1e6))\n\n    # Load Pretrain\n    if os.path.isfile(hps.train.pretrain):\n        print(\"Start from Grad_SVC pretrain model: %s\" % hps.train.pretrain)\n        checkpoint = torch.load(hps.train.pretrain, map_location='cpu')\n        load_model(model, checkpoint['model'])\n        hps.train.learning_rate = 2e-5\n        # fine_tune\n        model.fine_tune()\n    else:\n        print_error(10 * '~' + \"No Pretrain Model\" + 10 * '~')\n\n    print('Initializing optimizer...')\n    optim = torch.optim.Adam(params=model.parameters(), lr=hps.train.learning_rate)\n\n    initepoch = 1\n    iteration = 0\n\n    # Load Continue\n    if chkpt_path is not None:\n        print(\"Resuming from checkpoint: %s\" % chkpt_path)\n        checkpoint = torch.load(chkpt_path, map_location='cpu')\n        model.load_state_dict(checkpoint['model'])\n        optim.load_state_dict(checkpoint['optim'])\n        initepoch = checkpoint['epoch']\n        iteration = checkpoint['steps']\n\n    print('Logging test batch...')\n    test_batch = test_dataset.sample_test_batch(size=hps.train.test_size)\n    for i, item in enumerate(test_batch):\n        mel = item['mel']\n        logger.add_image(f'image_{i}/ground_truth', plot_tensor(mel.squeeze()),\n                         global_step=0, dataformats='HWC')\n        save_plot(mel.squeeze(), f'{hps.train.log_dir}/original_{i}.png')\n\n    print('Start training...')\n    skip_diff_train = True\n    if initepoch >= hps.train.fast_epochs:\n        skip_diff_train = False\n    for epoch in range(initepoch, hps.train.full_epochs + 1):\n\n        if epoch % hps.train.test_step == 0:\n            model.eval()\n            print('Synthesis...')\n\n            with torch.no_grad():\n                for i, item in enumerate(test_batch):\n                    l_vec = item['vec'].shape[0]\n                    d_vec = item['vec'].shape[1]\n\n                    lengths_fix = fix_len_compatibility(l_vec)\n                    lengths = torch.LongTensor([l_vec]).cuda()\n\n                    vec = torch.zeros((1, lengths_fix, d_vec), dtype=torch.float32).cuda()\n                    pit = torch.zeros((1, lengths_fix), dtype=torch.float32).cuda()\n                    spk = item['spk'].to(torch.float32).unsqueeze(0).cuda()\n                    vec[0, :l_vec, :] = item['vec']\n                    pit[0, :l_vec] = item['pit']\n\n                    y_enc, y_dec = model(lengths, vec, pit, spk, n_timesteps=50)\n\n                    logger.add_image(f'image_{i}/generated_enc',\n                                     plot_tensor(y_enc.squeeze().cpu()),\n                                     global_step=iteration, dataformats='HWC')\n                    logger.add_image(f'image_{i}/generated_dec',\n                                     plot_tensor(y_dec.squeeze().cpu()),\n                                     global_step=iteration, dataformats='HWC')\n                    save_plot(y_enc.squeeze().cpu(), \n                              f'{hps.train.log_dir}/generated_enc_{i}.png')\n                    save_plot(y_dec.squeeze().cpu(), \n                              f'{hps.train.log_dir}/generated_dec_{i}.png')\n\n        model.train()\n\n        prior_losses = []\n        diff_losses = []\n        mel_losses = []\n        spk_losses = []\n        with tqdm(loader, total=len(train_dataset)//hps.train.batch_size) as progress_bar:\n            for batch in progress_bar:\n                model.zero_grad()\n\n                lengths = batch['lengths'].cuda()\n                vec = batch['vec'].cuda()\n                pit = batch['pit'].cuda()\n                spk = batch['spk'].cuda()\n                mel = batch['mel'].cuda()\n\n                prior_loss, diff_loss, mel_loss, spk_loss = model.compute_loss(\n                    lengths, vec, pit, spk,\n                    mel, out_size=out_size,\n                    skip_diff=skip_diff_train)\n                loss = sum([prior_loss, diff_loss, mel_loss, spk_loss])\n                loss.backward()\n\n                enc_grad_norm = torch.nn.utils.clip_grad_norm_(model.encoder.parameters(), \n                                                            max_norm=1)\n                dec_grad_norm = torch.nn.utils.clip_grad_norm_(model.decoder.parameters(), \n                                                            max_norm=1)\n                optim.step()\n\n                logger.add_scalar('training/mel_loss', mel_loss,\n                                global_step=iteration)\n                logger.add_scalar('training/prior_loss', prior_loss,\n                                global_step=iteration)\n                logger.add_scalar('training/diffusion_loss', diff_loss,\n                                global_step=iteration)\n                logger.add_scalar('training/encoder_grad_norm', enc_grad_norm,\n                                global_step=iteration)\n                logger.add_scalar('training/decoder_grad_norm', dec_grad_norm,\n                                global_step=iteration)\n\n                msg = f'Epoch: {epoch}, iteration: {iteration} | ' \n                msg = msg + f'prior_loss: {prior_loss.item():.3f}, '\n                msg = msg + f'diff_loss: {diff_loss.item():.3f}, '\n                msg = msg + f'mel_loss: {mel_loss.item():.3f}, '\n                msg = msg + f'spk_loss: {spk_loss.item():.3f}, '\n                progress_bar.set_description(msg)\n\n                prior_losses.append(prior_loss.item())\n                diff_losses.append(diff_loss.item())\n                mel_losses.append(mel_loss.item())\n                spk_losses.append(spk_loss.item())\n                iteration += 1\n\n        msg = 'Epoch %d: ' % (epoch)\n        msg += '| spk loss = %.3f ' % np.mean(spk_losses)\n        msg += '| mel loss = %.3f ' % np.mean(mel_losses)\n        msg += '| prior loss = %.3f ' % np.mean(prior_losses)\n        msg += '| diffusion loss = %.3f\\n' % np.mean(diff_losses)\n        with open(f'{hps.train.log_dir}/train.log', 'a') as f:\n            f.write(msg)\n        # if (np.mean(prior_losses) < 1.05):\n        #     skip_diff_train = False\n        if epoch > hps.train.fast_epochs:\n            skip_diff_train = False\n        if epoch % hps.train.save_step > 0:\n            continue\n\n        save_path = f\"{hps.train.log_dir}/grad_svc_{epoch}.pt\"\n        torch.save({\n            'model': model.state_dict(),\n            'optim': optim.state_dict(),\n            'epoch': epoch,\n            'steps': iteration,\n\n        }, save_path)\n        print(\"Saved checkpoint to: %s\" % save_path)\n"
  },
  {
    "path": "grad_extend/utils.py",
    "content": "import os\nimport glob\nimport numpy as np\nimport matplotlib.pyplot as plt\n\nimport torch\n\n\ndef parse_filelist(filelist_path, split_char=\"|\"):\n    with open(filelist_path, encoding='utf-8') as f:\n        filepaths_and_text = [line.strip().split(split_char) for line in f]\n    return filepaths_and_text\n\n\ndef load_model(model, saved_state_dict):\n    state_dict = model.state_dict()\n    new_state_dict = {}\n    for k, v in state_dict.items():\n        try:\n            new_state_dict[k] = saved_state_dict[k]\n        except:\n            print(\"%s is not in the checkpoint\" % k)\n            new_state_dict[k] = v\n    model.load_state_dict(new_state_dict)\n    return model\n\n\ndef latest_checkpoint_path(dir_path, regex=\"grad_svc_*.pt\"):\n    f_list = glob.glob(os.path.join(dir_path, regex))\n    f_list.sort(key=lambda f: int(\"\".join(filter(str.isdigit, f))))\n    x = f_list[-1]\n    return x\n\n\ndef load_checkpoint(logdir, model, num=None):\n    if num is None:\n        model_path = latest_checkpoint_path(logdir, regex=\"grad_svc_*.pt\")\n    else:\n        model_path = os.path.join(logdir, f\"grad_svc_{num}.pt\")\n    print(f'Loading checkpoint {model_path}...')\n    model_dict = torch.load(model_path, map_location=lambda loc, storage: loc)\n    model.load_state_dict(model_dict, strict=False)\n    return model\n\n\ndef save_figure_to_numpy(fig):\n    data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')\n    data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))\n    return data\n\n\ndef plot_tensor(tensor):\n    plt.style.use('default')\n    fig, ax = plt.subplots(figsize=(12, 3))\n    im = ax.imshow(tensor, aspect=\"auto\", origin=\"lower\", interpolation='none')\n    plt.colorbar(im, ax=ax)\n    plt.tight_layout()\n    fig.canvas.draw()\n    data = save_figure_to_numpy(fig)\n    plt.close()\n    return data\n\n\ndef save_plot(tensor, savepath):\n    plt.style.use('default')\n    fig, ax = plt.subplots(figsize=(12, 3))\n    im = ax.imshow(tensor, aspect=\"auto\", origin=\"lower\", interpolation='none')\n    plt.colorbar(im, ax=ax)\n    plt.tight_layout()\n    fig.canvas.draw()\n    plt.savefig(savepath)\n    plt.close()\n    return\n\n\ndef print_error(info):\n    print(f\"\\033[31m {info} \\033[0m\")\n"
  },
  {
    "path": "grad_pretrain/README.md",
    "content": "Path for:\n\n    gvc.pretrain.pth"
  },
  {
    "path": "gvc_export.py",
    "content": "import sys,os\nsys.path.append(os.path.dirname(os.path.abspath(__file__)))\nimport torch\nimport argparse\nfrom omegaconf import OmegaConf\nfrom grad.model import GradTTS\n\n\ndef load_model(checkpoint_path, model):\n    assert os.path.isfile(checkpoint_path)\n    checkpoint_dict = torch.load(checkpoint_path, map_location=\"cpu\")\n    saved_state_dict = checkpoint_dict[\"model\"]\n\n    state_dict = model.state_dict()\n    new_state_dict = {}\n    for k, v in state_dict.items():\n        try:\n            new_state_dict[k] = saved_state_dict[k]\n        except:\n            print(\"%s is not in the checkpoint\" % k)\n            new_state_dict[k] = v\n    model.load_state_dict(new_state_dict)\n\n\ndef main(args):\n    hps = OmegaConf.load(args.config)\n\n    print('Initializing Grad-TTS...')\n    model = GradTTS(hps.grad.n_mels, hps.grad.n_vecs, hps.grad.n_pits, hps.grad.n_spks, hps.grad.n_embs,\n                    hps.grad.n_enc_channels, hps.grad.filter_channels,\n                    hps.grad.dec_dim, hps.grad.beta_min, hps.grad.beta_max, hps.grad.pe_scale)\n    print('Number of encoder parameters = %.2fm' % (model.encoder.nparams/1e6))\n    print('Number of decoder parameters = %.2fm' % (model.decoder.nparams/1e6))\n\n    load_model(args.checkpoint_path, model)\n    torch.save({'model': model.state_dict()}, \"gvc.pth\")\n    torch.save({'model': model.state_dict()}, \"gvc.pretrain.pth\")\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-c', '--config', type=str, default='./configs/base.yaml',\n                        help=\"yaml file for config.\")\n    parser.add_argument('-p', '--checkpoint_path', type=str, required=True,\n                        help=\"path of checkpoint pt file for evaluation\")\n    args = parser.parse_args()\n\n    main(args)\n"
  },
  {
    "path": "gvc_inference.py",
    "content": "import sys,os\nsys.path.append(os.path.dirname(os.path.abspath(__file__)))\nimport torch\nimport argparse\nimport numpy as np\n\nfrom omegaconf import OmegaConf\nfrom pitch import load_csv_pitch\nfrom spec.inference import print_mel\n\nfrom grad_extend.utils import print_error\nfrom grad.utils import fix_len_compatibility\nfrom grad.model import GradTTS\nfrom bigvgan.model.generator import Generator\nfrom scipy.io.wavfile import write\n\n\ndef load_gvc_model(checkpoint_path, model):\n    assert os.path.isfile(checkpoint_path)\n    checkpoint_dict = torch.load(checkpoint_path, map_location=\"cpu\")\n    saved_state_dict = checkpoint_dict[\"model\"]\n    state_dict = model.state_dict()\n    new_state_dict = {}\n    for k, v in state_dict.items():\n        try:\n            new_state_dict[k] = saved_state_dict[k]\n        except:\n            print(\"%s is not in the checkpoint\" % k)\n            new_state_dict[k] = v\n    model.load_state_dict(new_state_dict)\n    return model\n\n\ndef load_bigv_model(checkpoint_path, model):\n    assert os.path.isfile(checkpoint_path)\n    checkpoint_dict = torch.load(checkpoint_path, map_location=\"cpu\")\n    saved_state_dict = checkpoint_dict[\"model_g\"]\n    state_dict = model.state_dict()\n    new_state_dict = {}\n    for k, v in state_dict.items():\n        try:\n            new_state_dict[k] = saved_state_dict[k]\n        except:\n            print(\"%s is not in the checkpoint\" % k)\n            new_state_dict[k] = v\n    model.load_state_dict(new_state_dict)\n    return model\n\n\n@torch.no_grad()\ndef gvc_main(device, model, _vec, _pit, spk, rature=1.015):\n    l_vec = _vec.shape[0]\n    d_vec = _vec.shape[1]\n    lengths_fix = fix_len_compatibility(l_vec)\n    lengths = torch.LongTensor([l_vec]).to(device)\n    vec = torch.zeros((1, lengths_fix, d_vec), dtype=torch.float32).to(device)\n    pit = torch.zeros((1, lengths_fix), dtype=torch.float32).to(device)\n    vec[0, :l_vec, :] = _vec\n    pit[0, :l_vec] = _pit\n    y_enc, y_dec = model(lengths, vec, pit, spk, n_timesteps=20, temperature=rature)\n    y_dec = y_dec.squeeze(0)\n    y_dec = y_dec[:, :l_vec]\n    return y_dec\n\n\ndef main(args):\n\n    if (args.vec == None):\n        args.vec = \"gvc_tmp.vec.npy\"\n        print(\n            f\"Auto run : python hubert/inference.py -w {args.wave} -v {args.vec}\")\n        os.system(f\"python hubert/inference.py -w {args.wave} -v {args.vec}\")\n\n    if (args.pit == None):\n        args.pit = \"gvc_tmp.pit.csv\"\n        print(\n            f\"Auto run : python pitch/inference.py -w {args.wave} -p {args.pit}\")\n        os.system(f\"python pitch/inference.py -w {args.wave} -p {args.pit}\")\n\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    hps = OmegaConf.load(args.config)\n\n    print('Initializing Grad-TTS...')\n    model = GradTTS(hps.grad.n_mels, hps.grad.n_vecs, hps.grad.n_pits, hps.grad.n_spks, hps.grad.n_embs,\n                    hps.grad.n_enc_channels, hps.grad.filter_channels,\n                    hps.grad.dec_dim, hps.grad.beta_min, hps.grad.beta_max, hps.grad.pe_scale)\n    print('Number of encoder parameters = %.2fm' % (model.encoder.nparams/1e6))\n    print('Number of decoder parameters = %.2fm' % (model.decoder.nparams/1e6))\n    print_error(f'Temperature: {args.rature}')\n\n    load_gvc_model(args.model, model)\n    model.eval()\n    model.to(device)\n\n    spk = np.load(args.spk)\n    spk = torch.FloatTensor(spk)\n\n    vec = np.load(args.vec)\n    vec = np.repeat(vec, 2, 0)\n    vec = torch.FloatTensor(vec)\n\n    pit = load_csv_pitch(args.pit)\n    pit = np.array(pit)\n    pit = pit * 2 ** (args.shift / 12)\n    pit = torch.FloatTensor(pit)\n\n    len_pit = pit.size()[0]\n    len_vec = vec.size()[0]\n    len_min = min(len_pit, len_vec)\n    pit = pit[:len_min]\n    vec = vec[:len_min, :]\n\n    with torch.no_grad():\n        spk = spk.unsqueeze(0).to(device)\n\n        all_frame = len_min\n        hop_frame = 8\n        out_chunk = 2400  # 24 S\n        out_index = 0\n        mel = None\n\n        while (out_index < all_frame):\n            if (out_index == 0):  # start frame\n                cut_s = 0\n                cut_s_out = 0\n            else:\n                cut_s = out_index - hop_frame\n                cut_s_out = hop_frame\n\n            if (out_index + out_chunk + hop_frame > all_frame):  # end frame\n                cut_e = all_frame\n                cut_e_out = -1\n            else:\n                cut_e = out_index + out_chunk + hop_frame\n                cut_e_out = -1 * hop_frame\n\n            sub_vec = vec[cut_s:cut_e, :].to(device)\n            sub_pit = pit[cut_s:cut_e].to(device)\n\n            sub_out = gvc_main(device, model, sub_vec, sub_pit, spk, args.rature)\n            sub_out = sub_out[:, cut_s_out:cut_e_out]\n \n            out_index = out_index + out_chunk\n            if mel == None:\n                mel = sub_out\n            else:\n                mel = torch.cat((mel, sub_out), -1)\n            if cut_e == all_frame:\n                break\n\n    print_error(10 * '~' + \"mel has been generated\" + 10 * '~')\n    print_mel(mel, \"gvc_out.mel.png\")\n    del model\n    del hps\n    del spk\n    del vec\n    del sub_vec\n    del sub_pit\n    del sub_out\n\n    hps = OmegaConf.load(args.config_bigv)\n    model = Generator(hps)\n    load_bigv_model(args.model_bigv, model)\n    model.eval()\n    model.to(device)\n\n    len_pit = pit.size()[0]\n    len_mel = mel.size()[1]\n    len_min = min(len_pit, len_mel)\n    pit = pit[:len_min]\n    mel = mel[:, :len_min]\n\n    with torch.no_grad():\n        mel = mel.unsqueeze(0).to(device)\n        pit = pit.unsqueeze(0).to(device)\n        audio = model.inference(mel, pit)\n        audio = audio.cpu().detach().numpy()\n\n        pitwav = model.pitch2wav(pit)\n        pitwav = pitwav.cpu().detach().numpy()\n\n    print_error(10 * '~' + \"wav has been generated\" + 10 * '~')\n    write(\"gvc_out.wav\", hps.audio.sampling_rate, audio)\n    write(\"gvc_pitch.wav\", hps.audio.sampling_rate, pitwav)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--config', type=str, default='./configs/base.yaml',\n                        help=\"yaml file for config.\")\n    parser.add_argument('--model', type=str, required=True,\n                        help=\"path of model for evaluation\")\n    parser.add_argument('--wave', type=str, required=True,\n                        help=\"Path of raw audio.\")\n    parser.add_argument('--spk', type=str, required=True,\n                        help=\"Path of speaker.\")\n    parser.add_argument('--vec', type=str,\n                        help=\"Path of hubert vector.\")\n    parser.add_argument('--pit', type=str,\n                        help=\"Path of pitch csv file.\")\n    parser.add_argument('--shift', type=int, default=0,\n                        help=\"Pitch shift key.\")\n    parser.add_argument('--rature', type=float, default=1.015,\n                        help=\"Pitch shift key.\")\n\n    args = parser.parse_args()\n\n    args.config_bigv = \"./bigvgan/configs/nsf_bigvgan.yaml\"\n    args.model_bigv = \"./bigvgan_pretrain/nsf_bigvgan_pretrain_32K.pth\"\n\n    assert os.path.isfile(args.config)\n    assert os.path.isfile(args.model)\n\n    assert os.path.isfile(args.config_bigv)\n    assert os.path.isfile(args.model_bigv)\n\n    main(args)\n"
  },
  {
    "path": "gvc_trainer.py",
    "content": "import sys,os\nsys.path.append(os.path.dirname(os.path.abspath(__file__)))\nimport argparse\nimport torch\nimport numpy as np\n\nfrom omegaconf import OmegaConf\nfrom grad_extend.train import train\n\ntorch.backends.cudnn.benchmark = True\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-c', '--config', type=str, default='./configs/base.yaml',\n                        help=\"yaml file for configuration\")\n    parser.add_argument('-p', '--checkpoint_path', type=str, default=None,\n                        help=\"path of checkpoint pt file to resume training\")\n    args = parser.parse_args()\n\n    assert torch.cuda.is_available()\n    print('Numbers of GPU :', torch.cuda.device_count())\n\n    hps = OmegaConf.load(args.config)\n\n    np.random.seed(hps.train.seed)\n    torch.manual_seed(hps.train.seed)\n    torch.cuda.manual_seed(hps.train.seed)\n\n    train(hps, args.checkpoint_path)\n"
  },
  {
    "path": "hubert/__init__.py",
    "content": ""
  },
  {
    "path": "hubert/hubert_model.py",
    "content": "import copy\nimport random\nfrom typing import Optional, Tuple\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as t_func\n\n\nclass Hubert(nn.Module):\n    def __init__(self, num_label_embeddings: int = 100, mask: bool = True):\n        super().__init__()\n        self._mask = mask\n        self.feature_extractor = FeatureExtractor()\n        self.feature_projection = FeatureProjection()\n        self.positional_embedding = PositionalConvEmbedding()\n        self.norm = nn.LayerNorm(768)\n        self.dropout = nn.Dropout(0.1)\n        self.encoder = TransformerEncoder(\n            nn.TransformerEncoderLayer(\n                768, 12, 3072, activation=\"gelu\", batch_first=True\n            ),\n            12,\n        )\n        self.proj = nn.Linear(768, 256)\n\n        self.masked_spec_embed = nn.Parameter(torch.FloatTensor(768).uniform_())\n        self.label_embedding = nn.Embedding(num_label_embeddings, 256)\n\n    def mask(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n        mask = None\n        if self.training and self._mask:\n            mask = _compute_mask((x.size(0), x.size(1)), 0.8, 10, x.device, 2)\n            x[mask] = self.masked_spec_embed.to(x.dtype)\n        return x, mask\n\n    def encode(\n            self, x: torch.Tensor, layer: Optional[int] = None\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        x = self.feature_extractor(x)\n        x = self.feature_projection(x.transpose(1, 2))\n        x, mask = self.mask(x)\n        x = x + self.positional_embedding(x)\n        x = self.dropout(self.norm(x))\n        x = self.encoder(x, output_layer=layer)\n        return x, mask\n\n    def logits(self, x: torch.Tensor) -> torch.Tensor:\n        logits = torch.cosine_similarity(\n            x.unsqueeze(2),\n            self.label_embedding.weight.unsqueeze(0).unsqueeze(0),\n            dim=-1,\n        )\n        return logits / 0.1\n\n    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n        x, mask = self.encode(x)\n        x = self.proj(x)\n        logits = self.logits(x)\n        return logits, mask\n\n\nclass HubertSoft(Hubert):\n    def __init__(self):\n        super().__init__()\n\n    @torch.inference_mode()\n    def units(self, wav: torch.Tensor) -> torch.Tensor:\n        wav = t_func.pad(wav, ((400 - 320) // 2, (400 - 320) // 2))\n        x, _ = self.encode(wav)\n        return self.proj(x)\n\n\nclass FeatureExtractor(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv0 = nn.Conv1d(1, 512, 10, 5, bias=False)\n        self.norm0 = nn.GroupNorm(512, 512)\n        self.conv1 = nn.Conv1d(512, 512, 3, 2, bias=False)\n        self.conv2 = nn.Conv1d(512, 512, 3, 2, bias=False)\n        self.conv3 = nn.Conv1d(512, 512, 3, 2, bias=False)\n        self.conv4 = nn.Conv1d(512, 512, 3, 2, bias=False)\n        self.conv5 = nn.Conv1d(512, 512, 2, 2, bias=False)\n        self.conv6 = nn.Conv1d(512, 512, 2, 2, bias=False)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = t_func.gelu(self.norm0(self.conv0(x)))\n        x = t_func.gelu(self.conv1(x))\n        x = t_func.gelu(self.conv2(x))\n        x = t_func.gelu(self.conv3(x))\n        x = t_func.gelu(self.conv4(x))\n        x = t_func.gelu(self.conv5(x))\n        x = t_func.gelu(self.conv6(x))\n        return x\n\n\nclass FeatureProjection(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.norm = nn.LayerNorm(512)\n        self.projection = nn.Linear(512, 768)\n        self.dropout = nn.Dropout(0.1)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.norm(x)\n        x = self.projection(x)\n        x = self.dropout(x)\n        return x\n\n\nclass PositionalConvEmbedding(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv = nn.Conv1d(\n            768,\n            768,\n            kernel_size=128,\n            padding=128 // 2,\n            groups=16,\n        )\n        self.conv = nn.utils.weight_norm(self.conv, name=\"weight\", dim=2)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.conv(x.transpose(1, 2))\n        x = t_func.gelu(x[:, :, :-1])\n        return x.transpose(1, 2)\n\n\nclass TransformerEncoder(nn.Module):\n    def __init__(\n            self, encoder_layer: nn.TransformerEncoderLayer, num_layers: int\n    ) -> None:\n        super(TransformerEncoder, self).__init__()\n        self.layers = nn.ModuleList(\n            [copy.deepcopy(encoder_layer) for _ in range(num_layers)]\n        )\n        self.num_layers = num_layers\n\n    def forward(\n            self,\n            src: torch.Tensor,\n            mask: torch.Tensor = None,\n            src_key_padding_mask: torch.Tensor = None,\n            output_layer: Optional[int] = None,\n    ) -> torch.Tensor:\n        output = src\n        for layer in self.layers[:output_layer]:\n            output = layer(\n                output, src_mask=mask, src_key_padding_mask=src_key_padding_mask\n            )\n        return output\n\n\ndef _compute_mask(\n        shape: Tuple[int, int],\n        mask_prob: float,\n        mask_length: int,\n        device: torch.device,\n        min_masks: int = 0,\n) -> torch.Tensor:\n    batch_size, sequence_length = shape\n\n    if mask_length < 1:\n        raise ValueError(\"`mask_length` has to be bigger than 0.\")\n\n    if mask_length > sequence_length:\n        raise ValueError(\n            f\"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`\"\n        )\n\n    # compute number of masked spans in batch\n    num_masked_spans = int(mask_prob * sequence_length / mask_length + random.random())\n    num_masked_spans = max(num_masked_spans, min_masks)\n\n    # make sure num masked indices <= sequence_length\n    if num_masked_spans * mask_length > sequence_length:\n        num_masked_spans = sequence_length // mask_length\n\n    # SpecAugment mask to fill\n    mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool)\n\n    # uniform distribution to sample from, make sure that offset samples are < sequence_length\n    uniform_dist = torch.ones(\n        (batch_size, sequence_length - (mask_length - 1)), device=device\n    )\n\n    # get random indices to mask\n    mask_indices = torch.multinomial(uniform_dist, num_masked_spans)\n\n    # expand masked indices to masked spans\n    mask_indices = (\n        mask_indices.unsqueeze(dim=-1)\n        .expand((batch_size, num_masked_spans, mask_length))\n        .reshape(batch_size, num_masked_spans * mask_length)\n    )\n    offsets = (\n        torch.arange(mask_length, device=device)[None, None, :]\n        .expand((batch_size, num_masked_spans, mask_length))\n        .reshape(batch_size, num_masked_spans * mask_length)\n    )\n    mask_idxs = mask_indices + offsets\n\n    # scatter indices to mask\n    mask = mask.scatter(1, mask_idxs, True)\n\n    return mask\n\n\ndef consume_prefix(state_dict, prefix: str) -> None:\n    keys = sorted(state_dict.keys())\n    for key in keys:\n        if key.startswith(prefix):\n            newkey = key[len(prefix):]\n            state_dict[newkey] = state_dict.pop(key)\n\n\ndef hubert_soft(\n        path: str,\n) -> HubertSoft:\n    r\"\"\"HuBERT-Soft from `\"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion\"`.\n    Args:\n        path (str): path of a pretrained model\n    \"\"\"\n    hubert = HubertSoft()\n    checkpoint = torch.load(path)\n    consume_prefix(checkpoint, \"module.\")\n    hubert.load_state_dict(checkpoint)\n    hubert.eval()\n    return hubert\n"
  },
  {
    "path": "hubert/inference.py",
    "content": "import sys,os\nsys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))\nimport numpy as np\nimport argparse\nimport torch\nimport librosa\n\nfrom hubert import hubert_model\n\n\ndef load_audio(file: str, sr: int = 16000):\n    x, sr = librosa.load(file, sr=sr)\n    return x\n\n\ndef load_model(path, device):\n    model = hubert_model.hubert_soft(path)\n    model.eval()\n    if not (device == \"cpu\"):\n        model.half()\n    model.to(device)\n    return model\n\n\ndef pred_vec(model, wavPath, vecPath, device):\n    audio = load_audio(wavPath)\n    audln = audio.shape[0]\n    vec_a = []\n    idx_s = 0\n    while (idx_s + 20 * 16000 < audln):\n        feats = audio[idx_s:idx_s + 20 * 16000]\n        feats = torch.from_numpy(feats).to(device)\n        feats = feats[None, None, :]\n        if not (device == \"cpu\"):\n            feats = feats.half()\n        with torch.no_grad():\n            vec = model.units(feats).squeeze().data.cpu().float().numpy()\n            vec_a.extend(vec)\n        idx_s = idx_s + 20 * 16000\n    if (idx_s < audln):\n        feats = audio[idx_s:audln]\n        feats = torch.from_numpy(feats).to(device)\n        feats = feats[None, None, :]\n        if not (device == \"cpu\"):\n            feats = feats.half()\n        with torch.no_grad():\n            vec = model.units(feats).squeeze().data.cpu().float().numpy()\n            # print(vec.shape)   # [length, dim=256] hop=320\n            vec_a.extend(vec)\n    np.save(vecPath, vec_a, allow_pickle=False)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"-w\", \"--wav\", help=\"wav\", dest=\"wav\")\n    parser.add_argument(\"-v\", \"--vec\", help=\"vec\", dest=\"vec\")\n    args = parser.parse_args()\n    print(args.wav)\n    print(args.vec)\n\n    wavPath = args.wav\n    vecPath = args.vec\n\n    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n    hubert = load_model(os.path.join(\n        \"hubert_pretrain\", \"hubert-soft-0d54a1f4.pt\"), device)\n    pred_vec(hubert, wavPath, vecPath, device)\n"
  },
  {
    "path": "hubert_pretrain/README.md",
    "content": "Path for:\n\n    hubert-soft-0d54a1f4.pt"
  },
  {
    "path": "pitch/__init__.py",
    "content": "from .inference import load_csv_pitch"
  },
  {
    "path": "pitch/inference.py",
    "content": "import sys,os\nsys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))\nimport librosa\nimport argparse\nimport numpy as np\nimport parselmouth\n# pip install praat-parselmouth\n\ndef compute_f0_mouth(path):\n    x, sr = librosa.load(path, sr=16000)\n    assert sr == 16000\n    lpad = 1024 // 160\n    rpad = lpad\n    f0 = parselmouth.Sound(x, sr).to_pitch_ac(\n        time_step=160 / sr,\n        voicing_threshold=0.5,\n        pitch_floor=30,\n        pitch_ceiling=1000).selected_array['frequency']\n    f0 = np.pad(f0, [[lpad, rpad]], mode='constant')\n    return f0\n\n\ndef compute_f0_crepe(filename):\n    import torch\n    import torchcrepe\n    \n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    audio, sr = librosa.load(filename, sr=16000)\n    assert sr == 16000\n    audio = torch.tensor(np.copy(audio))[None]\n    audio = audio + torch.randn_like(audio) * 0.001\n    # Here we'll use a 20 millisecond hop length\n    hop_length = 320\n    fmin = 50\n    fmax = 1000\n    model = \"full\"\n    batch_size = 512\n    pitch = torchcrepe.predict(\n        audio,\n        sr,\n        hop_length,\n        fmin,\n        fmax,\n        model,\n        batch_size=batch_size,\n        device=device,\n        return_periodicity=False,\n    )\n    pitch = np.repeat(pitch, 2, -1)  # 320 -> 160 * 2\n    pitch = torchcrepe.filter.mean(pitch, 5)\n    pitch = pitch.squeeze(0)\n    return pitch\n\n\ndef save_csv_pitch(pitch, path):\n    with open(path, \"w\", encoding='utf-8') as pitch_file:\n        for i in range(len(pitch)):\n            t = i * 10\n            minute = t // 60000\n            seconds = (t - minute * 60000) // 1000\n            millisecond = t % 1000\n            print(\n                f\"{minute}m {seconds}s {millisecond:3d},{int(pitch[i])}\", file=pitch_file)\n\n\ndef load_csv_pitch(path):\n    pitch = []\n    with open(path, \"r\", encoding='utf-8') as pitch_file:\n        for line in pitch_file.readlines():\n            pit = line.strip().split(\",\")[-1]\n            pitch.append(int(pit))\n    return pitch\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"-w\", \"--wav\", help=\"wav\", dest=\"wav\")\n    parser.add_argument(\"-p\", \"--pit\", help=\"pit\", dest=\"pit\")  # csv for excel\n    args = parser.parse_args()\n    print(args.wav)\n    print(args.pit)\n\n    pitch = compute_f0_mouth(args.wav)\n    save_csv_pitch(pitch, args.pit)\n    #tmp = load_csv_pitch(args.pit)\n    #save_csv_pitch(tmp, \"tmp.csv\")\n"
  },
  {
    "path": "prepare/preprocess_a.py",
    "content": "import os\nimport librosa\nimport argparse\nimport numpy as np\nfrom tqdm import tqdm\nfrom concurrent.futures import ThreadPoolExecutor, as_completed\nfrom scipy.io import wavfile\n\n\ndef resample_wave(wav_in, wav_out, sample_rate):\n    wav, _ = librosa.load(wav_in, sr=sample_rate)\n    wav = wav / np.abs(wav).max() * 0.6\n    wav = wav / max(0.01, np.max(np.abs(wav))) * 32767 * 0.6\n    wavfile.write(wav_out, sample_rate, wav.astype(np.int16))\n\n\ndef process_file(file, wavPath, spks, outPath, sr):\n    if file.endswith(\".wav\"):\n        file = file[:-4]\n        resample_wave(f\"{wavPath}/{spks}/{file}.wav\", f\"{outPath}/{spks}/{file}.wav\", sr)\n\n\ndef process_files_with_thread_pool(wavPath, spks, outPath, sr, thread_num=None):\n    files = [f for f in os.listdir(f\"./{wavPath}/{spks}\") if f.endswith(\".wav\")]\n\n    with ThreadPoolExecutor(max_workers=thread_num) as executor:\n        futures = {executor.submit(process_file, file, wavPath, spks, outPath, sr): file for file in files}\n\n        for future in tqdm(as_completed(futures), total=len(futures), desc=f'Processing {sr} {spks}'):\n            future.result()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"-w\", \"--wav\", help=\"wav\", dest=\"wav\", required=True)\n    parser.add_argument(\"-o\", \"--out\", help=\"out\", dest=\"out\", required=True)\n    parser.add_argument(\"-s\", \"--sr\", help=\"sample rate\", dest=\"sr\", type=int, required=True)\n    parser.add_argument(\"-t\", \"--thread_count\", help=\"thread count to process, set 0 to use all cpu cores\", dest=\"thread_count\", type=int, default=1)\n\n    args = parser.parse_args()\n    print(args.wav)\n    print(args.out)\n    print(args.sr)\n\n    os.makedirs(args.out, exist_ok=True)\n    wavPath = args.wav\n    outPath = args.out\n\n    assert args.sr == 16000 or args.sr == 32000\n\n    for spks in os.listdir(wavPath):\n        if os.path.isdir(f\"./{wavPath}/{spks}\"):\n            os.makedirs(f\"./{outPath}/{spks}\", exist_ok=True)\n            if args.thread_count == 0:\n                process_num = os.cpu_count() // 2 + 1\n            else:\n                process_num = args.thread_count\n            process_files_with_thread_pool(wavPath, spks, outPath, args.sr, process_num)\n"
  },
  {
    "path": "prepare/preprocess_f0.py",
    "content": "import os\nimport numpy as np\nimport librosa\nimport argparse\nimport parselmouth\n# pip install praat-parselmouth\nfrom tqdm import tqdm\nfrom concurrent.futures import ProcessPoolExecutor, as_completed\n\n\ndef compute_f0(path, save):\n    x, sr = librosa.load(path, sr=16000)\n    assert sr == 16000\n    lpad = 1024 // 160\n    rpad = lpad\n    f0 = parselmouth.Sound(x, sr).to_pitch_ac(\n        time_step=160 / sr,\n        voicing_threshold=0.5,\n        pitch_floor=30,\n        pitch_ceiling=1000).selected_array['frequency']\n    f0 = np.pad(f0, [[lpad, rpad]], mode='constant')\n    for index, pitch in enumerate(f0):\n        f0[index] = round(pitch, 1)\n    np.save(save, f0, allow_pickle=False)\n\n\ndef process_file(file, wavPath, spks, pitPath):\n    if file.endswith(\".wav\"):\n        file = file[:-4]\n        compute_f0(f\"{wavPath}/{spks}/{file}.wav\", f\"{pitPath}/{spks}/{file}.pit\")\n\n\ndef process_files_with_process_pool(wavPath, spks, pitPath, process_num=None):\n    files = [f for f in os.listdir(f\"./{wavPath}/{spks}\") if f.endswith(\".wav\")]\n\n    with ProcessPoolExecutor(max_workers=process_num) as executor:\n        futures = {executor.submit(process_file, file, wavPath, spks, pitPath): file for file in files}\n\n        for future in tqdm(as_completed(futures), total=len(futures), desc=f'Processing f0 {spks}'):\n            future.result()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"-w\", \"--wav\", help=\"wav\", dest=\"wav\", required=True)\n    parser.add_argument(\"-p\", \"--pit\", help=\"pit\", dest=\"pit\", required=True)\n    parser.add_argument(\"-t\", \"--thread_count\", help=\"thread count to process, set 0 to use all cpu cores\", dest=\"thread_count\", type=int, default=1)\n    \n    args = parser.parse_args()\n    print(args.wav)\n    print(args.pit)\n\n    os.makedirs(args.pit, exist_ok=True)\n    wavPath = args.wav\n    pitPath = args.pit\n\n    for spks in os.listdir(wavPath):\n        if os.path.isdir(f\"./{wavPath}/{spks}\"):\n            os.makedirs(f\"./{pitPath}/{spks}\", exist_ok=True)\n            if args.thread_count == 0:\n                process_num = os.cpu_count() // 2 + 1\n            else:\n                process_num = args.thread_count\n            process_files_with_process_pool(wavPath, spks, pitPath, process_num)\n"
  },
  {
    "path": "prepare/preprocess_hubert.py",
    "content": "import sys,os\nsys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))\nimport numpy as np\nimport argparse\nimport torch\nimport librosa\n\nfrom tqdm import tqdm\nfrom hubert import hubert_model\n\n\ndef load_audio(file: str, sr: int = 16000):\n    x, sr = librosa.load(file, sr=sr)\n    return x\n\n\ndef load_model(path, device):\n    model = hubert_model.hubert_soft(path)\n    model.eval()\n    model.half()\n    model.to(device)\n    return model\n\n\ndef pred_vec(model, wavPath, vecPath, device):\n    feats = load_audio(wavPath)\n    feats = torch.from_numpy(feats).to(device)\n    feats = feats[None, None, :].half()\n    with torch.no_grad():\n        vec = model.units(feats).squeeze().data.cpu().float().numpy()\n        # print(vec.shape)   # [length, dim=256] hop=320\n        np.save(vecPath, vec, allow_pickle=False)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"-w\", \"--wav\", help=\"wav\", dest=\"wav\", required=True)\n    parser.add_argument(\"-v\", \"--vec\", help=\"vec\", dest=\"vec\", required=True)\n    \n    args = parser.parse_args()\n    print(args.wav)\n    print(args.vec)\n    os.makedirs(args.vec, exist_ok=True)\n\n    wavPath = args.wav\n    vecPath = args.vec\n\n    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n    hubert = load_model(os.path.join(\"hubert_pretrain\", \"hubert-soft-0d54a1f4.pt\"), device)\n\n    for spks in os.listdir(wavPath):\n        if os.path.isdir(f\"./{wavPath}/{spks}\"):\n            os.makedirs(f\"./{vecPath}/{spks}\", exist_ok=True)\n\n            files = [f for f in os.listdir(f\"./{wavPath}/{spks}\") if f.endswith(\".wav\")]\n            for file in tqdm(files, desc=f'Processing vec {spks}'):\n                file = file[:-4]\n                pred_vec(hubert, f\"{wavPath}/{spks}/{file}.wav\", f\"{vecPath}/{spks}/{file}.vec\", device)\n"
  },
  {
    "path": "prepare/preprocess_speaker.py",
    "content": "import sys,os\nsys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))\nimport torch\nimport numpy as np\nimport argparse\n\nfrom tqdm import tqdm\nfrom functools import partial\nfrom argparse import RawTextHelpFormatter\nfrom multiprocessing.pool import ThreadPool\n\nfrom speaker.models.lstm import LSTMSpeakerEncoder\nfrom speaker.config import SpeakerEncoderConfig\nfrom speaker.utils.audio import AudioProcessor\nfrom speaker.infer import read_json\n\n\ndef get_spk_wavs(dataset_path, output_path):\n    wav_files = []\n    os.makedirs(f\"./{output_path}\", exist_ok=True)\n    for spks in os.listdir(dataset_path):\n        if os.path.isdir(f\"./{dataset_path}/{spks}\"):\n            os.makedirs(f\"./{output_path}/{spks}\", exist_ok=True)\n            for file in os.listdir(f\"./{dataset_path}/{spks}\"):\n                if file.endswith(\".wav\"):\n                    wav_files.append(f\"./{dataset_path}/{spks}/{file}\")\n        elif spks.endswith(\".wav\"):\n            wav_files.append(f\"./{dataset_path}/{spks}\")\n    return wav_files\n\n\ndef process_wav(wav_file, dataset_path, output_path, args, speaker_encoder_ap, speaker_encoder):\n    waveform = speaker_encoder_ap.load_wav(\n        wav_file, sr=speaker_encoder_ap.sample_rate\n    )\n    spec = speaker_encoder_ap.melspectrogram(waveform)\n    spec = torch.from_numpy(spec.T)\n    if args.use_cuda:\n        spec = spec.cuda()\n    spec = spec.unsqueeze(0)\n    embed = speaker_encoder.compute_embedding(spec).detach().cpu().numpy()\n    embed = embed.squeeze()\n    embed_path = wav_file.replace(dataset_path, output_path)\n    embed_path = embed_path.replace(\".wav\", \".spk\")\n    np.save(embed_path, embed, allow_pickle=False)\n\n\ndef extract_speaker_embeddings(wav_files, dataset_path, output_path, args, speaker_encoder_ap, speaker_encoder, concurrency):\n    bound_process_wav = partial(process_wav, dataset_path=dataset_path, output_path=output_path, args=args, speaker_encoder_ap=speaker_encoder_ap, speaker_encoder=speaker_encoder)\n\n    with ThreadPool(concurrency) as pool:\n        list(tqdm(pool.imap(bound_process_wav, wav_files), total=len(wav_files)))\n\n\nif __name__ == \"__main__\":\n\n    parser = argparse.ArgumentParser(\n        description=\"\"\"Compute embedding vectors for each wav file in a dataset.\"\"\",\n        formatter_class=RawTextHelpFormatter,\n    )\n    parser.add_argument(\"dataset_path\", type=str, help=\"Path to dataset waves.\")\n    parser.add_argument(\n        \"output_path\", type=str, help=\"path for output speaker/speaker_wavs.npy.\"\n    )\n    parser.add_argument(\"--use_cuda\", type=bool, help=\"flag to set cuda.\", default=True)\n    parser.add_argument(\"-t\", \"--thread_count\", help=\"thread count to process, set 0 to use all cpu cores\", dest=\"thread_count\", type=int, default=1)\n    args = parser.parse_args()\n    dataset_path = args.dataset_path\n    output_path = args.output_path\n    thread_count = args.thread_count\n    # model\n    args.model_path = os.path.join(\"speaker_pretrain\", \"best_model.pth.tar\")\n    args.config_path = os.path.join(\"speaker_pretrain\", \"config.json\")\n    # config\n    config_dict = read_json(args.config_path)\n\n    # model\n    config = SpeakerEncoderConfig(config_dict)\n    config.from_dict(config_dict)\n\n    speaker_encoder = LSTMSpeakerEncoder(\n        config.model_params[\"input_dim\"],\n        config.model_params[\"proj_dim\"],\n        config.model_params[\"lstm_dim\"],\n        config.model_params[\"num_lstm_layers\"],\n    )\n\n    speaker_encoder.load_checkpoint(args.model_path, eval=True, use_cuda=args.use_cuda)\n\n    # preprocess\n    speaker_encoder_ap = AudioProcessor(**config.audio)\n    # normalize the input audio level and trim silences\n    speaker_encoder_ap.do_sound_norm = True\n    speaker_encoder_ap.do_trim_silence = True\n\n    wav_files = get_spk_wavs(dataset_path, output_path)\n\n    if thread_count == 0:\n        process_num = os.cpu_count()\n    else:\n        process_num = thread_count\n\n    extract_speaker_embeddings(wav_files, dataset_path, output_path, args, speaker_encoder_ap, speaker_encoder, process_num)"
  },
  {
    "path": "prepare/preprocess_speaker_ave.py",
    "content": "import os\nimport torch\nimport argparse\nimport numpy as np\nfrom tqdm import tqdm\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"dataset_speaker\", type=str)\n    parser.add_argument(\"dataset_singer\", type=str)\n\n    data_speaker = parser.parse_args().dataset_speaker\n    data_singer = parser.parse_args().dataset_singer\n\n    os.makedirs(data_singer, exist_ok=True)\n\n    for speaker in os.listdir(data_speaker):\n        subfile_num = 0\n        speaker_ave = 0\n\n        for file in tqdm(os.listdir(os.path.join(data_speaker, speaker)), desc=f\"average {speaker}\"):\n            if not file.endswith(\".npy\"):\n                continue\n            source_embed = np.load(os.path.join(data_speaker, speaker, file))\n            source_embed = source_embed.astype(np.float32)\n            speaker_ave = speaker_ave + source_embed\n            subfile_num = subfile_num + 1\n        if subfile_num == 0:\n            continue\n        speaker_ave = speaker_ave / subfile_num\n\n        np.save(os.path.join(data_singer, f\"{speaker}.spk.npy\"),\n                speaker_ave, allow_pickle=False)\n\n        # rewrite timbre code by average, if similarity is larger than cmp_val\n        rewrite_timbre_code = True\n        if not rewrite_timbre_code:\n            continue\n        cmp_src = torch.FloatTensor(speaker_ave)\n        cmp_num = 0\n        cmp_val = 0.85\n        for file in tqdm(os.listdir(os.path.join(data_speaker, speaker)), desc=f\"rewrite {speaker}\"):\n            if not file.endswith(\".npy\"):\n                continue\n            cmp_tmp = np.load(os.path.join(data_speaker, speaker, file))\n            cmp_tmp = cmp_tmp.astype(np.float32)\n            cmp_tmp = torch.FloatTensor(cmp_tmp)\n            cmp_cos = torch.cosine_similarity(cmp_src, cmp_tmp, dim=0)\n            if (cmp_cos > cmp_val):\n                cmp_num += 1\n                np.save(os.path.join(data_speaker, speaker, file),\n                        speaker_ave, allow_pickle=False)\n        print(f\"rewrite timbre for {speaker} with :\", cmp_num)\n"
  },
  {
    "path": "prepare/preprocess_spec.py",
    "content": "import sys,os\nsys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))\nimport torch\nimport argparse\nfrom concurrent.futures import ThreadPoolExecutor\nfrom spec.inference import mel_spectrogram_file\nfrom tqdm import tqdm\nfrom omegaconf import OmegaConf\n\n\ndef compute_spec(hps, filename, specname):\n    spec = mel_spectrogram_file(filename, hps)\n    spec = torch.squeeze(spec, 0)\n    # print(spec.shape)\n    torch.save(spec, specname)\n\n\ndef process_file(file):\n    if file.endswith(\".wav\"):\n        file = file[:-4]\n        compute_spec(hps, f\"{wavPath}/{spks}/{file}.wav\", f\"{spePath}/{spks}/{file}.mel.pt\")\n\n\ndef process_files_with_thread_pool(wavPath, spks, thread_num):\n    files = os.listdir(f\"./{wavPath}/{spks}\")\n    with ThreadPoolExecutor(max_workers=thread_num) as executor:\n        list(tqdm(executor.map(process_file, files), total=len(files), desc=f'Processing spec {spks}'))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"-w\", \"--wav\", help=\"wav\", dest=\"wav\", required=True)\n    parser.add_argument(\"-s\", \"--spe\", help=\"spe\", dest=\"spe\", required=True)\n    parser.add_argument(\"-t\", \"--thread_count\", help=\"thread count to process, set 0 to use all cpu cores\", dest=\"thread_count\", type=int, default=1)\n\n    args = parser.parse_args()\n    print(args.wav)\n    print(args.spe)\n\n    os.makedirs(args.spe, exist_ok=True)\n    wavPath = args.wav\n    spePath = args.spe\n    hps = OmegaConf.load(\"./configs/base.yaml\")\n\n    for spks in os.listdir(wavPath):\n        if os.path.isdir(f\"./{wavPath}/{spks}\"):\n            os.makedirs(f\"./{spePath}/{spks}\", exist_ok=True)\n            if args.thread_count == 0:\n                process_num = os.cpu_count() // 2 + 1\n            else:\n                process_num = args.thread_count\n            process_files_with_thread_pool(wavPath, spks, process_num)\n"
  },
  {
    "path": "prepare/preprocess_train.py",
    "content": "import os\nimport random\n\n\ndef print_error(info):\n    print(f\"\\033[31m File isn't existed: {info}\\033[0m\")\n\n\nif __name__ == \"__main__\":\n    os.makedirs(\"./files/\", exist_ok=True)\n\n    rootPath = \"./data_gvc/waves-32k/\"\n    all_items = []\n    for spks in os.listdir(f\"./{rootPath}\"):\n        if not os.path.isdir(f\"./{rootPath}/{spks}\"):\n            continue\n        print(f\"./{rootPath}/{spks}\")\n        for file in os.listdir(f\"./{rootPath}/{spks}\"):\n            if file.endswith(\".wav\"):\n                file = file[:-4]\n\n                path_mel = f\"./data_gvc/mel/{spks}/{file}.mel.pt\"\n                path_vec = f\"./data_gvc/hubert/{spks}/{file}.vec.npy\"\n                path_pit = f\"./data_gvc/pitch/{spks}/{file}.pit.npy\"\n                path_spk = f\"./data_gvc/speaker/{spks}/{file}.spk.npy\"\n\n                has_error = 0\n                if not os.path.isfile(path_mel):\n                    print_error(path_mel)\n                    has_error = 1\n                if not os.path.isfile(path_vec):\n                    print_error(path_vec)\n                    has_error = 1\n                if not os.path.isfile(path_pit):\n                    print_error(path_pit)\n                    has_error = 1\n                if not os.path.isfile(path_spk):\n                    print_error(path_spk)\n                    has_error = 1\n                if has_error == 0:\n                    all_items.append(\n                        f\"{path_mel}|{path_vec}|{path_pit}|{path_spk}\")\n\n    random.shuffle(all_items)\n    valids = all_items[:10]\n    valids.sort()\n    trains = all_items[10:]\n    # trains.sort()\n    fw = open(\"./files/valid.txt\", \"w\", encoding=\"utf-8\")\n    for strs in valids:\n        print(strs, file=fw)\n    fw.close()\n    fw = open(\"./files/train.txt\", \"w\", encoding=\"utf-8\")\n    for strs in trains:\n        print(strs, file=fw)\n    fw.close()\n"
  },
  {
    "path": "prepare/preprocess_zzz.py",
    "content": "import sys,os\nsys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))\nfrom tqdm import tqdm\nfrom torch.utils.data import DataLoader\nfrom grad_extend.data import TextMelSpeakerDataset, TextMelSpeakerBatchCollate\n\n\nif __name__ == \"__main__\":\n    filelist_path = \"files/valid.txt\"\n    \n    dataset = TextMelSpeakerDataset(filelist_path)\n    collate = TextMelSpeakerBatchCollate()\n    loader = DataLoader(dataset=dataset, \n                        batch_size=2,\n                        collate_fn=collate, \n                        drop_last=True,\n                        num_workers=1, \n                        shuffle=True)\n    \n    for batch in tqdm(loader):\n        lengths = batch['lengths'].cuda()\n        vec = batch['vec'].cuda()\n        pit = batch['pit'].cuda()\n        spk = batch['spk'].cuda()\n        mel = batch['mel'].cuda()\n    \n        print('len', lengths.shape)\n        print('vec', vec.shape)\n        print('pit', pit.shape)\n        print('spk', spk.shape)\n        print('mel', mel.shape)\n"
  },
  {
    "path": "requirements.txt",
    "content": "librosa\nsoundfile\nmatplotlib\ntensorboard\ntransformers\ntqdm\neinops\nfsspec\nomegaconf\npyworld\npraat-parselmouth\n"
  },
  {
    "path": "speaker/__init__.py",
    "content": ""
  },
  {
    "path": "speaker/config.py",
    "content": "from dataclasses import asdict, dataclass, field\nfrom typing import Dict, List\n\nfrom .utils.coqpit import MISSING\nfrom .utils.shared_configs import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig\n\n\n@dataclass\nclass SpeakerEncoderConfig(BaseTrainingConfig):\n    \"\"\"Defines parameters for Speaker Encoder model.\"\"\"\n\n    model: str = \"speaker_encoder\"\n    audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)\n    datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])\n    # model params\n    model_params: Dict = field(\n        default_factory=lambda: {\n            \"model_name\": \"lstm\",\n            \"input_dim\": 80,\n            \"proj_dim\": 256,\n            \"lstm_dim\": 768,\n            \"num_lstm_layers\": 3,\n            \"use_lstm_with_projection\": True,\n        }\n    )\n\n    audio_augmentation: Dict = field(default_factory=lambda: {})\n\n    storage: Dict = field(\n        default_factory=lambda: {\n            \"sample_from_storage_p\": 0.66,  # the probability with which we'll sample from the DataSet in-memory storage\n            \"storage_size\": 15,  # the size of the in-memory storage with respect to a single batch\n        }\n    )\n\n    # training params\n    max_train_step: int = 1000000  # end training when number of training steps reaches this value.\n    loss: str = \"angleproto\"\n    grad_clip: float = 3.0\n    lr: float = 0.0001\n    lr_decay: bool = False\n    warmup_steps: int = 4000\n    wd: float = 1e-6\n\n    # logging params\n    tb_model_param_stats: bool = False\n    steps_plot_stats: int = 10\n    checkpoint: bool = True\n    save_step: int = 1000\n    print_step: int = 20\n\n    # data loader\n    num_speakers_in_batch: int = MISSING\n    num_utters_per_speaker: int = MISSING\n    num_loader_workers: int = MISSING\n    skip_speakers: bool = False\n    voice_len: float = 1.6\n\n    def check_values(self):\n        super().check_values()\n        c = asdict(self)\n        assert (\n            c[\"model_params\"][\"input_dim\"] == self.audio.num_mels\n        ), \" [!] model input dimendion must be equal to melspectrogram dimension.\"\n"
  },
  {
    "path": "speaker/infer.py",
    "content": "import re\nimport json\nimport fsspec\nimport torch\nimport numpy as np\nimport argparse\n\nfrom argparse import RawTextHelpFormatter\nfrom .models.lstm import LSTMSpeakerEncoder\nfrom .config import SpeakerEncoderConfig\nfrom .utils.audio import AudioProcessor\n\n\ndef read_json(json_path):\n    config_dict = {}\n    try:\n        with fsspec.open(json_path, \"r\", encoding=\"utf-8\") as f:\n            data = json.load(f)\n    except json.decoder.JSONDecodeError:\n        # backwards compat.\n        data = read_json_with_comments(json_path)\n    config_dict.update(data)\n    return config_dict\n\n\ndef read_json_with_comments(json_path):\n    \"\"\"for backward compat.\"\"\"\n    # fallback to json\n    with fsspec.open(json_path, \"r\", encoding=\"utf-8\") as f:\n        input_str = f.read()\n    # handle comments\n    input_str = re.sub(r\"\\\\\\n\", \"\", input_str)\n    input_str = re.sub(r\"//.*\\n\", \"\\n\", input_str)\n    data = json.loads(input_str)\n    return data\n\n\nif __name__ == \"__main__\":\n\n    parser = argparse.ArgumentParser(\n        description=\"\"\"Compute embedding vectors for each wav file in a dataset.\"\"\",\n        formatter_class=RawTextHelpFormatter,\n    )\n    parser.add_argument(\"model_path\", type=str, help=\"Path to model checkpoint file.\")\n    parser.add_argument(\n        \"config_path\",\n        type=str,\n        help=\"Path to model config file.\",\n    )\n\n    parser.add_argument(\"-s\", \"--source\", help=\"input wave\", dest=\"source\")\n    parser.add_argument(\n        \"-t\", \"--target\", help=\"output 256d speaker embeddimg\", dest=\"target\"\n    )\n\n    parser.add_argument(\"--use_cuda\", type=bool, help=\"flag to set cuda.\", default=True)\n    parser.add_argument(\"--eval\", type=bool, help=\"compute eval.\", default=True)\n\n    args = parser.parse_args()\n    source_file = args.source\n    target_file = args.target\n\n    # config\n    config_dict = read_json(args.config_path)\n    # print(config_dict)\n\n    # model\n    config = SpeakerEncoderConfig(config_dict)\n    config.from_dict(config_dict)\n\n    speaker_encoder = LSTMSpeakerEncoder(\n        config.model_params[\"input_dim\"],\n        config.model_params[\"proj_dim\"],\n        config.model_params[\"lstm_dim\"],\n        config.model_params[\"num_lstm_layers\"],\n    )\n\n    speaker_encoder.load_checkpoint(args.model_path, eval=True, use_cuda=args.use_cuda)\n\n    # preprocess\n    speaker_encoder_ap = AudioProcessor(**config.audio)\n    # normalize the input audio level and trim silences\n    speaker_encoder_ap.do_sound_norm = True\n    speaker_encoder_ap.do_trim_silence = True\n\n    # compute speaker embeddings\n\n    # extract the embedding\n    waveform = speaker_encoder_ap.load_wav(\n        source_file, sr=speaker_encoder_ap.sample_rate\n    )\n    spec = speaker_encoder_ap.melspectrogram(waveform)\n    spec = torch.from_numpy(spec.T)\n    if args.use_cuda:\n        spec = spec.cuda()\n    spec = spec.unsqueeze(0)\n    embed = speaker_encoder.compute_embedding(spec).detach().cpu().numpy()\n    embed = embed.squeeze()\n    # print(embed)\n    # print(embed.size)\n    np.save(target_file, embed, allow_pickle=False)\n\n\n    if hasattr(speaker_encoder, 'module'):\n        state_dict = speaker_encoder.module.state_dict()\n    else:\n        state_dict = speaker_encoder.state_dict()\n        torch.save({'model': state_dict}, \"model_small.pth\")\n"
  },
  {
    "path": "speaker/models/__init__.py",
    "content": ""
  },
  {
    "path": "speaker/models/lstm.py",
    "content": "import numpy as np\nimport torch\nfrom torch import nn\n\nfrom ..utils.io import load_fsspec\n\n\nclass LSTMWithProjection(nn.Module):\n    def __init__(self, input_size, hidden_size, proj_size):\n        super().__init__()\n        self.input_size = input_size\n        self.hidden_size = hidden_size\n        self.proj_size = proj_size\n        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)\n        self.linear = nn.Linear(hidden_size, proj_size, bias=False)\n\n    def forward(self, x):\n        self.lstm.flatten_parameters()\n        o, (_, _) = self.lstm(x)\n        return self.linear(o)\n\n\nclass LSTMWithoutProjection(nn.Module):\n    def __init__(self, input_dim, lstm_dim, proj_dim, num_lstm_layers):\n        super().__init__()\n        self.lstm = nn.LSTM(input_size=input_dim, hidden_size=lstm_dim, num_layers=num_lstm_layers, batch_first=True)\n        self.linear = nn.Linear(lstm_dim, proj_dim, bias=True)\n        self.relu = nn.ReLU()\n\n    def forward(self, x):\n        _, (hidden, _) = self.lstm(x)\n        return self.relu(self.linear(hidden[-1]))\n\n\nclass LSTMSpeakerEncoder(nn.Module):\n    def __init__(self, input_dim, proj_dim=256, lstm_dim=768, num_lstm_layers=3, use_lstm_with_projection=True):\n        super().__init__()\n        self.use_lstm_with_projection = use_lstm_with_projection\n        layers = []\n        # choise LSTM layer\n        if use_lstm_with_projection:\n            layers.append(LSTMWithProjection(input_dim, lstm_dim, proj_dim))\n            for _ in range(num_lstm_layers - 1):\n                layers.append(LSTMWithProjection(proj_dim, lstm_dim, proj_dim))\n            self.layers = nn.Sequential(*layers)\n        else:\n            self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers)\n\n        self._init_layers()\n\n    def _init_layers(self):\n        for name, param in self.layers.named_parameters():\n            if \"bias\" in name:\n                nn.init.constant_(param, 0.0)\n            elif \"weight\" in name:\n                nn.init.xavier_normal_(param)\n\n    def forward(self, x):\n        # TODO: implement state passing for lstms\n        d = self.layers(x)\n        if self.use_lstm_with_projection:\n            d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1)\n        else:\n            d = torch.nn.functional.normalize(d, p=2, dim=1)\n        return d\n\n    @torch.no_grad()\n    def inference(self, x):\n        d = self.layers.forward(x)\n        if self.use_lstm_with_projection:\n            d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1)\n        else:\n            d = torch.nn.functional.normalize(d, p=2, dim=1)\n        return d\n\n    def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):\n        \"\"\"\n        Generate embeddings for a batch of utterances\n        x: 1xTxD\n        \"\"\"\n        max_len = x.shape[1]\n\n        if max_len < num_frames:\n            num_frames = max_len\n\n        offsets = np.linspace(0, max_len - num_frames, num=num_eval)\n\n        frames_batch = []\n        for offset in offsets:\n            offset = int(offset)\n            end_offset = int(offset + num_frames)\n            frames = x[:, offset:end_offset]\n            frames_batch.append(frames)\n\n        frames_batch = torch.cat(frames_batch, dim=0)\n        embeddings = self.inference(frames_batch)\n\n        if return_mean:\n            embeddings = torch.mean(embeddings, dim=0, keepdim=True)\n\n        return embeddings\n\n    def batch_compute_embedding(self, x, seq_lens, num_frames=160, overlap=0.5):\n        \"\"\"\n        Generate embeddings for a batch of utterances\n        x: BxTxD\n        \"\"\"\n        num_overlap = num_frames * overlap\n        max_len = x.shape[1]\n        embed = None\n        num_iters = seq_lens / (num_frames - num_overlap)\n        cur_iter = 0\n        for offset in range(0, max_len, num_frames - num_overlap):\n            cur_iter += 1\n            end_offset = min(x.shape[1], offset + num_frames)\n            frames = x[:, offset:end_offset]\n            if embed is None:\n                embed = self.inference(frames)\n            else:\n                embed[cur_iter <= num_iters, :] += self.inference(frames[cur_iter <= num_iters, :, :])\n        return embed / num_iters\n\n    # pylint: disable=unused-argument, redefined-builtin\n    def load_checkpoint(self, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):\n        state = load_fsspec(checkpoint_path, map_location=torch.device(\"cpu\"))\n        self.load_state_dict(state[\"model\"])\n        if use_cuda:\n            self.cuda()\n        if eval:\n            self.eval()\n            assert not self.training\n"
  },
  {
    "path": "speaker/models/resnet.py",
    "content": "import numpy as np\nimport torch\nfrom torch import nn\n\nfrom TTS.utils.io import load_fsspec\n\n\nclass SELayer(nn.Module):\n    def __init__(self, channel, reduction=8):\n        super(SELayer, self).__init__()\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        self.fc = nn.Sequential(\n            nn.Linear(channel, channel // reduction),\n            nn.ReLU(inplace=True),\n            nn.Linear(channel // reduction, channel),\n            nn.Sigmoid(),\n        )\n\n    def forward(self, x):\n        b, c, _, _ = x.size()\n        y = self.avg_pool(x).view(b, c)\n        y = self.fc(y).view(b, c, 1, 1)\n        return x * y\n\n\nclass SEBasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8):\n        super(SEBasicBlock, self).__init__()\n        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.se = SELayer(planes, reduction)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.relu(out)\n        out = self.bn1(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.se(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n        return out\n\n\nclass ResNetSpeakerEncoder(nn.Module):\n    \"\"\"Implementation of the model H/ASP without batch normalization in speaker embedding. This model was proposed in: https://arxiv.org/abs/2009.14153\n    Adapted from: https://github.com/clovaai/voxceleb_trainer\n    \"\"\"\n\n    # pylint: disable=W0102\n    def __init__(\n        self,\n        input_dim=64,\n        proj_dim=512,\n        layers=[3, 4, 6, 3],\n        num_filters=[32, 64, 128, 256],\n        encoder_type=\"ASP\",\n        log_input=False,\n    ):\n        super(ResNetSpeakerEncoder, self).__init__()\n\n        self.encoder_type = encoder_type\n        self.input_dim = input_dim\n        self.log_input = log_input\n        self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1)\n        self.relu = nn.ReLU(inplace=True)\n        self.bn1 = nn.BatchNorm2d(num_filters[0])\n\n        self.inplanes = num_filters[0]\n        self.layer1 = self.create_layer(SEBasicBlock, num_filters[0], layers[0])\n        self.layer2 = self.create_layer(SEBasicBlock, num_filters[1], layers[1], stride=(2, 2))\n        self.layer3 = self.create_layer(SEBasicBlock, num_filters[2], layers[2], stride=(2, 2))\n        self.layer4 = self.create_layer(SEBasicBlock, num_filters[3], layers[3], stride=(2, 2))\n\n        self.instancenorm = nn.InstanceNorm1d(input_dim)\n\n        outmap_size = int(self.input_dim / 8)\n\n        self.attention = nn.Sequential(\n            nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1),\n            nn.ReLU(),\n            nn.BatchNorm1d(128),\n            nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1),\n            nn.Softmax(dim=2),\n        )\n\n        if self.encoder_type == \"SAP\":\n            out_dim = num_filters[3] * outmap_size\n        elif self.encoder_type == \"ASP\":\n            out_dim = num_filters[3] * outmap_size * 2\n        else:\n            raise ValueError(\"Undefined encoder\")\n\n        self.fc = nn.Linear(out_dim, proj_dim)\n\n        self._init_layers()\n\n    def _init_layers(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode=\"fan_out\", nonlinearity=\"relu\")\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n    def create_layer(self, block, planes, blocks, stride=1):\n        downsample = None\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),\n                nn.BatchNorm2d(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, downsample))\n        self.inplanes = planes * block.expansion\n        for _ in range(1, blocks):\n            layers.append(block(self.inplanes, planes))\n\n        return nn.Sequential(*layers)\n\n    # pylint: disable=R0201\n    def new_parameter(self, *size):\n        out = nn.Parameter(torch.FloatTensor(*size))\n        nn.init.xavier_normal_(out)\n        return out\n\n    def forward(self, x, l2_norm=False):\n        x = x.transpose(1, 2)\n        with torch.no_grad():\n            with torch.cuda.amp.autocast(enabled=False):\n                if self.log_input:\n                    x = (x + 1e-6).log()\n                x = self.instancenorm(x).unsqueeze(1)\n\n        x = self.conv1(x)\n        x = self.relu(x)\n        x = self.bn1(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        x = x.reshape(x.size()[0], -1, x.size()[-1])\n\n        w = self.attention(x)\n\n        if self.encoder_type == \"SAP\":\n            x = torch.sum(x * w, dim=2)\n        elif self.encoder_type == \"ASP\":\n            mu = torch.sum(x * w, dim=2)\n            sg = torch.sqrt((torch.sum((x ** 2) * w, dim=2) - mu ** 2).clamp(min=1e-5))\n            x = torch.cat((mu, sg), 1)\n\n        x = x.view(x.size()[0], -1)\n        x = self.fc(x)\n\n        if l2_norm:\n            x = torch.nn.functional.normalize(x, p=2, dim=1)\n        return x\n\n    @torch.no_grad()\n    def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):\n        \"\"\"\n        Generate embeddings for a batch of utterances\n        x: 1xTxD\n        \"\"\"\n        max_len = x.shape[1]\n\n        if max_len < num_frames:\n            num_frames = max_len\n\n        offsets = np.linspace(0, max_len - num_frames, num=num_eval)\n\n        frames_batch = []\n        for offset in offsets:\n            offset = int(offset)\n            end_offset = int(offset + num_frames)\n            frames = x[:, offset:end_offset]\n            frames_batch.append(frames)\n\n        frames_batch = torch.cat(frames_batch, dim=0)\n        embeddings = self.forward(frames_batch, l2_norm=True)\n\n        if return_mean:\n            embeddings = torch.mean(embeddings, dim=0, keepdim=True)\n\n        return embeddings\n\n    def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):\n        state = load_fsspec(checkpoint_path, map_location=torch.device(\"cpu\"))\n        self.load_state_dict(state[\"model\"])\n        if use_cuda:\n            self.cuda()\n        if eval:\n            self.eval()\n            assert not self.training\n"
  },
  {
    "path": "speaker/utils/__init__.py",
    "content": ""
  },
  {
    "path": "speaker/utils/audio.py",
    "content": "from typing import Dict, Tuple\n\nimport librosa\nimport numpy as np\nimport pyworld as pw\nimport scipy.io.wavfile\nimport scipy.signal\nimport soundfile as sf\nimport torch\nfrom torch import nn\n\nclass StandardScaler:\n    \"\"\"StandardScaler for mean-scale normalization with the given mean and scale values.\"\"\"\n\n    def __init__(self, mean: np.ndarray = None, scale: np.ndarray = None) -> None:\n        self.mean_ = mean\n        self.scale_ = scale\n\n    def set_stats(self, mean, scale):\n        self.mean_ = mean\n        self.scale_ = scale\n\n    def reset_stats(self):\n        delattr(self, \"mean_\")\n        delattr(self, \"scale_\")\n\n    def transform(self, X):\n        X = np.asarray(X)\n        X -= self.mean_\n        X /= self.scale_\n        return X\n\n    def inverse_transform(self, X):\n        X = np.asarray(X)\n        X *= self.scale_\n        X += self.mean_\n        return X\n\nclass TorchSTFT(nn.Module):  # pylint: disable=abstract-method\n    \"\"\"Some of the audio processing funtions using Torch for faster batch processing.\n\n    TODO: Merge this with audio.py\n    \"\"\"\n\n    def __init__(\n        self,\n        n_fft,\n        hop_length,\n        win_length,\n        pad_wav=False,\n        window=\"hann_window\",\n        sample_rate=None,\n        mel_fmin=0,\n        mel_fmax=None,\n        n_mels=80,\n        use_mel=False,\n        do_amp_to_db=False,\n        spec_gain=1.0,\n    ):\n        super().__init__()\n        self.n_fft = n_fft\n        self.hop_length = hop_length\n        self.win_length = win_length\n        self.pad_wav = pad_wav\n        self.sample_rate = sample_rate\n        self.mel_fmin = mel_fmin\n        self.mel_fmax = mel_fmax\n        self.n_mels = n_mels\n        self.use_mel = use_mel\n        self.do_amp_to_db = do_amp_to_db\n        self.spec_gain = spec_gain\n        self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)\n        self.mel_basis = None\n        if use_mel:\n            self._build_mel_basis()\n\n    def __call__(self, x):\n        \"\"\"Compute spectrogram frames by torch based stft.\n\n        Args:\n            x (Tensor): input waveform\n\n        Returns:\n            Tensor: spectrogram frames.\n\n        Shapes:\n            x: [B x T] or [:math:`[B, 1, T]`]\n        \"\"\"\n        if x.ndim == 2:\n            x = x.unsqueeze(1)\n        if self.pad_wav:\n            padding = int((self.n_fft - self.hop_length) / 2)\n            x = torch.nn.functional.pad(x, (padding, padding), mode=\"reflect\")\n        # B x D x T x 2\n        o = torch.stft(\n            x.squeeze(1),\n            self.n_fft,\n            self.hop_length,\n            self.win_length,\n            self.window,\n            center=True,\n            pad_mode=\"reflect\",  # compatible with audio.py\n            normalized=False,\n            onesided=True,\n            return_complex=False,\n        )\n        M = o[:, :, :, 0]\n        P = o[:, :, :, 1]\n        S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))\n        if self.use_mel:\n            S = torch.matmul(self.mel_basis.to(x), S)\n        if self.do_amp_to_db:\n            S = self._amp_to_db(S, spec_gain=self.spec_gain)\n        return S\n\n    def _build_mel_basis(self):\n        mel_basis = librosa.filters.mel(\n            sr=self.sample_rate, n_fft=self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax\n        )\n        self.mel_basis = torch.from_numpy(mel_basis).float()\n\n    @staticmethod\n    def _amp_to_db(x, spec_gain=1.0):\n        return torch.log(torch.clamp(x, min=1e-5) * spec_gain)\n\n    @staticmethod\n    def _db_to_amp(x, spec_gain=1.0):\n        return torch.exp(x) / spec_gain\n\n\n# pylint: disable=too-many-public-methods\nclass AudioProcessor(object):\n    \"\"\"Audio Processor for TTS used by all the data pipelines.\n\n    Note:\n        All the class arguments are set to default values to enable a flexible initialization\n        of the class with the model config. They are not meaningful for all the arguments.\n\n    Args:\n        sample_rate (int, optional):\n            target audio sampling rate. Defaults to None.\n\n        resample (bool, optional):\n            enable/disable resampling of the audio clips when the target sampling rate does not match the original sampling rate. Defaults to False.\n\n        num_mels (int, optional):\n            number of melspectrogram dimensions. Defaults to None.\n\n        log_func (int, optional):\n            log exponent used for converting spectrogram aplitude to DB.\n\n        min_level_db (int, optional):\n            minimum db threshold for the computed melspectrograms. Defaults to None.\n\n        frame_shift_ms (int, optional):\n            milliseconds of frames between STFT columns. Defaults to None.\n\n        frame_length_ms (int, optional):\n            milliseconds of STFT window length. Defaults to None.\n\n        hop_length (int, optional):\n            number of frames between STFT columns. Used if ```frame_shift_ms``` is None. Defaults to None.\n\n        win_length (int, optional):\n            STFT window length. Used if ```frame_length_ms``` is None. Defaults to None.\n\n        ref_level_db (int, optional):\n            reference DB level to avoid background noise. In general <20DB corresponds to the air noise. Defaults to None.\n\n        fft_size (int, optional):\n            FFT window size for STFT. Defaults to 1024.\n\n        power (int, optional):\n            Exponent value applied to the spectrogram before GriffinLim. Defaults to None.\n\n        preemphasis (float, optional):\n            Preemphasis coefficient. Preemphasis is disabled if == 0.0. Defaults to 0.0.\n\n        signal_norm (bool, optional):\n            enable/disable signal normalization. Defaults to None.\n\n        symmetric_norm (bool, optional):\n            enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else [0, k], Defaults to None.\n\n        max_norm (float, optional):\n            ```k``` defining the normalization range. Defaults to None.\n\n        mel_fmin (int, optional):\n            minimum filter frequency for computing melspectrograms. Defaults to None.\n\n        mel_fmax (int, optional):\n            maximum filter frequency for computing melspectrograms.. Defaults to None.\n\n        spec_gain (int, optional):\n            gain applied when converting amplitude to DB. Defaults to 20.\n\n        stft_pad_mode (str, optional):\n            Padding mode for STFT. Defaults to 'reflect'.\n\n        clip_norm (bool, optional):\n            enable/disable clipping the our of range values in the normalized audio signal. Defaults to True.\n\n        griffin_lim_iters (int, optional):\n            Number of GriffinLim iterations. Defaults to None.\n\n        do_trim_silence (bool, optional):\n            enable/disable silence trimming when loading the audio signal. Defaults to False.\n\n        trim_db (int, optional):\n            DB threshold used for silence trimming. Defaults to 60.\n\n        do_sound_norm (bool, optional):\n            enable/disable signal normalization. Defaults to False.\n\n        do_amp_to_db_linear (bool, optional):\n            enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True.\n\n        do_amp_to_db_mel (bool, optional):\n            enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.\n\n        stats_path (str, optional):\n            Path to the computed stats file. Defaults to None.\n\n        verbose (bool, optional):\n            enable/disable logging. Defaults to True.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        sample_rate=None,\n        resample=False,\n        num_mels=None,\n        log_func=\"np.log10\",\n        min_level_db=None,\n        frame_shift_ms=None,\n        frame_length_ms=None,\n        hop_length=None,\n        win_length=None,\n        ref_level_db=None,\n        fft_size=1024,\n        power=None,\n        preemphasis=0.0,\n        signal_norm=None,\n        symmetric_norm=None,\n        max_norm=None,\n        mel_fmin=None,\n        mel_fmax=None,\n        spec_gain=20,\n        stft_pad_mode=\"reflect\",\n        clip_norm=True,\n        griffin_lim_iters=None,\n        do_trim_silence=False,\n        trim_db=60,\n        do_sound_norm=False,\n        do_amp_to_db_linear=True,\n        do_amp_to_db_mel=True,\n        stats_path=None,\n        verbose=True,\n        **_,\n    ):\n\n        # setup class attributed\n        self.sample_rate = sample_rate\n        self.resample = resample\n        self.num_mels = num_mels\n        self.log_func = log_func\n        self.min_level_db = min_level_db or 0\n        self.frame_shift_ms = frame_shift_ms\n        self.frame_length_ms = frame_length_ms\n        self.ref_level_db = ref_level_db\n        self.fft_size = fft_size\n        self.power = power\n        self.preemphasis = preemphasis\n        self.griffin_lim_iters = griffin_lim_iters\n        self.signal_norm = signal_norm\n        self.symmetric_norm = symmetric_norm\n        self.mel_fmin = mel_fmin or 0\n        self.mel_fmax = mel_fmax\n        self.spec_gain = float(spec_gain)\n        self.stft_pad_mode = stft_pad_mode\n        self.max_norm = 1.0 if max_norm is None else float(max_norm)\n        self.clip_norm = clip_norm\n        self.do_trim_silence = do_trim_silence\n        self.trim_db = trim_db\n        self.do_sound_norm = do_sound_norm\n        self.do_amp_to_db_linear = do_amp_to_db_linear\n        self.do_amp_to_db_mel = do_amp_to_db_mel\n        self.stats_path = stats_path\n        # setup exp_func for db to amp conversion\n        if log_func == \"np.log\":\n            self.base = np.e\n        elif log_func == \"np.log10\":\n            self.base = 10\n        else:\n            raise ValueError(\" [!] unknown `log_func` value.\")\n        # setup stft parameters\n        if hop_length is None:\n            # compute stft parameters from given time values\n            self.hop_length, self.win_length = self._stft_parameters()\n        else:\n            # use stft parameters from config file\n            self.hop_length = hop_length\n            self.win_length = win_length\n        assert min_level_db != 0.0, \" [!] min_level_db is 0\"\n        assert self.win_length <= self.fft_size, \" [!] win_length cannot be larger than fft_size\"\n        members = vars(self)\n        if verbose:\n            print(\" > Setting up Audio Processor...\")\n            for key, value in members.items():\n                print(\" | > {}:{}\".format(key, value))\n        # create spectrogram utils\n        self.mel_basis = self._build_mel_basis()\n        self.inv_mel_basis = np.linalg.pinv(self._build_mel_basis())\n        # setup scaler\n        if stats_path and signal_norm:\n            mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path)\n            self.setup_scaler(mel_mean, mel_std, linear_mean, linear_std)\n            self.signal_norm = True\n            self.max_norm = None\n            self.clip_norm = None\n            self.symmetric_norm = None\n\n    ### setting up the parameters ###\n    def _build_mel_basis(\n        self,\n    ) -> np.ndarray:\n        \"\"\"Build melspectrogram basis.\n\n        Returns:\n            np.ndarray: melspectrogram basis.\n        \"\"\"\n        if self.mel_fmax is not None:\n            assert self.mel_fmax <= self.sample_rate // 2\n        return librosa.filters.mel(\n            sr=self.sample_rate, n_fft=self.fft_size, n_mels=self.num_mels, fmin=self.mel_fmin, fmax=self.mel_fmax\n        )\n\n    def _stft_parameters(\n        self,\n    ) -> Tuple[int, int]:\n        \"\"\"Compute the real STFT parameters from the time values.\n\n        Returns:\n            Tuple[int, int]: hop length and window length for STFT.\n        \"\"\"\n        factor = self.frame_length_ms / self.frame_shift_ms\n        assert (factor).is_integer(), \" [!] frame_shift_ms should divide frame_length_ms\"\n        hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)\n        win_length = int(hop_length * factor)\n        return hop_length, win_length\n\n    ### normalization ###\n    def normalize(self, S: np.ndarray) -> np.ndarray:\n        \"\"\"Normalize values into `[0, self.max_norm]` or `[-self.max_norm, self.max_norm]`\n\n        Args:\n            S (np.ndarray): Spectrogram to normalize.\n\n        Raises:\n            RuntimeError: Mean and variance is computed from incompatible parameters.\n\n        Returns:\n            np.ndarray: Normalized spectrogram.\n        \"\"\"\n        # pylint: disable=no-else-return\n        S = S.copy()\n        if self.signal_norm:\n            # mean-var scaling\n            if hasattr(self, \"mel_scaler\"):\n                if S.shape[0] == self.num_mels:\n                    return self.mel_scaler.transform(S.T).T\n                elif S.shape[0] == self.fft_size / 2:\n                    return self.linear_scaler.transform(S.T).T\n                else:\n                    raise RuntimeError(\" [!] Mean-Var stats does not match the given feature dimensions.\")\n            # range normalization\n            S -= self.ref_level_db  # discard certain range of DB assuming it is air noise\n            S_norm = (S - self.min_level_db) / (-self.min_level_db)\n            if self.symmetric_norm:\n                S_norm = ((2 * self.max_norm) * S_norm) - self.max_norm\n                if self.clip_norm:\n                    S_norm = np.clip(\n                        S_norm, -self.max_norm, self.max_norm  # pylint: disable=invalid-unary-operand-type\n                    )\n                return S_norm\n            else:\n                S_norm = self.max_norm * S_norm\n                if self.clip_norm:\n                    S_norm = np.clip(S_norm, 0, self.max_norm)\n                return S_norm\n        else:\n            return S\n\n    def denormalize(self, S: np.ndarray) -> np.ndarray:\n        \"\"\"Denormalize spectrogram values.\n\n        Args:\n            S (np.ndarray): Spectrogram to denormalize.\n\n        Raises:\n            RuntimeError: Mean and variance are incompatible.\n\n        Returns:\n            np.ndarray: Denormalized spectrogram.\n        \"\"\"\n        # pylint: disable=no-else-return\n        S_denorm = S.copy()\n        if self.signal_norm:\n            # mean-var scaling\n            if hasattr(self, \"mel_scaler\"):\n                if S_denorm.shape[0] == self.num_mels:\n                    return self.mel_scaler.inverse_transform(S_denorm.T).T\n                elif S_denorm.shape[0] == self.fft_size / 2:\n                    return self.linear_scaler.inverse_transform(S_denorm.T).T\n                else:\n                    raise RuntimeError(\" [!] Mean-Var stats does not match the given feature dimensions.\")\n            if self.symmetric_norm:\n                if self.clip_norm:\n                    S_denorm = np.clip(\n                        S_denorm, -self.max_norm, self.max_norm  # pylint: disable=invalid-unary-operand-type\n                    )\n                S_denorm = ((S_denorm + self.max_norm) * -self.min_level_db / (2 * self.max_norm)) + self.min_level_db\n                return S_denorm + self.ref_level_db\n            else:\n                if self.clip_norm:\n                    S_denorm = np.clip(S_denorm, 0, self.max_norm)\n                S_denorm = (S_denorm * -self.min_level_db / self.max_norm) + self.min_level_db\n                return S_denorm + self.ref_level_db\n        else:\n            return S_denorm\n\n    ### Mean-STD scaling ###\n    def load_stats(self, stats_path: str) -> Tuple[np.array, np.array, np.array, np.array, Dict]:\n        \"\"\"Loading mean and variance statistics from a `npy` file.\n\n        Args:\n            stats_path (str): Path to the `npy` file containing\n\n        Returns:\n            Tuple[np.array, np.array, np.array, np.array, Dict]: loaded statistics and the config used to\n                compute them.\n        \"\"\"\n        stats = np.load(stats_path, allow_pickle=True).item()  # pylint: disable=unexpected-keyword-arg\n        mel_mean = stats[\"mel_mean\"]\n        mel_std = stats[\"mel_std\"]\n        linear_mean = stats[\"linear_mean\"]\n        linear_std = stats[\"linear_std\"]\n        stats_config = stats[\"audio_config\"]\n        # check all audio parameters used for computing stats\n        skip_parameters = [\"griffin_lim_iters\", \"stats_path\", \"do_trim_silence\", \"ref_level_db\", \"power\"]\n        for key in stats_config.keys():\n            if key in skip_parameters:\n                continue\n            if key not in [\"sample_rate\", \"trim_db\"]:\n                assert (\n                    stats_config[key] == self.__dict__[key]\n                ), f\" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}\"\n        return mel_mean, mel_std, linear_mean, linear_std, stats_config\n\n    # pylint: disable=attribute-defined-outside-init\n    def setup_scaler(\n        self, mel_mean: np.ndarray, mel_std: np.ndarray, linear_mean: np.ndarray, linear_std: np.ndarray\n    ) -> None:\n        \"\"\"Initialize scaler objects used in mean-std normalization.\n\n        Args:\n            mel_mean (np.ndarray): Mean for melspectrograms.\n            mel_std (np.ndarray): STD for melspectrograms.\n            linear_mean (np.ndarray): Mean for full scale spectrograms.\n            linear_std (np.ndarray): STD for full scale spectrograms.\n        \"\"\"\n        self.mel_scaler = StandardScaler()\n        self.mel_scaler.set_stats(mel_mean, mel_std)\n        self.linear_scaler = StandardScaler()\n        self.linear_scaler.set_stats(linear_mean, linear_std)\n\n    ### DB and AMP conversion ###\n    # pylint: disable=no-self-use\n    def _amp_to_db(self, x: np.ndarray) -> np.ndarray:\n        \"\"\"Convert amplitude values to decibels.\n\n        Args:\n            x (np.ndarray): Amplitude spectrogram.\n\n        Returns:\n            np.ndarray: Decibels spectrogram.\n        \"\"\"\n        return self.spec_gain * _log(np.maximum(1e-5, x), self.base)\n\n    # pylint: disable=no-self-use\n    def _db_to_amp(self, x: np.ndarray) -> np.ndarray:\n        \"\"\"Convert decibels spectrogram to amplitude spectrogram.\n\n        Args:\n            x (np.ndarray): Decibels spectrogram.\n\n        Returns:\n            np.ndarray: Amplitude spectrogram.\n        \"\"\"\n        return _exp(x / self.spec_gain, self.base)\n\n    ### Preemphasis ###\n    def apply_preemphasis(self, x: np.ndarray) -> np.ndarray:\n        \"\"\"Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values.\n\n        Args:\n            x (np.ndarray): Audio signal.\n\n        Raises:\n            RuntimeError: Preemphasis coeff is set to 0.\n\n        Returns:\n            np.ndarray: Decorrelated audio signal.\n        \"\"\"\n        if self.preemphasis == 0:\n            raise RuntimeError(\" [!] Preemphasis is set 0.0.\")\n        return scipy.signal.lfilter([1, -self.preemphasis], [1], x)\n\n    def apply_inv_preemphasis(self, x: np.ndarray) -> np.ndarray:\n        \"\"\"Reverse pre-emphasis.\"\"\"\n        if self.preemphasis == 0:\n            raise RuntimeError(\" [!] Preemphasis is set 0.0.\")\n        return scipy.signal.lfilter([1], [1, -self.preemphasis], x)\n\n    ### SPECTROGRAMs ###\n    def _linear_to_mel(self, spectrogram: np.ndarray) -> np.ndarray:\n        \"\"\"Project a full scale spectrogram to a melspectrogram.\n\n        Args:\n            spectrogram (np.ndarray): Full scale spectrogram.\n\n        Returns:\n            np.ndarray: Melspectrogram\n        \"\"\"\n        return np.dot(self.mel_basis, spectrogram)\n\n    def _mel_to_linear(self, mel_spec: np.ndarray) -> np.ndarray:\n        \"\"\"Convert a melspectrogram to full scale spectrogram.\"\"\"\n        return np.maximum(1e-10, np.dot(self.inv_mel_basis, mel_spec))\n\n    def spectrogram(self, y: np.ndarray) -> np.ndarray:\n        \"\"\"Compute a spectrogram from a waveform.\n\n        Args:\n            y (np.ndarray): Waveform.\n\n        Returns:\n            np.ndarray: Spectrogram.\n        \"\"\"\n        if self.preemphasis != 0:\n            D = self._stft(self.apply_preemphasis(y))\n        else:\n            D = self._stft(y)\n        if self.do_amp_to_db_linear:\n            S = self._amp_to_db(np.abs(D))\n        else:\n            S = np.abs(D)\n        return self.normalize(S).astype(np.float32)\n\n    def melspectrogram(self, y: np.ndarray) -> np.ndarray:\n        \"\"\"Compute a melspectrogram from a waveform.\"\"\"\n        if self.preemphasis != 0:\n            D = self._stft(self.apply_preemphasis(y))\n        else:\n            D = self._stft(y)\n        if self.do_amp_to_db_mel:\n            S = self._amp_to_db(self._linear_to_mel(np.abs(D)))\n        else:\n            S = self._linear_to_mel(np.abs(D))\n        return self.normalize(S).astype(np.float32)\n\n    def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray:\n        \"\"\"Convert a spectrogram to a waveform using Griffi-Lim vocoder.\"\"\"\n        S = self.denormalize(spectrogram)\n        S = self._db_to_amp(S)\n        # Reconstruct phase\n        if self.preemphasis != 0:\n            return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))\n        return self._griffin_lim(S ** self.power)\n\n    def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray:\n        \"\"\"Convert a melspectrogram to a waveform using Griffi-Lim vocoder.\"\"\"\n        D = self.denormalize(mel_spectrogram)\n        S = self._db_to_amp(D)\n        S = self._mel_to_linear(S)  # Convert back to linear\n        if self.preemphasis != 0:\n            return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))\n        return self._griffin_lim(S ** self.power)\n\n    def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray:\n        \"\"\"Convert a full scale linear spectrogram output of a network to a melspectrogram.\n\n        Args:\n            linear_spec (np.ndarray): Normalized full scale linear spectrogram.\n\n        Returns:\n            np.ndarray: Normalized melspectrogram.\n        \"\"\"\n        S = self.denormalize(linear_spec)\n        S = self._db_to_amp(S)\n        S = self._linear_to_mel(np.abs(S))\n        S = self._amp_to_db(S)\n        mel = self.normalize(S)\n        return mel\n\n    ### STFT and ISTFT ###\n    def _stft(self, y: np.ndarray) -> np.ndarray:\n        \"\"\"Librosa STFT wrapper.\n\n        Args:\n            y (np.ndarray): Audio signal.\n\n        Returns:\n            np.ndarray: Complex number array.\n        \"\"\"\n        return librosa.stft(\n            y=y,\n            n_fft=self.fft_size,\n            hop_length=self.hop_length,\n            win_length=self.win_length,\n            pad_mode=self.stft_pad_mode,\n            window=\"hann\",\n            center=True,\n        )\n\n    def _istft(self, y: np.ndarray) -> np.ndarray:\n        \"\"\"Librosa iSTFT wrapper.\"\"\"\n        return librosa.istft(y, hop_length=self.hop_length, win_length=self.win_length)\n\n    def _griffin_lim(self, S):\n        angles = np.exp(2j * np.pi * np.random.rand(*S.shape))\n        S_complex = np.abs(S).astype(np.complex)\n        y = self._istft(S_complex * angles)\n        if not np.isfinite(y).all():\n            print(\" [!] Waveform is not finite everywhere. Skipping the GL.\")\n            return np.array([0.0])\n        for _ in range(self.griffin_lim_iters):\n            angles = np.exp(1j * np.angle(self._stft(y)))\n            y = self._istft(S_complex * angles)\n        return y\n\n    def compute_stft_paddings(self, x, pad_sides=1):\n        \"\"\"Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding\n        (first and final frames)\"\"\"\n        assert pad_sides in (1, 2)\n        pad = (x.shape[0] // self.hop_length + 1) * self.hop_length - x.shape[0]\n        if pad_sides == 1:\n            return 0, pad\n        return pad // 2, pad // 2 + pad % 2\n\n    def compute_f0(self, x: np.ndarray) -> np.ndarray:\n        \"\"\"Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.\n\n        Args:\n            x (np.ndarray): Waveform.\n\n        Returns:\n            np.ndarray: Pitch.\n\n        Examples:\n            >>> WAV_FILE = filename = librosa.util.example_audio_file()\n            >>> from TTS.config import BaseAudioConfig\n            >>> from TTS.utils.audio import AudioProcessor\n            >>> conf = BaseAudioConfig(mel_fmax=8000)\n            >>> ap = AudioProcessor(**conf)\n            >>> wav = ap.load_wav(WAV_FILE, sr=22050)[:5 * 22050]\n            >>> pitch = ap.compute_f0(wav)\n        \"\"\"\n        f0, t = pw.dio(\n            x.astype(np.double),\n            fs=self.sample_rate,\n            f0_ceil=self.mel_fmax,\n            frame_period=1000 * self.hop_length / self.sample_rate,\n        )\n        f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate)\n        # pad = int((self.win_length / self.hop_length) / 2)\n        # f0 = [0.0] * pad + f0 + [0.0] * pad\n        # f0 = np.pad(f0, (pad, pad), mode=\"constant\", constant_values=0)\n        # f0 = np.array(f0, dtype=np.float32)\n\n        # f01, _, _ = librosa.pyin(\n        #     x,\n        #     fmin=65 if self.mel_fmin == 0 else self.mel_fmin,\n        #     fmax=self.mel_fmax,\n        #     frame_length=self.win_length,\n        #     sr=self.sample_rate,\n        #     fill_na=0.0,\n        # )\n\n        # spec = self.melspectrogram(x)\n        return f0\n\n    ### Audio Processing ###\n    def find_endpoint(self, wav: np.ndarray, threshold_db=-40, min_silence_sec=0.8) -> int:\n        \"\"\"Find the last point without silence at the end of a audio signal.\n\n        Args:\n            wav (np.ndarray): Audio signal.\n            threshold_db (int, optional): Silence threshold in decibels. Defaults to -40.\n            min_silence_sec (float, optional): Ignore silences that are shorter then this in secs. Defaults to 0.8.\n\n        Returns:\n            int: Last point without silence.\n        \"\"\"\n        window_length = int(self.sample_rate * min_silence_sec)\n        hop_length = int(window_length / 4)\n        threshold = self._db_to_amp(threshold_db)\n        for x in range(hop_length, len(wav) - window_length, hop_length):\n            if np.max(wav[x : x + window_length]) < threshold:\n                return x + hop_length\n        return len(wav)\n\n    def trim_silence(self, wav):\n        \"\"\"Trim silent parts with a threshold and 0.01 sec margin\"\"\"\n        margin = int(self.sample_rate * 0.01)\n        wav = wav[margin:-margin]\n        return librosa.effects.trim(wav, top_db=self.trim_db, frame_length=self.win_length, hop_length=self.hop_length)[\n            0\n        ]\n\n    @staticmethod\n    def sound_norm(x: np.ndarray) -> np.ndarray:\n        \"\"\"Normalize the volume of an audio signal.\n\n        Args:\n            x (np.ndarray): Raw waveform.\n\n        Returns:\n            np.ndarray: Volume normalized waveform.\n        \"\"\"\n        return x / abs(x).max() * 0.95\n\n    ### save and load ###\n    def load_wav(self, filename: str, sr: int = None) -> np.ndarray:\n        \"\"\"Read a wav file using Librosa and optionally resample, silence trim, volume normalize.\n\n        Args:\n            filename (str): Path to the wav file.\n            sr (int, optional): Sampling rate for resampling. Defaults to None.\n\n        Returns:\n            np.ndarray: Loaded waveform.\n        \"\"\"\n        if self.resample:\n            x, sr = librosa.load(filename, sr=self.sample_rate)\n        elif sr is None:\n            x, sr = sf.read(filename)\n            assert self.sample_rate == sr, \"%s vs %s\" % (self.sample_rate, sr)\n        else:\n            x, sr = librosa.load(filename, sr=sr)\n        if self.do_trim_silence:\n            try:\n                x = self.trim_silence(x)\n            except ValueError:\n                print(f\" [!] File cannot be trimmed for silence - {filename}\")\n        if self.do_sound_norm:\n            x = self.sound_norm(x)\n        return x\n\n    def save_wav(self, wav: np.ndarray, path: str, sr: int = None) -> None:\n        \"\"\"Save a waveform to a file using Scipy.\n\n        Args:\n            wav (np.ndarray): Waveform to save.\n            path (str): Path to a output file.\n            sr (int, optional): Sampling rate used for saving to the file. Defaults to None.\n        \"\"\"\n        wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))\n        scipy.io.wavfile.write(path, sr if sr else self.sample_rate, wav_norm.astype(np.int16))\n\n    @staticmethod\n    def mulaw_encode(wav: np.ndarray, qc: int) -> np.ndarray:\n        mu = 2 ** qc - 1\n        # wav_abs = np.minimum(np.abs(wav), 1.0)\n        signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu)\n        # Quantize signal to the specified number of levels.\n        signal = (signal + 1) / 2 * mu + 0.5\n        return np.floor(\n            signal,\n        )\n\n    @staticmethod\n    def mulaw_decode(wav, qc):\n        \"\"\"Recovers waveform from quantized values.\"\"\"\n        mu = 2 ** qc - 1\n        x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)\n        return x\n\n    @staticmethod\n    def encode_16bits(x):\n        return np.clip(x * 2 ** 15, -(2 ** 15), 2 ** 15 - 1).astype(np.int16)\n\n    @staticmethod\n    def quantize(x: np.ndarray, bits: int) -> np.ndarray:\n        \"\"\"Quantize a waveform to a given number of bits.\n\n        Args:\n            x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`.\n            bits (int): Number of quantization bits.\n\n        Returns:\n            np.ndarray: Quantized waveform.\n        \"\"\"\n        return (x + 1.0) * (2 ** bits - 1) / 2\n\n    @staticmethod\n    def dequantize(x, bits):\n        \"\"\"Dequantize a waveform from the given number of bits.\"\"\"\n        return 2 * x / (2 ** bits - 1) - 1\n\n\ndef _log(x, base):\n    if base == 10:\n        return np.log10(x)\n    return np.log(x)\n\n\ndef _exp(x, base):\n    if base == 10:\n        return np.power(10, x)\n    return np.exp(x)\n"
  },
  {
    "path": "speaker/utils/coqpit.py",
    "content": "import argparse\nimport functools\nimport json\nimport operator\nimport os\nfrom collections.abc import MutableMapping\nfrom dataclasses import MISSING as _MISSING\nfrom dataclasses import Field, asdict, dataclass, fields, is_dataclass, replace\nfrom pathlib import Path\nfrom pprint import pprint\nfrom typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, get_type_hints\n\nT = TypeVar(\"T\")\nMISSING: Any = \"???\"\n\n\nclass _NoDefault(Generic[T]):\n    pass\n\n\nNoDefaultVar = Union[_NoDefault[T], T]\nno_default: NoDefaultVar = _NoDefault()\n\n\ndef is_primitive_type(arg_type: Any) -> bool:\n    \"\"\"Check if the input type is one of `int, float, str, bool`.\n\n    Args:\n        arg_type (typing.Any): input type to check.\n\n    Returns:\n        bool: True if input type is one of `int, float, str, bool`.\n    \"\"\"\n    try:\n        return isinstance(arg_type(), (int, float, str, bool))\n    except (AttributeError, TypeError):\n        return False\n\n\ndef is_list(arg_type: Any) -> bool:\n    \"\"\"Check if the input type is `list`\n\n    Args:\n        arg_type (typing.Any): input type.\n\n    Returns:\n        bool: True if input type is `list`\n    \"\"\"\n    try:\n        return arg_type is list or arg_type is List or arg_type.__origin__ is list or arg_type.__origin__ is List\n    except AttributeError:\n        return False\n\n\ndef is_dict(arg_type: Any) -> bool:\n    \"\"\"Check if the input type is `dict`\n\n    Args:\n        arg_type (typing.Any): input type.\n\n    Returns:\n        bool: True if input type is `dict`\n    \"\"\"\n    try:\n        return arg_type is dict or arg_type is Dict or arg_type.__origin__ is dict\n    except AttributeError:\n        return False\n\n\ndef is_union(arg_type: Any) -> bool:\n    \"\"\"Check if the input type is `Union`.\n\n    Args:\n        arg_type (typing.Any): input type.\n\n    Returns:\n        bool: True if input type is `Union`\n    \"\"\"\n    try:\n        return safe_issubclass(arg_type.__origin__, Union)\n    except AttributeError:\n        return False\n\n\ndef safe_issubclass(cls, classinfo) -> bool:\n    \"\"\"Check if the input type is a subclass of the given class.\n\n    Args:\n        cls (type): input type.\n        classinfo (type): parent class.\n\n    Returns:\n        bool: True if the input type is a subclass of the given class\n    \"\"\"\n    try:\n        r = issubclass(cls, classinfo)\n    except Exception:  # pylint: disable=broad-except\n        return cls is classinfo\n    else:\n        return r\n\n\ndef _coqpit_json_default(obj: Any) -> Any:\n    if isinstance(obj, Path):\n        return str(obj)\n    raise TypeError(f\"Can't encode object of type {type(obj).__name__}\")\n\n\ndef _default_value(x: Field):\n    \"\"\"Return the default value of the input Field.\n\n    Args:\n        x (Field): input Field.\n\n    Returns:\n        object: default value of the input Field.\n    \"\"\"\n    if x.default not in (MISSING, _MISSING):\n        return x.default\n    if x.default_factory not in (MISSING, _MISSING):\n        return x.default_factory()\n    return x.default\n\n\ndef _is_optional_field(field) -> bool:\n    \"\"\"Check if the input field is optional.\n\n    Args:\n        field (Field): input Field to check.\n\n    Returns:\n        bool: True if the input field is optional.\n    \"\"\"\n    # return isinstance(field.type, _GenericAlias) and type(None) in getattr(field.type, \"__args__\")\n    return type(None) in getattr(field.type, \"__args__\")\n\n\ndef my_get_type_hints(\n    cls,\n):\n    \"\"\"Custom `get_type_hints` dealing with https://github.com/python/typing/issues/737\n\n    Returns:\n        [dataclass]: dataclass to get the type hints of its fields.\n    \"\"\"\n    r_dict = {}\n    for base in cls.__class__.__bases__:\n        if base == object:\n            break\n        r_dict.update(my_get_type_hints(base))\n    r_dict.update(get_type_hints(cls))\n    return r_dict\n\n\ndef _serialize(x):\n    \"\"\"Pick the right serialization for the datatype of the given input.\n\n    Args:\n        x (object): input object.\n\n    Returns:\n        object: serialized object.\n    \"\"\"\n    if isinstance(x, Path):\n        return str(x)\n    if isinstance(x, dict):\n        return {k: _serialize(v) for k, v in x.items()}\n    if isinstance(x, list):\n        return [_serialize(xi) for xi in x]\n    if isinstance(x, Serializable) or issubclass(type(x), Serializable):\n        return x.serialize()\n    if isinstance(x, type) and issubclass(x, Serializable):\n        return x.serialize(x)\n    return x\n\n\ndef _deserialize_dict(x: Dict) -> Dict:\n    \"\"\"Deserialize dict.\n\n    Args:\n        x (Dict): value to deserialized.\n\n    Returns:\n        Dict: deserialized dictionary.\n    \"\"\"\n    out_dict = {}\n    for k, v in x.items():\n        if v is None:  # if {'key':None}\n            out_dict[k] = None\n        else:\n            out_dict[k] = _deserialize(v, type(v))\n    return out_dict\n\n\ndef _deserialize_list(x: List, field_type: Type) -> List:\n    \"\"\"Deserialize values for List typed fields.\n\n    Args:\n        x (List): value to be deserialized\n        field_type (Type): field type.\n\n    Raises:\n        ValueError: Coqpit does not support multi type-hinted lists.\n\n    Returns:\n        [List]: deserialized list.\n    \"\"\"\n    field_args = None\n    if hasattr(field_type, \"__args__\") and field_type.__args__:\n        field_args = field_type.__args__\n    elif hasattr(field_type, \"__parameters__\") and field_type.__parameters__:\n        # bandaid for python 3.6\n        field_args = field_type.__parameters__\n    if field_args:\n        if len(field_args) > 1:\n            raise ValueError(\" [!] Coqpit does not support multi-type hinted 'List'\")\n        field_arg = field_args[0]\n        # if field type is TypeVar set the current type by the value's type.\n        if isinstance(field_arg, TypeVar):\n            field_arg = type(x)\n        return [_deserialize(xi, field_arg) for xi in x]\n    return x\n\n\ndef _deserialize_union(x: Any, field_type: Type) -> Any:\n    \"\"\"Deserialize values for Union typed fields\n\n    Args:\n        x (Any): value to be deserialized.\n        field_type (Type): field type.\n\n    Returns:\n        [Any]: desrialized value.\n    \"\"\"\n    for arg in field_type.__args__:\n        # stop after first matching type in Union\n        try:\n            x = _deserialize(x, arg)\n            break\n        except ValueError:\n            pass\n    return x\n\n\ndef _deserialize_primitive_types(x: Union[int, float, str, bool], field_type: Type) -> Union[int, float, str, bool]:\n    \"\"\"Deserialize python primitive types (float, int, str, bool).\n    It handles `inf` values exclusively and keeps them float against int fields since int does not support inf values.\n\n    Args:\n        x (Union[int, float, str, bool]): value to be deserialized.\n        field_type (Type): field type.\n\n    Returns:\n        Union[int, float, str, bool]: deserialized value.\n    \"\"\"\n\n    if isinstance(x, (str, bool)):\n        return x\n    if isinstance(x, (int, float)):\n        if x == float(\"inf\") or x == float(\"-inf\"):\n            # if value type is inf return regardless.\n            return x\n        x = field_type(x)\n        return x\n    # TODO: Raise an error when x does not match the types.\n    return None\n\n\ndef _deserialize(x: Any, field_type: Any) -> Any:\n    \"\"\"Pick the right desrialization for the given object and the corresponding field type.\n\n    Args:\n        x (object): object to be deserialized.\n        field_type (type): expected type after deserialization.\n\n    Returns:\n        object: deserialized object\n\n    \"\"\"\n    # pylint: disable=too-many-return-statements\n    if is_dict(field_type):\n        return _deserialize_dict(x)\n    if is_list(field_type):\n        return _deserialize_list(x, field_type)\n    if is_union(field_type):\n        return _deserialize_union(x, field_type)\n    if issubclass(field_type, Serializable):\n        return field_type.deserialize_immutable(x)\n    if is_primitive_type(field_type):\n        return _deserialize_primitive_types(x, field_type)\n    raise ValueError(f\" [!] '{type(x)}' value type of '{x}' does not match '{field_type}' field type.\")\n\n\n# Recursive setattr (supports dotted attr names)\ndef rsetattr(obj, attr, val):\n    def _setitem(obj, attr, val):\n        return operator.setitem(obj, int(attr), val)\n\n    pre, _, post = attr.rpartition(\".\")\n    setfunc = _setitem if post.isnumeric() else setattr\n\n    return setfunc(rgetattr(obj, pre) if pre else obj, post, val)\n\n\n# Recursive getattr (supports dotted attr names)\ndef rgetattr(obj, attr, *args):\n    def _getitem(obj, attr):\n        return operator.getitem(obj, int(attr), *args)\n\n    def _getattr(obj, attr):\n        getfunc = _getitem if attr.isnumeric() else getattr\n        return getfunc(obj, attr, *args)\n\n    return functools.reduce(_getattr, [obj] + attr.split(\".\"))\n\n\n# Recursive setitem (supports dotted attr names)\ndef rsetitem(obj, attr, val):\n    pre, _, post = attr.rpartition(\".\")\n    return operator.setitem(rgetitem(obj, pre) if pre else obj, post, val)\n\n\n# Recursive getitem (supports dotted attr names)\ndef rgetitem(obj, attr, *args):\n    def _getitem(obj, attr):\n        return operator.getitem(obj, int(attr) if attr.isnumeric() else attr, *args)\n\n    return functools.reduce(_getitem, [obj] + attr.split(\".\"))\n\n\n@dataclass\nclass Serializable:\n    \"\"\"Gives serialization ability to any inheriting dataclass.\"\"\"\n\n    def __post_init__(self):\n        self._validate_contracts()\n        for key, value in self.__dict__.items():\n            if value is no_default:\n                raise TypeError(f\"__init__ missing 1 required argument: '{key}'\")\n\n    def _validate_contracts(self):\n        dataclass_fields = fields(self)\n\n        for field in dataclass_fields:\n\n            value = getattr(self, field.name)\n\n            if value is None:\n                if not _is_optional_field(field):\n                    raise TypeError(f\"{field.name} is not optional\")\n\n            contract = field.metadata.get(\"contract\", None)\n\n            if contract is not None:\n                if value is not None and not contract(value):\n                    raise ValueError(f\"break the contract for {field.name}, {self.__class__.__name__}\")\n\n    def validate(self):\n        \"\"\"validate if object can serialize / deserialize correctly.\"\"\"\n        self._validate_contracts()\n        if self != self.__class__.deserialize(  # pylint: disable=no-value-for-parameter\n            json.loads(json.dumps(self.serialize()))\n        ):\n            raise ValueError(\"could not be deserialized with same value\")\n\n    def to_dict(self) -> dict:\n        \"\"\"Transform serializable object to dict.\"\"\"\n        cls_fields = fields(self)\n        o = {}\n        for cls_field in cls_fields:\n            o[cls_field.name] = getattr(self, cls_field.name)\n        return o\n\n    def serialize(self) -> dict:\n        \"\"\"Serialize object to be json serializable representation.\"\"\"\n        if not is_dataclass(self):\n            raise TypeError(\"need to be decorated as dataclass\")\n\n        dataclass_fields = fields(self)\n\n        o = {}\n\n        for field in dataclass_fields:\n            value = getattr(self, field.name)\n            value = _serialize(value)\n            o[field.name] = value\n        return o\n\n    def deserialize(self, data: dict) -> \"Serializable\":\n        \"\"\"Parse input dictionary and desrialize its fields to a dataclass.\n\n        Returns:\n            self: deserialized `self`.\n        \"\"\"\n        if not isinstance(data, dict):\n            raise ValueError()\n        data = data.copy()\n        init_kwargs = {}\n        for field in fields(self):\n            # if field.name == 'dataset_config':\n            if field.name not in data:\n                if field.name in vars(self):\n                    init_kwargs[field.name] = vars(self)[field.name]\n                    continue\n                raise ValueError(f' [!] Missing required field \"{field.name}\"')\n            value = data.get(field.name, _default_value(field))\n            if value is None:\n                init_kwargs[field.name] = value\n                continue\n            if value == MISSING:\n                raise ValueError(f\"deserialized with unknown value for {field.name} in {self.__name__}\")\n            value = _deserialize(value, field.type)\n            init_kwargs[field.name] = value\n        for k, v in init_kwargs.items():\n            setattr(self, k, v)\n        return self\n\n    @classmethod\n    def deserialize_immutable(cls, data: dict) -> \"Serializable\":\n        \"\"\"Parse input dictionary and desrialize its fields to a dataclass.\n\n        Returns:\n            Newly created deserialized object.\n        \"\"\"\n        if not isinstance(data, dict):\n            raise ValueError()\n        data = data.copy()\n        init_kwargs = {}\n        for field in fields(cls):\n            # if field.name == 'dataset_config':\n            if field.name not in data:\n                if field.name in vars(cls):\n                    init_kwargs[field.name] = vars(cls)[field.name]\n                    continue\n                # if not in cls and the default value is not Missing use it\n                default_value = _default_value(field)\n                if default_value not in (MISSING, _MISSING):\n                    init_kwargs[field.name] = default_value\n                    continue\n                raise ValueError(f' [!] Missing required field \"{field.name}\"')\n            value = data.get(field.name, _default_value(field))\n            if value is None:\n                init_kwargs[field.name] = value\n                continue\n            if value == MISSING:\n                raise ValueError(f\"Deserialized with unknown value for {field.name} in {cls.__name__}\")\n            value = _deserialize(value, field.type)\n            init_kwargs[field.name] = value\n        return cls(**init_kwargs)\n\n\n# ---------------------------------------------------------------------------- #\n#                        Argument Parsing from `argparse`                      #\n# ---------------------------------------------------------------------------- #\n\n\ndef _get_help(field):\n    try:\n        field_help = field.metadata[\"help\"]\n    except KeyError:\n        field_help = \"\"\n    return field_help\n\n\ndef _init_argparse(\n    parser,\n    field_name,\n    field_type,\n    field_default,\n    field_default_factory,\n    field_help,\n    arg_prefix=\"\",\n    help_prefix=\"\",\n    relaxed_parser=False,\n):\n    has_default = False\n    default = None\n    if field_default:\n        has_default = True\n        default = field_default\n    elif field_default_factory not in (None, _MISSING):\n        has_default = True\n        default = field_default_factory()\n\n    if not has_default and not is_primitive_type(field_type) and not is_list(field_type):\n        # aggregate types (fields with a Coqpit subclass as type) are not supported without None\n        return parser\n    arg_prefix = field_name if arg_prefix == \"\" else f\"{arg_prefix}.{field_name}\"\n    help_prefix = field_help if help_prefix == \"\" else f\"{help_prefix} - {field_help}\"\n    if is_dict(field_type):  # pylint: disable=no-else-raise\n        # NOTE: accept any string in json format as input to dict field.\n        parser.add_argument(\n            f\"--{arg_prefix}\",\n            dest=arg_prefix,\n            default=json.dumps(field_default) if field_default else None,\n            type=json.loads,\n        )\n    elif is_list(field_type):\n        # TODO: We need a more clear help msg for lists.\n        if hasattr(field_type, \"__args__\"):  # if the list is hinted\n            if len(field_type.__args__) > 1 and not relaxed_parser:\n                raise ValueError(\" [!] Coqpit does not support multi-type hinted 'List'\")\n            list_field_type = field_type.__args__[0]\n        else:\n            raise ValueError(\" [!] Coqpit does not support un-hinted 'List'\")\n\n        # TODO: handle list of lists\n        if is_list(list_field_type) and relaxed_parser:\n            return parser\n\n        if not has_default or field_default_factory is list:\n            if not is_primitive_type(list_field_type) and not relaxed_parser:\n                raise NotImplementedError(\" [!] Empty list with non primitive inner type is currently not supported.\")\n\n            # If the list's default value is None, the user can specify the entire list by passing multiple parameters\n            parser.add_argument(\n                f\"--{arg_prefix}\",\n                nargs=\"*\",\n                type=list_field_type,\n                help=f\"Coqpit Field: {help_prefix}\",\n            )\n        else:\n            # If a default value is defined, just enable editing the values from argparse\n            # TODO: allow inserting a new value/obj to the end of the list.\n            for idx, fv in enumerate(default):\n                parser = _init_argparse(\n                    parser,\n                    str(idx),\n                    list_field_type,\n                    fv,\n                    field_default_factory,\n                    field_help=\"\",\n                    help_prefix=f\"{help_prefix} - \",\n                    arg_prefix=f\"{arg_prefix}\",\n                    relaxed_parser=relaxed_parser,\n                )\n    elif is_union(field_type):\n        # TODO: currently I don't know how to handle Union type on argparse\n        if not relaxed_parser:\n            raise NotImplementedError(\n                \" [!] Parsing `Union` field from argparse is not yet implemented. Please create an issue.\"\n            )\n    elif issubclass(field_type, Serializable):\n        return default.init_argparse(\n            parser, arg_prefix=arg_prefix, help_prefix=help_prefix, relaxed_parser=relaxed_parser\n        )\n    elif isinstance(field_type(), bool):\n\n        def parse_bool(x):\n            if x not in (\"true\", \"false\"):\n                raise ValueError(f' [!] Value for boolean field must be either \"true\" or \"false\". Got \"{x}\".')\n            return x == \"true\"\n\n        parser.add_argument(\n            f\"--{arg_prefix}\",\n            type=parse_bool,\n            default=field_default,\n            help=f\"Coqpit Field: {help_prefix}\",\n            metavar=\"true/false\",\n        )\n    elif is_primitive_type(field_type):\n        parser.add_argument(\n            f\"--{arg_prefix}\",\n            default=field_default,\n            type=field_type,\n            help=f\"Coqpit Field: {help_prefix}\",\n        )\n    else:\n        if not relaxed_parser:\n            raise NotImplementedError(f\" [!] '{field_type}' is not supported by arg_parser. Please file a bug report.\")\n    return parser\n\n\n# ---------------------------------------------------------------------------- #\n#                               Main Coqpit Class                              #\n# ---------------------------------------------------------------------------- #\n\n\n@dataclass\nclass Coqpit(Serializable, MutableMapping):\n    \"\"\"Coqpit base class to be inherited by any Coqpit dataclasses.\n    It overrides Python `dict` interface and provides `dict` compatible API.\n    It also enables serializing/deserializing a dataclass to/from a json file, plus some semi-dynamic type and value check.\n    Note that it does not support all datatypes and likely to fail in some cases.\n    \"\"\"\n\n    _initialized = False\n\n    def _is_initialized(self):\n        \"\"\"Check if Coqpit is initialized. Useful to prevent running some aux functions\n        at the initialization when no attribute has been defined.\"\"\"\n        return \"_initialized\" in vars(self) and self._initialized\n\n    def __post_init__(self):\n        self._initialized = True\n        try:\n            self.check_values()\n        except AttributeError:\n            pass\n\n    ## `dict` API functions\n\n    def __iter__(self):\n        return iter(asdict(self))\n\n    def __len__(self):\n        return len(fields(self))\n\n    def __setitem__(self, arg: str, value: Any):\n        setattr(self, arg, value)\n\n    def __getitem__(self, arg: str):\n        \"\"\"Access class attributes with ``[arg]``.\"\"\"\n        return self.__dict__[arg]\n\n    def __delitem__(self, arg: str):\n        delattr(self, arg)\n\n    def _keytransform(self, key):  # pylint: disable=no-self-use\n        return key\n\n    ## end `dict` API functions\n\n    def __getattribute__(self, arg: str):  # pylint: disable=no-self-use\n        \"\"\"Check if the mandatory field is defined when accessing it.\"\"\"\n        value = super().__getattribute__(arg)\n        if isinstance(value, str) and value == \"???\":\n            raise AttributeError(f\" [!] MISSING field {arg} must be defined.\")\n        return value\n\n    def __contains__(self, arg: str):\n        return arg in self.to_dict()\n\n    def get(self, key: str, default: Any = None):\n        if self.has(key):\n            return asdict(self)[key]\n        return default\n\n    def items(self):\n        return asdict(self).items()\n\n    def merge(self, coqpits: Union[\"Coqpit\", List[\"Coqpit\"]]):\n        \"\"\"Merge a coqpit instance or a list of coqpit instances to self.\n        Note that it does not pass the fields and overrides attributes with\n        the last Coqpit instance in the given List.\n        TODO: find a way to merge instances with all the class internals.\n\n        Args:\n            coqpits (Union[Coqpit, List[Coqpit]]): coqpit instance or list of instances to be merged.\n        \"\"\"\n\n        def _merge(coqpit):\n            self.__dict__.update(coqpit.__dict__)\n            self.__annotations__.update(coqpit.__annotations__)\n            self.__dataclass_fields__.update(coqpit.__dataclass_fields__)\n\n        if isinstance(coqpits, list):\n            for coqpit in coqpits:\n                _merge(coqpit)\n        else:\n            _merge(coqpits)\n\n    def check_values(self):\n        pass\n\n    def has(self, arg: str) -> bool:\n        return arg in vars(self)\n\n    def copy(self):\n        return replace(self)\n\n    def update(self, new: dict, allow_new=False) -> None:\n        \"\"\"Update Coqpit fields by the input ```dict```.\n\n        Args:\n            new (dict): dictionary with new values.\n            allow_new (bool, optional): allow new fields to add. Defaults to False.\n        \"\"\"\n        for key, value in new.items():\n            if allow_new:\n                setattr(self, key, value)\n            else:\n                if hasattr(self, key):\n                    setattr(self, key, value)\n                else:\n                    raise KeyError(f\" [!] No key - {key}\")\n\n    def pprint(self) -> None:\n        \"\"\"Print Coqpit fields in a format.\"\"\"\n        pprint(asdict(self))\n\n    def to_dict(self) -> dict:\n        # return asdict(self)\n        return self.serialize()\n\n    def from_dict(self, data: dict) -> None:\n        self = self.deserialize(data)  # pylint: disable=self-cls-assignment\n\n    @classmethod\n    def new_from_dict(cls: Serializable, data: dict) -> \"Coqpit\":\n        return cls.deserialize_immutable(data)\n\n    def to_json(self) -> str:\n        \"\"\"Returns a JSON string representation.\"\"\"\n        return json.dumps(asdict(self), indent=4, default=_coqpit_json_default)\n\n    def save_json(self, file_name: str) -> None:\n        \"\"\"Save Coqpit to a json file.\n\n        Args:\n            file_name (str): path to the output json file.\n        \"\"\"\n        with open(file_name, \"w\", encoding=\"utf8\") as f:\n            json.dump(asdict(self), f, indent=4)\n\n    def load_json(self, file_name: str) -> None:\n        \"\"\"Load a json file and update matching config fields with type checking.\n        Non-matching parameters in the json file are ignored.\n\n        Args:\n            file_name (str): path to the json file.\n\n        Returns:\n            Coqpit: new Coqpit with updated config fields.\n        \"\"\"\n        with open(file_name, \"r\", encoding=\"utf8\") as f:\n            input_str = f.read()\n            dump_dict = json.loads(input_str)\n        # TODO: this looks stupid 💆\n        self = self.deserialize(dump_dict)  # pylint: disable=self-cls-assignment\n        self.check_values()\n\n    @classmethod\n    def init_from_argparse(\n        cls, args: Optional[Union[argparse.Namespace, List[str]]] = None, arg_prefix: str = \"coqpit\"\n    ) -> \"Coqpit\":\n        \"\"\"Create a new Coqpit instance from argparse input.\n\n        Args:\n            args (namespace or list of str, optional): parsed argparse.Namespace or list of command line parameters. If unspecified will use a newly created parser with ```init_argparse()```.\n            arg_prefix: prefix to add to CLI parameters. Gets forwarded to ```init_argparse``` when ```args``` is not passed.\n        \"\"\"\n        if not args:\n            # If args was not specified, parse from sys.argv\n            parser = cls.init_argparse(cls, arg_prefix=arg_prefix)\n            args = parser.parse_args()  # pylint: disable=E1120, E1111\n        if isinstance(args, list):\n            # If a list was passed in (eg. the second result of `parse_known_args`, run that through argparse first to get a parsed Namespace\n            parser = cls.init_argparse(cls, arg_prefix=arg_prefix)\n            args = parser.parse_args(args)  # pylint: disable=E1120, E1111\n\n        # Handle list and object attributes with defaults, which can be modified\n        # directly (eg. --coqpit.list.0.val_a 1), by constructing real objects\n        # from defaults and passing those to `cls.__init__`\n        args_with_lists_processed = {}\n        class_fields = fields(cls)\n        for field in class_fields:\n            has_default = False\n            default = None\n            field_default = field.default if field.default is not _MISSING else None\n            field_default_factory = field.default_factory if field.default_factory is not _MISSING else None\n            if field_default:\n                has_default = True\n                default = field_default\n            elif field_default_factory:\n                has_default = True\n                default = field_default_factory()\n\n            if has_default and (not is_primitive_type(field.type) or is_list(field.type)):\n                args_with_lists_processed[field.name] = default\n\n        args_dict = vars(args)\n        for k, v in args_dict.items():\n            # Remove argparse prefix (eg. \"--coqpit.\" if present)\n            if k.startswith(f\"{arg_prefix}.\"):\n                k = k[len(f\"{arg_prefix}.\") :]\n\n            rsetitem(args_with_lists_processed, k, v)\n\n        return cls(**args_with_lists_processed)\n\n    def parse_args(\n        self, args: Optional[Union[argparse.Namespace, List[str]]] = None, arg_prefix: str = \"coqpit\"\n    ) -> None:\n        \"\"\"Update config values from argparse arguments with some meta-programming ✨.\n\n        Args:\n            args (namespace or list of str, optional): parsed argparse.Namespace or list of command line parameters. If unspecified will use a newly created parser with ```init_argparse()```.\n            arg_prefix: prefix to add to CLI parameters. Gets forwarded to ```init_argparse``` when ```args``` is not passed.\n        \"\"\"\n        if not args:\n            # If args was not specified, parse from sys.argv\n            parser = self.init_argparse(arg_prefix=arg_prefix)\n            args = parser.parse_args()\n        if isinstance(args, list):\n            # If a list was passed in (eg. the second result of `parse_known_args`, run that through argparse first to get a parsed Namespace\n            parser = self.init_argparse(arg_prefix=arg_prefix)\n            args = parser.parse_args(args)\n\n        args_dict = vars(args)\n\n        for k, v in args_dict.items():\n            if k.startswith(f\"{arg_prefix}.\"):\n                k = k[len(f\"{arg_prefix}.\") :]\n            try:\n                rgetattr(self, k)\n            except (TypeError, AttributeError) as e:\n                raise Exception(f\" [!] '{k}' not exist to override from argparse.\") from e\n\n            rsetattr(self, k, v)\n\n        self.check_values()\n\n    def parse_known_args(\n        self,\n        args: Optional[Union[argparse.Namespace, List[str]]] = None,\n        arg_prefix: str = \"coqpit\",\n        relaxed_parser=False,\n    ) -> List[str]:\n        \"\"\"Update config values from argparse arguments. Ignore unknown arguments.\n           This is analog to argparse.ArgumentParser.parse_known_args (vs parse_args).\n\n        Args:\n            args (namespace or list of str, optional): parsed argparse.Namespace or list of command line parameters. If unspecified will use a newly created parser with ```init_argparse()```.\n            arg_prefix: prefix to add to CLI parameters. Gets forwarded to ```init_argparse``` when ```args``` is not passed.\n            relaxed_parser (bool, optional): If True, do not force all the fields to have compatible types with the argparser. Defaults to False.\n\n        Returns:\n            List of unknown parameters.\n        \"\"\"\n        if not args:\n            # If args was not specified, parse from sys.argv\n            parser = self.init_argparse(arg_prefix=arg_prefix, relaxed_parser=relaxed_parser)\n            args, unknown = parser.parse_known_args()\n        if isinstance(args, list):\n            # If a list was passed in (eg. the second result of `parse_known_args`, run that through argparse first to get a parsed Namespace\n            parser = self.init_argparse(arg_prefix=arg_prefix, relaxed_parser=relaxed_parser)\n            args, unknown = parser.parse_known_args(args)\n\n        self.parse_args(args)\n        return unknown\n\n    def init_argparse(\n        self,\n        parser: Optional[argparse.ArgumentParser] = None,\n        arg_prefix=\"coqpit\",\n        help_prefix=\"\",\n        relaxed_parser=False,\n    ) -> argparse.ArgumentParser:\n        \"\"\"Pass Coqpit fields as argparse arguments. This allows to edit values through command-line.\n\n        Args:\n            parser (argparse.ArgumentParser, optional): argparse.ArgumentParser instance. If unspecified a new one will be created.\n            arg_prefix (str, optional): Prefix to be used for the argument name. Defaults to 'coqpit'.\n            help_prefix (str, optional): Prefix to be used for the argument description. Defaults to ''.\n            relaxed_parser (bool, optional): If True, do not force all the fields to have compatible types with the argparser. Defaults to False.\n\n        Returns:\n            argparse.ArgumentParser: parser instance with the new arguments.\n        \"\"\"\n        if not parser:\n            parser = argparse.ArgumentParser()\n        class_fields = fields(self)\n        for field in class_fields:\n            if field.name in vars(self):\n                # use the current value of the field\n                # prevent dropping the current value\n                field_default = vars(self)[field.name]\n            else:\n                # use the default value of the field\n                field_default = field.default if field.default is not _MISSING else None\n            field_type = field.type\n            field_default_factory = field.default_factory\n            field_help = _get_help(field)\n            _init_argparse(\n                parser,\n                field.name,\n                field_type,\n                field_default,\n                field_default_factory,\n                field_help,\n                arg_prefix,\n                help_prefix,\n                relaxed_parser,\n            )\n        return parser\n\n\ndef check_argument(\n    name,\n    c,\n    is_path: bool = False,\n    prerequest: str = None,\n    enum_list: list = None,\n    max_val: float = None,\n    min_val: float = None,\n    restricted: bool = False,\n    alternative: str = None,\n    allow_none: bool = True,\n) -> None:\n    \"\"\"Simple type and value checking for Coqpit.\n    It is intended to be used under ```__post_init__()``` of config dataclasses.\n\n    Args:\n        name (str): name of the field to be checked.\n        c (dict): config dictionary.\n        is_path (bool, optional): if ```True``` check if the path is exist. Defaults to False.\n        prerequest (list or str, optional): a list of field name that are prerequestedby the target field name.\n            Defaults to ```[]```.\n        enum_list (list, optional): list of possible values for the target field. Defaults to None.\n        max_val (float, optional): maximum possible value for the target field. Defaults to None.\n        min_val (float, optional): minimum possible value for the target field. Defaults to None.\n        restricted (bool, optional): if ```True``` the target field has to be defined. Defaults to False.\n        alternative (str, optional): a field name superceding the target field. Defaults to None.\n        allow_none (bool, optional): if ```True``` allow the target field to be ```None```. Defaults to False.\n\n\n    Example:\n        >>> num_mels = 5\n        >>> check_argument('num_mels', c, restricted=True, min_val=10, max_val=2056)\n        >>> fft_size = 128\n        >>> check_argument('fft_size', c, restricted=True, min_val=128, max_val=4058)\n    \"\"\"\n    # check if None allowed\n    if allow_none and c[name] is None:\n        return\n    if not allow_none:\n        assert c[name] is not None, f\" [!] None value is not allowed for {name}.\"\n    # check if restricted and it it is check if it exists\n    if isinstance(restricted, bool) and restricted:\n        assert name in c.keys(), f\" [!] {name} not defined in config.json\"\n    # check prerequest fields are defined\n    if isinstance(prerequest, list):\n        assert any(\n            f not in c.keys() for f in prerequest\n        ), f\" [!] prequested fields {prerequest} for {name} are not defined.\"\n    else:\n        assert (\n            prerequest is None or prerequest in c.keys()\n        ), f\" [!] prequested fields {prerequest} for {name} are not defined.\"\n    # check if the path exists\n    if is_path:\n        assert os.path.exists(c[name]), f' [!] path for {name} (\"{c[name]}\") does not exist.'\n    # skip the rest if the alternative field is defined.\n    if alternative in c.keys() and c[alternative] is not None:\n        return\n    # check value constraints\n    if name in c.keys():\n        if max_val is not None:\n            assert c[name] <= max_val, f\" [!] {name} is larger than max value {max_val}\"\n        if min_val is not None:\n            assert c[name] >= min_val, f\" [!] {name} is smaller than min value {min_val}\"\n        if enum_list is not None:\n            assert c[name].lower() in enum_list, f\" [!] {name} is not a valid value\"\n"
  },
  {
    "path": "speaker/utils/io.py",
    "content": "import datetime\nimport json\nimport os\nimport pickle as pickle_tts\nimport shutil\nfrom typing import Any, Callable, Dict, Union\n\nimport fsspec\nimport torch\nfrom .coqpit import Coqpit\n\n\nclass RenamingUnpickler(pickle_tts.Unpickler):\n    \"\"\"Overload default pickler to solve module renaming problem\"\"\"\n\n    def find_class(self, module, name):\n        return super().find_class(module.replace(\"mozilla_voice_tts\", \"TTS\"), name)\n\n\nclass AttrDict(dict):\n    \"\"\"A custom dict which converts dict keys\n    to class attributes\"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.__dict__ = self\n\n\ndef copy_model_files(config: Coqpit, out_path, new_fields):\n    \"\"\"Copy config.json and other model files to training folder and add\n    new fields.\n\n    Args:\n        config (Coqpit): Coqpit config defining the training run.\n        out_path (str): output path to copy the file.\n        new_fields (dict): new fileds to be added or edited\n            in the config file.\n    \"\"\"\n    copy_config_path = os.path.join(out_path, \"config.json\")\n    # add extra information fields\n    config.update(new_fields, allow_new=True)\n    # TODO: Revert to config.save_json() once Coqpit supports arbitrary paths.\n    with fsspec.open(copy_config_path, \"w\", encoding=\"utf8\") as f:\n        json.dump(config.to_dict(), f, indent=4)\n\n    # copy model stats file if available\n    if config.audio.stats_path is not None:\n        copy_stats_path = os.path.join(out_path, \"scale_stats.npy\")\n        filesystem = fsspec.get_mapper(copy_stats_path).fs\n        if not filesystem.exists(copy_stats_path):\n            with fsspec.open(config.audio.stats_path, \"rb\") as source_file:\n                with fsspec.open(copy_stats_path, \"wb\") as target_file:\n                    shutil.copyfileobj(source_file, target_file)\n\n\ndef load_fsspec(\n    path: str,\n    map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,\n    **kwargs,\n) -> Any:\n    \"\"\"Like torch.load but can load from other locations (e.g. s3:// , gs://).\n\n    Args:\n        path: Any path or url supported by fsspec.\n        map_location: torch.device or str.\n        **kwargs: Keyword arguments forwarded to torch.load.\n\n    Returns:\n        Object stored in path.\n    \"\"\"\n    with fsspec.open(path, \"rb\") as f:\n        return torch.load(f, map_location=map_location, **kwargs)\n\n\ndef load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False):  # pylint: disable=redefined-builtin\n    try:\n        state = load_fsspec(checkpoint_path, map_location=torch.device(\"cpu\"))\n    except ModuleNotFoundError:\n        pickle_tts.Unpickler = RenamingUnpickler\n        state = load_fsspec(checkpoint_path, map_location=torch.device(\"cpu\"), pickle_module=pickle_tts)\n    model.load_state_dict(state[\"model\"])\n    if use_cuda:\n        model.cuda()\n    if eval:\n        model.eval()\n    return model, state\n\n\ndef save_fsspec(state: Any, path: str, **kwargs):\n    \"\"\"Like torch.save but can save to other locations (e.g. s3:// , gs://).\n\n    Args:\n        state: State object to save\n        path: Any path or url supported by fsspec.\n        **kwargs: Keyword arguments forwarded to torch.save.\n    \"\"\"\n    with fsspec.open(path, \"wb\") as f:\n        torch.save(state, f, **kwargs)\n\n\ndef save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs):\n    if hasattr(model, \"module\"):\n        model_state = model.module.state_dict()\n    else:\n        model_state = model.state_dict()\n    if isinstance(optimizer, list):\n        optimizer_state = [optim.state_dict() for optim in optimizer]\n    else:\n        optimizer_state = optimizer.state_dict() if optimizer is not None else None\n\n    if isinstance(scaler, list):\n        scaler_state = [s.state_dict() for s in scaler]\n    else:\n        scaler_state = scaler.state_dict() if scaler is not None else None\n\n    if isinstance(config, Coqpit):\n        config = config.to_dict()\n\n    state = {\n        \"config\": config,\n        \"model\": model_state,\n        \"optimizer\": optimizer_state,\n        \"scaler\": scaler_state,\n        \"step\": current_step,\n        \"epoch\": epoch,\n        \"date\": datetime.date.today().strftime(\"%B %d, %Y\"),\n    }\n    state.update(kwargs)\n    save_fsspec(state, output_path)\n\n\ndef save_checkpoint(\n    config,\n    model,\n    optimizer,\n    scaler,\n    current_step,\n    epoch,\n    output_folder,\n    **kwargs,\n):\n    file_name = \"checkpoint_{}.pth.tar\".format(current_step)\n    checkpoint_path = os.path.join(output_folder, file_name)\n    print(\"\\n > CHECKPOINT : {}\".format(checkpoint_path))\n    save_model(\n        config,\n        model,\n        optimizer,\n        scaler,\n        current_step,\n        epoch,\n        checkpoint_path,\n        **kwargs,\n    )\n\n\ndef save_best_model(\n    current_loss,\n    best_loss,\n    config,\n    model,\n    optimizer,\n    scaler,\n    current_step,\n    epoch,\n    out_path,\n    keep_all_best=False,\n    keep_after=10000,\n    **kwargs,\n):\n    if current_loss < best_loss:\n        best_model_name = f\"best_model_{current_step}.pth.tar\"\n        checkpoint_path = os.path.join(out_path, best_model_name)\n        print(\" > BEST MODEL : {}\".format(checkpoint_path))\n        save_model(\n            config,\n            model,\n            optimizer,\n            scaler,\n            current_step,\n            epoch,\n            checkpoint_path,\n            model_loss=current_loss,\n            **kwargs,\n        )\n        fs = fsspec.get_mapper(out_path).fs\n        # only delete previous if current is saved successfully\n        if not keep_all_best or (current_step < keep_after):\n            model_names = fs.glob(os.path.join(out_path, \"best_model*.pth.tar\"))\n            for model_name in model_names:\n                if os.path.basename(model_name) != best_model_name:\n                    fs.rm(model_name)\n        # create a shortcut which always points to the currently best model\n        shortcut_name = \"best_model.pth.tar\"\n        shortcut_path = os.path.join(out_path, shortcut_name)\n        fs.copy(checkpoint_path, shortcut_path)\n        best_loss = current_loss\n    return best_loss\n"
  },
  {
    "path": "speaker/utils/shared_configs.py",
    "content": "from dataclasses import asdict, dataclass\nfrom typing import List\n\nfrom .coqpit import Coqpit, check_argument\n\n\n@dataclass\nclass BaseAudioConfig(Coqpit):\n    \"\"\"Base config to definge audio processing parameters. It is used to initialize\n    ```TTS.utils.audio.AudioProcessor.```\n\n    Args:\n        fft_size (int):\n            Number of STFT frequency levels aka.size of the linear spectogram frame. Defaults to 1024.\n\n        win_length (int):\n            Each frame of audio is windowed by window of length ```win_length``` and then padded with zeros to match\n            ```fft_size```. Defaults to 1024.\n\n        hop_length (int):\n            Number of audio samples between adjacent STFT columns. Defaults to 1024.\n\n        frame_shift_ms (int):\n            Set ```hop_length``` based on milliseconds and sampling rate.\n\n        frame_length_ms (int):\n            Set ```win_length``` based on milliseconds and sampling rate.\n\n        stft_pad_mode (str):\n            Padding method used in STFT. 'reflect' or 'center'. Defaults to 'reflect'.\n\n        sample_rate (int):\n            Audio sampling rate. Defaults to 22050.\n\n        resample (bool):\n            Enable / Disable resampling audio to ```sample_rate```. Defaults to ```False```.\n\n        preemphasis (float):\n            Preemphasis coefficient. Defaults to 0.0.\n\n        ref_level_db (int): 20\n            Reference Db level to rebase the audio signal and ignore the level below. 20Db is assumed the sound of air.\n            Defaults to 20.\n\n        do_sound_norm (bool):\n            Enable / Disable sound normalization to reconcile the volume differences among samples. Defaults to False.\n\n        log_func (str):\n            Numpy log function used for amplitude to DB conversion. Defaults to 'np.log10'.\n\n        do_trim_silence (bool):\n            Enable / Disable trimming silences at the beginning and the end of the audio clip. Defaults to ```True```.\n\n        do_amp_to_db_linear (bool, optional):\n            enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True.\n\n        do_amp_to_db_mel (bool, optional):\n            enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.\n\n        trim_db (int):\n            Silence threshold used for silence trimming. Defaults to 45.\n\n        power (float):\n            Exponent used for expanding spectrogra levels before running Griffin Lim. It helps to reduce the\n            artifacts in the synthesized voice. Defaults to 1.5.\n\n        griffin_lim_iters (int):\n            Number of Griffing Lim iterations. Defaults to 60.\n\n        num_mels (int):\n            Number of mel-basis frames that defines the frame lengths of each mel-spectrogram frame. Defaults to 80.\n\n        mel_fmin (float): Min frequency level used for the mel-basis filters. ~50 for male and ~95 for female voices.\n            It needs to be adjusted for a dataset. Defaults to 0.\n\n        mel_fmax (float):\n            Max frequency level used for the mel-basis filters. It needs to be adjusted for a dataset.\n\n        spec_gain (int):\n            Gain applied when converting amplitude to DB. Defaults to 20.\n\n        signal_norm (bool):\n            enable/disable signal normalization. Defaults to True.\n\n        min_level_db (int):\n            minimum db threshold for the computed melspectrograms. Defaults to -100.\n\n        symmetric_norm (bool):\n            enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else\n            [0, k], Defaults to True.\n\n        max_norm (float):\n            ```k``` defining the normalization range. Defaults to 4.0.\n\n        clip_norm (bool):\n            enable/disable clipping the our of range values in the normalized audio signal. Defaults to True.\n\n        stats_path (str):\n            Path to the computed stats file. Defaults to None.\n    \"\"\"\n\n    # stft parameters\n    fft_size: int = 1024\n    win_length: int = 1024\n    hop_length: int = 256\n    frame_shift_ms: int = None\n    frame_length_ms: int = None\n    stft_pad_mode: str = \"reflect\"\n    # audio processing parameters\n    sample_rate: int = 22050\n    resample: bool = False\n    preemphasis: float = 0.0\n    ref_level_db: int = 20\n    do_sound_norm: bool = False\n    log_func: str = \"np.log10\"\n    # silence trimming\n    do_trim_silence: bool = True\n    trim_db: int = 45\n    # griffin-lim params\n    power: float = 1.5\n    griffin_lim_iters: int = 60\n    # mel-spec params\n    num_mels: int = 80\n    mel_fmin: float = 0.0\n    mel_fmax: float = None\n    spec_gain: int = 20\n    do_amp_to_db_linear: bool = True\n    do_amp_to_db_mel: bool = True\n    # normalization params\n    signal_norm: bool = True\n    min_level_db: int = -100\n    symmetric_norm: bool = True\n    max_norm: float = 4.0\n    clip_norm: bool = True\n    stats_path: str = None\n\n    def check_values(\n        self,\n    ):\n        \"\"\"Check config fields\"\"\"\n        c = asdict(self)\n        check_argument(\"num_mels\", c, restricted=True, min_val=10, max_val=2056)\n        check_argument(\"fft_size\", c, restricted=True, min_val=128, max_val=4058)\n        check_argument(\"sample_rate\", c, restricted=True, min_val=512, max_val=100000)\n        check_argument(\n            \"frame_length_ms\",\n            c,\n            restricted=True,\n            min_val=10,\n            max_val=1000,\n            alternative=\"win_length\",\n        )\n        check_argument(\"frame_shift_ms\", c, restricted=True, min_val=1, max_val=1000, alternative=\"hop_length\")\n        check_argument(\"preemphasis\", c, restricted=True, min_val=0, max_val=1)\n        check_argument(\"min_level_db\", c, restricted=True, min_val=-1000, max_val=10)\n        check_argument(\"ref_level_db\", c, restricted=True, min_val=0, max_val=1000)\n        check_argument(\"power\", c, restricted=True, min_val=1, max_val=5)\n        check_argument(\"griffin_lim_iters\", c, restricted=True, min_val=10, max_val=1000)\n\n        # normalization parameters\n        check_argument(\"signal_norm\", c, restricted=True)\n        check_argument(\"symmetric_norm\", c, restricted=True)\n        check_argument(\"max_norm\", c, restricted=True, min_val=0.1, max_val=1000)\n        check_argument(\"clip_norm\", c, restricted=True)\n        check_argument(\"mel_fmin\", c, restricted=True, min_val=0.0, max_val=1000)\n        check_argument(\"mel_fmax\", c, restricted=True, min_val=500.0, allow_none=True)\n        check_argument(\"spec_gain\", c, restricted=True, min_val=1, max_val=100)\n        check_argument(\"do_trim_silence\", c, restricted=True)\n        check_argument(\"trim_db\", c, restricted=True)\n\n\n@dataclass\nclass BaseDatasetConfig(Coqpit):\n    \"\"\"Base config for TTS datasets.\n\n    Args:\n        name (str):\n            Dataset name that defines the preprocessor in use. Defaults to None.\n\n        path (str):\n            Root path to the dataset files. Defaults to None.\n\n        meta_file_train (str):\n            Name of the dataset meta file. Or a list of speakers to be ignored at training for multi-speaker datasets.\n            Defaults to None.\n\n        unused_speakers (List):\n            List of speakers IDs that are not used at the training. Default None.\n\n        meta_file_val (str):\n            Name of the dataset meta file that defines the instances used at validation.\n\n        meta_file_attn_mask (str):\n            Path to the file that lists the attention mask files used with models that require attention masks to\n            train the duration predictor.\n    \"\"\"\n\n    name: str = \"\"\n    path: str = \"\"\n    meta_file_train: str = \"\"\n    ununsed_speakers: List[str] = None\n    meta_file_val: str = \"\"\n    meta_file_attn_mask: str = \"\"\n\n    def check_values(\n        self,\n    ):\n        \"\"\"Check config fields\"\"\"\n        c = asdict(self)\n        check_argument(\"name\", c, restricted=True)\n        check_argument(\"path\", c, restricted=True)\n        check_argument(\"meta_file_train\", c, restricted=True)\n        check_argument(\"meta_file_val\", c, restricted=False)\n        check_argument(\"meta_file_attn_mask\", c, restricted=False)\n\n\n@dataclass\nclass BaseTrainingConfig(Coqpit):\n    \"\"\"Base config to define the basic training parameters that are shared\n    among all the models.\n\n    Args:\n        model (str):\n            Name of the model that is used in the training.\n\n        run_name (str):\n            Name of the experiment. This prefixes the output folder name. Defaults to `coqui_tts`.\n\n        run_description (str):\n            Short description of the experiment.\n\n        epochs (int):\n            Number training epochs. Defaults to 10000.\n\n        batch_size (int):\n            Training batch size.\n\n        eval_batch_size (int):\n            Validation batch size.\n\n        mixed_precision (bool):\n            Enable / Disable mixed precision training. It reduces the VRAM use and allows larger batch sizes, however\n            it may also cause numerical unstability in some cases.\n\n        scheduler_after_epoch (bool):\n            If true, run the scheduler step after each epoch else run it after each model step.\n\n        run_eval (bool):\n            Enable / Disable evaluation (validation) run. Defaults to True.\n\n        test_delay_epochs (int):\n            Number of epochs before starting to use evaluation runs. Initially, models do not generate meaningful\n            results, hence waiting for a couple of epochs might save some time.\n\n        print_eval (bool):\n            Enable / Disable console logging for evalutaion steps. If disabled then it only shows the final values at\n            the end of the evaluation. Default to ```False```.\n\n        print_step (int):\n            Number of steps required to print the next training log.\n\n        log_dashboard (str): \"tensorboard\" or \"wandb\"\n            Set the experiment tracking tool\n\n        plot_step (int):\n            Number of steps required to log training on Tensorboard.\n\n        model_param_stats (bool):\n            Enable / Disable logging internal model stats for model diagnostic. It might be useful for model debugging.\n            Defaults to ```False```.\n\n        project_name (str):\n            Name of the project. Defaults to config.model\n\n        wandb_entity (str):\n            Name of W&B entity/team. Enables collaboration across a team or org.\n\n        log_model_step (int):\n            Number of steps required to log a checkpoint as W&B artifact\n\n        save_step (int):ipt\n            Number of steps required to save the next checkpoint.\n\n        checkpoint (bool):\n            Enable / Disable checkpointing.\n\n        keep_all_best (bool):\n            Enable / Disable keeping all the saved best models instead of overwriting the previous one. Defaults\n            to ```False```.\n\n        keep_after (int):\n            Number of steps to wait before saving all the best models. In use if ```keep_all_best == True```. Defaults\n            to 10000.\n\n        num_loader_workers (int):\n            Number of workers for training time dataloader.\n\n        num_eval_loader_workers (int):\n            Number of workers for evaluation time dataloader.\n\n        output_path (str):\n            Path for training output folder, either a local file path or other\n            URLs supported by both fsspec and tensorboardX, e.g. GCS (gs://) or\n            S3 (s3://) paths. The nonexist part of the given path is created\n            automatically. All training artefacts are saved there.\n    \"\"\"\n\n    model: str = None\n    run_name: str = \"coqui_tts\"\n    run_description: str = \"\"\n    # training params\n    epochs: int = 10000\n    batch_size: int = None\n    eval_batch_size: int = None\n    mixed_precision: bool = False\n    scheduler_after_epoch: bool = False\n    # eval params\n    run_eval: bool = True\n    test_delay_epochs: int = 0\n    print_eval: bool = False\n    # logging\n    dashboard_logger: str = \"tensorboard\"\n    print_step: int = 25\n    plot_step: int = 100\n    model_param_stats: bool = False\n    project_name: str = None\n    log_model_step: int = None\n    wandb_entity: str = None\n    # checkpointing\n    save_step: int = 10000\n    checkpoint: bool = True\n    keep_all_best: bool = False\n    keep_after: int = 10000\n    # dataloading\n    num_loader_workers: int = 0\n    num_eval_loader_workers: int = 0\n    use_noise_augment: bool = False\n    # paths\n    output_path: str = None\n    # distributed\n    distributed_backend: str = \"nccl\"\n    distributed_url: str = \"tcp://localhost:54321\"\n"
  },
  {
    "path": "speaker_pretrain/README.md",
    "content": "Path for:\n\n    best_model.pth.tar\n\n    config.json"
  },
  {
    "path": "speaker_pretrain/config.json",
    "content": "{\n    \"model_name\": \"lstm\",\n    \"run_name\": \"mueller91\",\n    \"run_description\": \"train speaker encoder with voxceleb1, voxceleb2 and libriSpeech \",\n    \"audio\":{\n        // Audio processing parameters\n        \"num_mels\": 80,         // size of the mel spec frame.\n        \"fft_size\": 1024,       // number of stft frequency levels. Size of the linear spectogram frame.\n        \"sample_rate\": 16000,   // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.\n        \"win_length\": 1024,     // stft window length in ms.\n        \"hop_length\": 256,      // stft window hop-lengh in ms.\n        \"frame_length_ms\": null,  // stft window length in ms.If null, 'win_length' is used.\n        \"frame_shift_ms\": null,   // stft window hop-lengh in ms. If null, 'hop_length' is used.\n        \"preemphasis\": 0.98,    // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.\n        \"min_level_db\": -100,   // normalization range\n        \"ref_level_db\": 20,     // reference level db, theoretically 20db is the sound of air.\n        \"power\": 1.5,           // value to sharpen wav signals after GL algorithm.\n        \"griffin_lim_iters\": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.\n        // Normalization parameters\n        \"signal_norm\": true,    // normalize the spec values in range [0, 1]\n        \"symmetric_norm\": true, // move normalization to range [-1, 1]\n        \"max_norm\": 4.0,          // scale normalization to range [-max_norm, max_norm] or [0, max_norm]\n        \"clip_norm\": true,      // clip normalized values into the range.\n        \"mel_fmin\": 0.0,         // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!\n        \"mel_fmax\": 8000.0,        // maximum freq level for mel-spec. Tune for dataset!!\n        \"do_trim_silence\": true,  // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)\n        \"trim_db\": 60          // threshold for timming silence. Set this according to your dataset.\n    },\n    \"reinit_layers\": [],\n    \"loss\": \"angleproto\", // \"ge2e\" to use Generalized End-to-End loss and \"angleproto\" to use Angular Prototypical loss (new SOTA)\n    \"grad_clip\": 3.0, // upper limit for gradients for clipping.\n    \"epochs\": 1000, // total number of epochs to train.\n    \"lr\": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.\n    \"lr_decay\": false, // if true, Noam learning rate decaying is applied through training.\n    \"warmup_steps\": 4000, // Noam decay steps to increase the learning rate from 0 to \"lr\"\n    \"tb_model_param_stats\": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. \n    \"steps_plot_stats\": 10, // number of steps to plot embeddings.\n    \"num_speakers_in_batch\": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.\n    \"voice_len\": 2.0, // size of the voice\n    \"num_utters_per_speaker\": 10,  //\n    \"num_loader_workers\": 8,        // number of training data loader processes. Don't set it too big. 4-8 are good values.\n    \"wd\": 0.000001, // Weight decay weight.\n    \"checkpoint\": true, // If true, it saves checkpoints per \"save_step\"\n    \"save_step\": 1000, // Number of training steps expected to save traning stats and checkpoints.\n    \"print_step\": 20, // Number of steps to log traning on console.\n    \"output_path\": \"../../OutputsMozilla/checkpoints/speaker_encoder/\", // DATASET-RELATED: output path for all training outputs.\n    \"model\": {\n        \"input_dim\": 80,\n        \"proj_dim\": 256,\n        \"lstm_dim\": 768,\n        \"num_lstm_layers\": 3,\n        \"use_lstm_with_projection\": true\n    },\n    \"storage\": {\n        \"sample_from_storage_p\": 0.9,  // the probability with which we'll sample from the DataSet in-memory storage\n        \"storage_size\": 25,   // the size of the in-memory storage with respect to a single batch\n        \"additive_noise\": 1e-5   // add very small gaussian noise to the data in order to increase robustness\n    },\n    \"datasets\": \n        [\n            {\n                \"name\": \"vctk_slim\",\n                \"path\": \"../../../audio-datasets/en/VCTK-Corpus/\",\n                \"meta_file_train\": null,\n                \"meta_file_val\": null\n            },\n            {\n                \"name\": \"libri_tts\",\n                \"path\": \"../../../audio-datasets/en/LibriTTS/train-clean-100\",\n                \"meta_file_train\": null,\n                \"meta_file_val\": null\n            },\n            {\n                \"name\": \"libri_tts\",\n                \"path\": \"../../../audio-datasets/en/LibriTTS/train-clean-360\",\n                \"meta_file_train\": null,\n                \"meta_file_val\": null\n            },\n            {\n                \"name\": \"libri_tts\",\n                \"path\": \"../../../audio-datasets/en/LibriTTS/train-other-500\",\n                \"meta_file_train\": null,\n                \"meta_file_val\": null\n            },\n            {\n                \"name\": \"voxceleb1\",\n                \"path\": \"../../../audio-datasets/en/voxceleb1/\",\n                \"meta_file_train\": null,\n                \"meta_file_val\": null\n            },\n            {\n                \"name\": \"voxceleb2\",\n                \"path\": \"../../../audio-datasets/en/voxceleb2/\",\n                \"meta_file_train\": null,\n                \"meta_file_val\": null\n            },\n            {\n                \"name\": \"common_voice\",\n                \"path\": \"../../../audio-datasets/en/MozillaCommonVoice\",\n                \"meta_file_train\": \"train.tsv\",\n                \"meta_file_val\": \"test.tsv\"\n            }\n        ]\n}"
  },
  {
    "path": "spec/inference.py",
    "content": "import argparse\nimport torch\nimport torch.utils.data\nimport numpy as np\nimport librosa\nfrom omegaconf import OmegaConf\nfrom librosa.filters import mel as librosa_mel_fn\n\n\nMAX_WAV_VALUE = 32768.0\n\n\ndef load_wav_to_torch(full_path, sample_rate):\n    wav, _ = librosa.load(full_path, sr=sample_rate)\n    wav = wav / np.abs(wav).max() * 0.6\n    return torch.FloatTensor(wav)\n\n\ndef dynamic_range_compression(x, C=1, clip_val=1e-5):\n    return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)\n\n\ndef dynamic_range_decompression(x, C=1):\n    return np.exp(x) / C\n\n\ndef dynamic_range_compression_torch(x, C=1, clip_val=1e-5):\n    return torch.log(torch.clamp(x, min=clip_val) * C)\n\n\ndef dynamic_range_decompression_torch(x, C=1):\n    return torch.exp(x) / C\n\n\ndef spectral_normalize_torch(magnitudes):\n    output = dynamic_range_compression_torch(magnitudes)\n    return output\n\n\ndef spectral_de_normalize_torch(magnitudes):\n    output = dynamic_range_decompression_torch(magnitudes)\n    return output\n\n\nmel_basis = {}\nhann_window = {}\n\n\ndef mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):\n    if torch.min(y) < -1.:\n        print('min value is ', torch.min(y))\n    if torch.max(y) > 1.:\n        print('max value is ', torch.max(y))\n\n    global mel_basis, hann_window\n    if fmax not in mel_basis:\n        mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)\n        mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)\n        hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)\n\n    y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')\n    y = y.squeeze(1)\n\n    # complex tensor as default, then use view_as_real for future pytorch compatibility\n    spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],\n                      center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)\n    spec = torch.view_as_real(spec)\n    spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))\n\n    spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)\n    spec = spectral_normalize_torch(spec)\n\n    return spec\n\n\ndef mel_spectrogram_file(path, hps):\n    audio = load_wav_to_torch(path, hps.data.sampling_rate)\n    audio = audio.unsqueeze(0)\n\n    # match audio length to self.hop_length * n for evaluation\n    if (audio.size(1) % hps.data.hop_length) != 0:\n        audio = audio[:, :-(audio.size(1) % hps.data.hop_length)]\n    mel = mel_spectrogram(audio, hps.data.filter_length, hps.data.mel_channels, hps.data.sampling_rate,\n                          hps.data.hop_length, hps.data.win_length, hps.data.mel_fmin, hps.data.mel_fmax, center=False)\n    return mel\n\n\ndef print_mel(mel, path=\"mel.png\"):\n    import matplotlib.pyplot as plt\n    fig = plt.figure(figsize=(12, 4))\n    if isinstance(mel, torch.Tensor):\n        mel = mel.cpu().numpy()\n    plt.pcolor(mel)\n    plt.savefig(path, format=\"png\")\n    plt.close(fig)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"-w\", \"--wav\", help=\"wav\", dest=\"wav\")\n    parser.add_argument(\"-m\", \"--mel\", help=\"mel\", dest=\"mel\")  # csv for excel\n    args = parser.parse_args()\n    print(args.wav)\n    print(args.mel)\n\n    hps = OmegaConf.load(f\"./configs/base.yaml\")\n\n    mel = mel_spectrogram_file(args.wav, hps)\n    # TODO\n    mel = torch.squeeze(mel, 0)\n    # [100, length]\n    torch.save(mel, args.mel)\n    print_mel(mel, \"debug.mel.png\")\n"
  }
]