Full Code of Plachtaa/VALL-E-X for AI

master 3faaf8ccadb1 cached
126 files
18.9 MB
5.0M tokens
341 symbols
1 requests
Copy disabled (too large) Download .txt
Showing preview only (19,851K chars total). Download the full file to get everything.
Repository: Plachtaa/VALL-E-X
Branch: master
Commit: 3faaf8ccadb1
Files: 126
Total size: 18.9 MB

Directory structure:
gitextract_bmerrgqi/

├── LICENSE
├── README-ZH.md
├── README.md
├── customs/
│   └── ph.txt
├── data/
│   ├── __init__.py
│   ├── collation.py
│   ├── datamodule.py
│   ├── dataset.py
│   ├── fbank.py
│   ├── input_strategies.py
│   └── tokenizer.py
├── descriptions.py
├── examples.py
├── launch-ui.py
├── macros.py
├── model-card.md
├── models/
│   ├── __init__.py
│   ├── macros.py
│   ├── transformer.py
│   ├── vallex.py
│   └── visualizer.py
├── modules/
│   ├── __init__.py
│   ├── activation.py
│   ├── embedding.py
│   ├── optim.py
│   ├── scaling.py
│   ├── scheduler.py
│   └── transformer.py
├── nltk_data/
│   └── tokenizers/
│       └── punkt/
│           ├── PY3/
│           │   ├── README
│           │   ├── czech.pickle
│           │   ├── danish.pickle
│           │   ├── dutch.pickle
│           │   ├── english.pickle
│           │   ├── estonian.pickle
│           │   ├── finnish.pickle
│           │   ├── french.pickle
│           │   ├── german.pickle
│           │   ├── greek.pickle
│           │   ├── italian.pickle
│           │   ├── malayalam.pickle
│           │   ├── norwegian.pickle
│           │   ├── polish.pickle
│           │   ├── portuguese.pickle
│           │   ├── russian.pickle
│           │   ├── slovene.pickle
│           │   ├── spanish.pickle
│           │   ├── swedish.pickle
│           │   └── turkish.pickle
│           ├── README
│           ├── czech.pickle
│           ├── danish.pickle
│           ├── dutch.pickle
│           ├── english.pickle
│           ├── estonian.pickle
│           ├── finnish.pickle
│           ├── french.pickle
│           ├── german.pickle
│           ├── greek.pickle
│           ├── italian.pickle
│           ├── malayalam.pickle
│           ├── norwegian.pickle
│           ├── polish.pickle
│           ├── portuguese.pickle
│           ├── russian.pickle
│           ├── slovene.pickle
│           ├── spanish.pickle
│           ├── swedish.pickle
│           └── turkish.pickle
├── presets/
│   ├── acou_1.npz
│   ├── acou_2.npz
│   ├── acou_3.npz
│   ├── acou_4.npz
│   ├── alan.npz
│   ├── amused.npz
│   ├── anger.npz
│   ├── babara.npz
│   ├── bronya.npz
│   ├── cafe.npz
│   ├── dingzhen.npz
│   ├── disgust.npz
│   ├── emo_amused.npz
│   ├── emo_anger.npz
│   ├── emo_neutral.npz
│   ├── emo_sleepy.npz
│   ├── emotion_sleepiness.npz
│   ├── en2zh_tts_1.npz
│   ├── en2zh_tts_2.npz
│   ├── en2zh_tts_3.npz
│   ├── en2zh_tts_4.npz
│   ├── esta.npz
│   ├── fuxuan.npz
│   ├── librispeech_1.npz
│   ├── librispeech_2.npz
│   ├── librispeech_3.npz
│   ├── librispeech_4.npz
│   ├── neutral.npz
│   ├── paimon.npz
│   ├── rosalia.npz
│   ├── seel.npz
│   ├── sleepiness.npz
│   ├── vctk_1.npz
│   ├── vctk_2.npz
│   ├── vctk_3.npz
│   ├── vctk_4.npz
│   ├── yaesakura.npz
│   ├── zh2en_tts_1.npz
│   ├── zh2en_tts_2.npz
│   ├── zh2en_tts_3.npz
│   └── zh2en_tts_4.npz
├── prompts/
│   ├── ja-2.ogg
│   └── ph.txt
├── requirements.txt
└── utils/
    ├── __init__.py
    ├── download.py
    ├── g2p/
    │   ├── __init__.py
    │   ├── bpe_1024.json
    │   ├── bpe_69.json
    │   ├── cleaners.py
    │   ├── english.py
    │   ├── japanese.py
    │   ├── mandarin.py
    │   └── symbols.py
    ├── generation.py
    ├── prompt_making.py
    ├── sentence_cutter.py
    └── symbol_table.py

================================================
FILE CONTENTS
================================================

================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2023 Songting

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README-ZH.md
================================================
# VALL-E X: 多语言文本到语音合成与语音克隆 🔊
[![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/qCBRmAnTxg)
<br>
[English](README.md) | 中文
<br>
微软[VALL-E X](https://arxiv.org/pdf/2303.03926) 零样本语音合成模型的开源实现.<br>
**预训练模型现已向公众开放,供研究或应用使用。**
![vallex-framework](/images/vallex_framework.jpg "VALL-E X framework")

VALL-E X 是一个强大而创新的多语言文本转语音(TTS)模型,最初由微软发布。虽然微软最初在他们的研究论文中提出了该概念,但并未发布任何代码或预训练模型。我们认识到了这项技术的潜力和价值,复现并训练了一个开源可用的VALL-E X模型。我们很乐意与社区分享我们的预训练模型,让每个人都能体验到次世代TTS的威力。 🎧
<br>
更多细节请查看 [model card](./model-card.md).

## 📖 目录
* [🚀 更新日志](#-更新日志)
* [📢 功能特点](#-功能特点)
* [💻 本地安装](#-本地安装)
* [🎧 在线Demo](#-在线Demo)
* [🐍 使用方法](#-Python中的使用方法)
* [❓ FAQ](#-faq)
* [🧠 TODO](#-todo)

## 🚀 Updates
**2023.09.10**
- 支持AR decoder的batch decoding以实现更稳定的生成结果

**2023.08.30**
- 将EnCodec解码器替换成了Vocos解码器,提升了音质。 (感谢[@v0xie](https://github.com/v0xie))

**2023.08.23**
- 加入了长文本生成功能

**2023.08.20**
- 加入了中文版README

**2023.08.14**
- 预训练模型权重已发布,从[这里](https://drive.google.com/file/d/10gdQWvP-K_e1undkvv0p2b7SU6I4Egyl/view?usp=sharing)下载。

## 💻 本地安装
### 使用pip安装,必须使用Python 3.10,CUDA 11.7 ~ 12.0,PyTorch 2.0+
```commandline
git clone https://github.com/Plachtaa/VALL-E-X.git
cd VALL-E-X
pip install -r requirements.txt
```

> 注意:如果需要制作prompt,需要安装 ffmpeg 并将其所在文件夹加入到环境变量PATH中

第一次运行程序时,会自动下载相应的模型。如果下载失败并报错,请按照以下步骤手动下载模型。

(请注意目录和文件夹的大小写)

1.检查安装目录下是否存在`checkpoints`文件夹,如果没有,在安装目录下手动创建`checkpoints`文件夹(`./checkpoints/`)。

2.检查`checkpoints`文件夹中是否有`vallex-checkpoint.pt`文件。如果没有,请从[这里](https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt)
手动下载`vallex-checkpoint.pt`文件并放到`checkpoints`文件夹里。

3.检查安装目录下是否存在`whisper`文件夹,如果没有,在安装目录下手动创建`whisper`文件夹(`./whisper/`)。

4.检查`whisper`文件夹中是否有`medium.pt`文件。如果没有,请从[这里](https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt)
手动下载`medium.pt`文件并放到`whisper`文件夹里。

##  🎧 在线Demo
如果你不想在本地安装,你可以在线体验VALL-E X的功能,点击下面的任意一个链接即可开始体验。
<br>
[![Open in Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue.svg)](https://huggingface.co/spaces/Plachta/VALL-E-X)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1yyD_sz531QntLKowMHo-XxorsFBCfKul?usp=sharing)


## 📢 功能特点

VALL-E X 配备有一系列尖端功能:

1. **多语言 TTS**: 可使用三种语言 - 英语、中文和日语 - 进行自然、富有表现力的语音合成。

2. **零样本语音克隆**: 仅需录制任意说话人的短短的 3~10 秒录音,VALL-E X 就能生成个性化、高质量的语音,完美还原他们的声音。

<details>
  <summary><h5>查看示例</h5></summary>

[prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/a7baa51d-a53a-41cc-a03d-6970f25fcca7)


[output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/b895601a-d126-4138-beff-061aabdc7985)

</details>

3. **语音情感控制**: VALL-E X 可以合成与给定说话人录音相同情感的语音,为音频增添更多表现力。

<details>
  <summary><h5>查看示例</h5></summary>

https://github.com/Plachtaa/VALL-E-X/assets/112609742/56fa9988-925e-4757-82c5-83ecb0df6266


https://github.com/Plachtaa/VALL-E-X/assets/112609742/699c47a3-d502-4801-8364-bd89bcc0b8f1

</details>

4. **零样本跨语言语音合成**: VALL-E X 可以合成与给定说话人母语不同的另一种语言,在不影响口音和流利度的同时,保留该说话人的音色与情感。以下是一个使用日语母语者进行英文与中文合成的样例: 🇯🇵 🗣

<details>
  <summary><h5>查看示例</h5></summary>

[jp-prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/ea6e2ee4-139a-41b4-837e-0bd04dda6e19)


[en-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/db8f9782-923f-425e-ba94-e8c1bd48f207)


[zh-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/15829d79-e448-44d3-8965-fafa7a3f8c28)

</details>

5. **口音控制**: VALL-E X 允许您控制所合成音频的口音,比如说中文带英语口音或反之。 🇨🇳 💬

<details>
  <summary><h5>查看示例</h5></summary>

[en-prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/f688d7f6-70ef-46ec-b1cc-355c31e78b3b)


[zh-accent-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/be59c7ca-b45b-44ca-a30d-4d800c950ccc)


[en-accent-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/8b4f4f9b-f299-4ea4-a548-137437b71738)

</details>

6. **声学环境保留**: 当给定说话人的录音在不同的声学环境下录制时,VALL-E X 可以保留该声学环境,使合成语音听起来更加自然。

<details>
  <summary><h5>查看示例</h5></summary>

[noise-prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/68986d88-abd0-4d1d-96e4-4f893eb9259e)


[noise-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/96c4c612-4516-4683-8804-501b70938608)

</details>


你可以访问我们的[demo页面](https://plachtaa.github.io/) 来浏览更多示例!

## 💻 Python中的使用方法

<details open>
  <summary><h3>🪑 基本使用</h3></summary>

```python
from utils.generation import SAMPLE_RATE, generate_audio, preload_models
from scipy.io.wavfile import write as write_wav
from IPython.display import Audio

# download and load all models
preload_models()

# generate audio from text
text_prompt = """
Hello, my name is Nose. And uh, and I like hamburger. Hahaha... But I also have other interests such as playing tactic toast.
"""
audio_array = generate_audio(text_prompt)

# save audio to disk
write_wav("vallex_generation.wav", SAMPLE_RATE, audio_array)

# play text in notebook
Audio(audio_array, rate=SAMPLE_RATE)
```

[hamburger.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/578d7bbe-cda9-483e-898c-29646edc8f2e)

</details>

<details open>
  <summary><h3>🌎 多语言</h3></summary>
<br>
该VALL-E X实现支持三种语言:英语、中文和日语。您可以通过设置`language`参数来指定语言。默认情况下,该模型将自动检测语言。
<br>

```python

text_prompt = """
    チュソクは私のお気に入りの祭りです。 私は数日間休んで、友人や家族との時間を過ごすことができます。
"""
audio_array = generate_audio(text_prompt)
```

[vallex_japanese.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/ee57a688-3e83-4be5-b0fe-019d16eec51c)

*注意:即使在一句话中混合多种语言的情况下,VALL-E X也能完美地控制口音,但是您需要手动标记各个句子对应的语言以便于我们的G2P工具识别它们。*
```python
text_prompt = """
    [EN]The Thirty Years' War was a devastating conflict that had a profound impact on Europe.[EN]
    [ZH]这是历史的开始。 如果您想听更多,请继续。[ZH]
"""
audio_array = generate_audio(text_prompt, language='mix')
```

[vallex_codeswitch.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/d8667abf-bd08-499f-a383-a861d852f98a)

</details>

<details open>
<summary><h3>📼 预设音色</h3></summary>
  
我们提供十几种说话人音色可直接VALL-E X使用! 在[这里](/presets)浏览所有可用音色。

> VALL-E X 尝试匹配给定预设音色的音调、音高、情感和韵律。该模型还尝试保留音乐、环境噪声等。
```python
text_prompt = """
I am an innocent boy with a smoky voice. It is a great honor for me to speak at the United Nations today.
"""
audio_array = generate_audio(text_prompt, prompt="dingzhen")
```

[smoky.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/d3f55732-b1cd-420f-87d6-eab60db14dc5)

</details>

<details open>
<summary><h3>🎙声音克隆</h3></summary>
  
VALL-E X 支持声音克隆!你可以使用任何人,角色,甚至是你自己的声音,来制作一个音频提示。在你使用该音频提示时,VALL-E X 将会使用与其相似的声音来合成文本。
<br>
你需要提供一段3~10秒长的语音,以及该语音对应的文本,来制作音频提示。你也可以将文本留空,让[Whisper](https://github.com/openai/whisper)模型为你生成文本。
> VALL-E X 尝试匹配给定音频提示的音调、音高、情感和韵律。该模型还尝试保留音乐、环境噪声等。

```python
from utils.prompt_making import make_prompt

### Use given transcript
make_prompt(name="paimon", audio_prompt_path="paimon_prompt.wav",
                transcript="Just, what was that? Paimon thought we were gonna get eaten.")

### Alternatively, use whisper
make_prompt(name="paimon", audio_prompt_path="paimon_prompt.wav")
```
来尝试一下刚刚做好的音频提示吧!
```python
from utils.generation import SAMPLE_RATE, generate_audio, preload_models
from scipy.io.wavfile import write as write_wav

# download and load all models
preload_models()

text_prompt = """
Hey, Traveler, Listen to this, This machine has taken my voice, and now it can talk just like me!
"""
audio_array = generate_audio(text_prompt, prompt="paimon")

write_wav("paimon_cloned.wav", SAMPLE_RATE, audio_array)

```

[paimon_prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/e7922859-9d12-4e2a-8651-e156e4280311)


[paimon_cloned.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/60d3b7e9-5ead-4024-b499-a897ce5f3d5e)


</details>


<details open>
<summary><h3>🎢用户界面</h3></summary>

如果你不擅长代码,我们还为VALL-E X创建了一个用户友好的图形界面。它可以让您轻松地与模型进行交互,使语音克隆和多语言语音合成变得轻而易举。
<br>
使用以下命令启动用户界面:
```commandline
python -X utf8 launch-ui.py
```
</details>

## 🛠️ 硬件要求及推理速度

VALL-E X 可以在CPU或GPU上运行 (`pytorch 2.0+`, CUDA 11.7 ~ CUDA 12.0).

若使用GPU运行,你需要至少6GB的显存。

## ⚙️ Details

VALL-E X 与 [Bark](https://github.com/suno-ai/bark), [VALL-E](https://arxiv.org/abs/2301.02111) and [AudioLM](https://arxiv.org/abs/2209.03143)类似, 使用GPT风格的模型以自回归方式预测量化音频token,并由[EnCodec](https://github.com/facebookresearch/encodec)解码.
<br>
与 [Bark](https://github.com/suno-ai/bark) 相比:
- ✔ **轻量**: 3️⃣ ✖ 更小,
- ✔ **快速**: 4️⃣ ✖ 更快, 
- ✔ **中文&日文的更高质量**
- ✔ **跨语言合成时没有外国口音**
- ✔ **开放且易于操作的声音克隆**
- ❌ **支持的语言较少**
- ❌ **没有用于合成音乐及特殊音效的token**

### 支持的语言

| 语言      | 状态 |
|---------| :---: |
| 英语 (en) | ✅ |
| 日语 (ja) | ✅ |
| 中文 (zh) | ✅ |

## ❓ FAQ

#### 在哪里可以下载checkpoint?
* 当您第一次运行程序时,我们使用`wget`将模型下载到`./checkpoints/`目录里。
* 如果第一次运行时下载失败,请从[这里](https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt)手动下载模型,并将文件放在`./checkpoints/`里。

#### 需要多少显存?
* 6GB 显存(GPU VRAM) - 几乎所有NVIDIA GPU都满足要求.

#### 为什么模型无法生成长文本?
* 当序列长度增加时,Transformer的计算复杂度呈二次方增长。因此,所有训练音频都保持在22秒以下。请确保音频提示(audio prompt)和生成的音频的总长度小于22秒以确保可接受的性能。

#### 更多...

## 🧠 待办事项
- [x] 添加中文 README
- [x] 长文本生成
- [x] 用Vocos解码器替换Encodec解码器
- [ ] 微调以实现更好的语音自适应
- [ ] 给非python用户的`.bat`脚本
- [ ] 更多...

## 🙏 感谢
- [VALL-E X paper](https://arxiv.org/pdf/2303.03926) for the brilliant idea
- [lifeiteng's vall-e](https://github.com/lifeiteng/vall-e) for related training code
- [bark](https://github.com/suno-ai/bark) for the amazing pioneering work in neuro-codec TTS model

## ⭐️ 表示出你的支持

如果您觉得VALL-E X有趣且有用,请在GitHub上给我们一颗星! ⭐️ 它鼓励我们不断改进模型并添加令人兴奋的功能。

## 📜 License

VALL-E X 使用 [MIT License](./LICENSE).

---

有问题或需要帮助? 可以随便 [open an issue](https://github.com/Plachtaa/VALL-E-X/issues/new) 或加入我们的 [Discord](https://discord.gg/qCBRmAnTxg)

Happy voice cloning! 🎤


================================================
FILE: README.md
================================================
# VALL-E X: Multilingual Text-to-Speech Synthesis and Voice Cloning 🔊
[![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/qCBRmAnTxg)
<br>
English | [中文](README-ZH.md)
<br>
An open source implementation of Microsoft's [VALL-E X](https://arxiv.org/pdf/2303.03926) zero-shot TTS model.<br>
**We release our trained model to the public for research or application usage.**

![vallex-framework](/images/vallex_framework.jpg "VALL-E X framework")

VALL-E X is an amazing multilingual text-to-speech (TTS) model proposed by Microsoft. While Microsoft initially publish in their research paper, they did not release any code or pretrained models. Recognizing the potential and value of this technology, our team took on the challenge to reproduce the results and train our own model. We are glad to share our trained VALL-E X model with the community, allowing everyone to experience the power next-generation TTS! 🎧
<br>
<br>
More details about the model are presented in [model card](./model-card.md).

## 📖 Quick Index
* [🚀 Updates](#-updates)
* [📢 Features](#-features)
* [💻 Installation](#-installation)
* [🎧 Demos](#-demos)
* [🐍 Usage](#-usage-in-python)
* [❓ FAQ](#-faq)
* [🧠 TODO](#-todo)

## 🚀 Updates
**2023.09.10**
- Added AR decoder batch decoding for more stable generation result.

**2023.08.30**
- Replaced EnCodec decoder with Vocos decoder, improved audio quality. (Thanks to [@v0xie](https://github.com/v0xie))

**2023.08.23**
- Added long text generation.

**2023.08.20**
- Added [Chinese README](README-ZH.md).

**2023.08.14**
- Pretrained VALL-E X checkpoint is now released. Download it [here](https://drive.google.com/file/d/10gdQWvP-K_e1undkvv0p2b7SU6I4Egyl/view?usp=sharing)

## 💻 Installation
### Install with pip, Python 3.10, CUDA 11.7 ~ 12.0, PyTorch 2.0+
```commandline
git clone https://github.com/Plachtaa/VALL-E-X.git
cd VALL-E-X
pip install -r requirements.txt
```

> Note: If you want to make prompt, you need to install ffmpeg and add its folder to the environment variable PATH.

When you run the program for the first time, it will automatically download the corresponding model. 

If the download fails and reports an error, please follow the steps below to manually download the model.

(Please pay attention to the capitalization of folders)

1. Check whether there is a `checkpoints` folder in the installation directory. 
If not, manually create a `checkpoints` folder (`./checkpoints/`) in the installation directory.

2. Check whether there is a `vallex-checkpoint.pt` file in the `checkpoints` folder. 
If not, please manually download the `vallex-checkpoint.pt` file from [here](https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt) and put it in the `checkpoints` folder.

3. Check whether there is a `whisper` folder in the installation directory. 
If not, manually create a `whisper` folder (`./whisper/`) in the installation directory.

4. Check whether there is a `medium.pt` file in the `whisper` folder. 
If not, please manually download the `medium.pt` file from [here](https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt) and put it in the `whisper` folder.

##  🎧 Demos
Not ready to set up the environment on your local machine just yet? No problem! We've got you covered with our online demos. You can try out VALL-E X directly on Hugging Face or Google Colab, experiencing the model's capabilities hassle-free!
<br>
[![Open in Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue.svg)](https://huggingface.co/spaces/Plachta/VALL-E-X)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1yyD_sz531QntLKowMHo-XxorsFBCfKul?usp=sharing)


## 📢 Features

VALL-E X comes packed with cutting-edge functionalities:

1. **Multilingual TTS**: Speak in three languages - English, Chinese, and Japanese - with natural and expressive speech synthesis.

2. **Zero-shot Voice Cloning**: Enroll a short 3~10 seconds recording of an unseen speaker, and watch VALL-E X create personalized, high-quality speech that sounds just like them!

<details>
  <summary><h5>see example</h5></summary>

[prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/a7baa51d-a53a-41cc-a03d-6970f25fcca7)


[output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/b895601a-d126-4138-beff-061aabdc7985)

</details>

3. **Speech Emotion Control**: Experience the power of emotions! VALL-E X can synthesize speech with the same emotion as the acoustic prompt provided, adding an extra layer of expressiveness to your audio.

<details>
  <summary><h5>see example</h5></summary>

https://github.com/Plachtaa/VALL-E-X/assets/112609742/56fa9988-925e-4757-82c5-83ecb0df6266


https://github.com/Plachtaa/VALL-E-X/assets/112609742/699c47a3-d502-4801-8364-bd89bcc0b8f1

</details>

4. **Zero-shot Cross-Lingual Speech Synthesis**: Take monolingual speakers on a linguistic journey! VALL-E X can produce personalized speech in another language without compromising on fluency or accent. Below is a Japanese speaker talk in Chinese & English. 🇯🇵 🗣

<details>
  <summary><h5>see example</h5></summary>

[jp-prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/ea6e2ee4-139a-41b4-837e-0bd04dda6e19)


[en-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/db8f9782-923f-425e-ba94-e8c1bd48f207)


[zh-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/15829d79-e448-44d3-8965-fafa7a3f8c28)

</details>

5. **Accent Control**: Get creative with accents! VALL-E X allows you to experiment with different accents, like speaking Chinese with an English accent or vice versa. 🇨🇳 💬

<details>
  <summary><h5>see example</h5></summary>

[en-prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/f688d7f6-70ef-46ec-b1cc-355c31e78b3b)


[zh-accent-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/be59c7ca-b45b-44ca-a30d-4d800c950ccc)


[en-accent-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/8b4f4f9b-f299-4ea4-a548-137437b71738)

</details>

6. **Acoustic Environment Maintenance**: No need for perfectly clean audio prompts! VALL-E X adapts to the acoustic environment of the input, making speech generation feel natural and immersive.

<details>
  <summary><h5>see example</h5></summary>

[noise-prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/68986d88-abd0-4d1d-96e4-4f893eb9259e)


[noise-output.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/96c4c612-4516-4683-8804-501b70938608)

</details>


Explore our [demo page](https://plachtaa.github.io/) for a lot more examples!

## 🐍 Usage in Python

<details open>
  <summary><h3>🪑 Basics</h3></summary>

```python
from utils.generation import SAMPLE_RATE, generate_audio, preload_models
from scipy.io.wavfile import write as write_wav
from IPython.display import Audio

# download and load all models
preload_models()

# generate audio from text
text_prompt = """
Hello, my name is Nose. And uh, and I like hamburger. Hahaha... But I also have other interests such as playing tactic toast.
"""
audio_array = generate_audio(text_prompt)

# save audio to disk
write_wav("vallex_generation.wav", SAMPLE_RATE, audio_array)

# play text in notebook
Audio(audio_array, rate=SAMPLE_RATE)
```

[hamburger.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/578d7bbe-cda9-483e-898c-29646edc8f2e)

</details>

<details open>
  <summary><h3>🌎 Foreign Language</h3></summary>
<br>
This VALL-E X implementation also supports Chinese and Japanese. All three languages have equally awesome performance!
<br>

```python

text_prompt = """
    チュソクは私のお気に入りの祭りです。 私は数日間休んで、友人や家族との時間を過ごすことができます。
"""
audio_array = generate_audio(text_prompt)
```

[vallex_japanese.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/ee57a688-3e83-4be5-b0fe-019d16eec51c)

*Note: VALL-E X controls accent perfectly even when synthesizing code-switch text. However, you need to manually denote language of respective sentences (since our g2p tool is rule-base)*
```python
text_prompt = """
    [EN]The Thirty Years' War was a devastating conflict that had a profound impact on Europe.[EN]
    [ZH]这是历史的开始。 如果您想听更多,请继续。[ZH]
"""
audio_array = generate_audio(text_prompt, language='mix')
```

[vallex_codeswitch.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/d8667abf-bd08-499f-a383-a861d852f98a)

</details>

<details open>
<summary><h3>📼 Voice Presets</h3></summary>
  
VALL-E X provides tens of speaker voices which you can directly used for inference! Browse all voices in the [code](/presets)

> VALL-E X tries to match the tone, pitch, emotion and prosody of a given preset. The model also attempts to preserve music, ambient noise, etc.

```python
text_prompt = """
I am an innocent boy with a smoky voice. It is a great honor for me to speak at the United Nations today.
"""
audio_array = generate_audio(text_prompt, prompt="dingzhen")
```

[smoky.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/d3f55732-b1cd-420f-87d6-eab60db14dc5)

</details>

<details open>
<summary><h3>🎙Voice Cloning</h3></summary>
  
VALL-E X supports voice cloning! You can make a voice prompt with any person, character or even your own voice, and use it like other voice presets.<br>
To make a voice prompt, you need to provide a speech of 3~10 seconds long, as well as the transcript of the speech. 
You can also leave the transcript blank to let the [Whisper](https://github.com/openai/whisper) model to generate the transcript.
> VALL-E X tries to match the tone, pitch, emotion and prosody of a given prompt. The model also attempts to preserve music, ambient noise, etc.

```python
from utils.prompt_making import make_prompt

### Use given transcript
make_prompt(name="paimon", audio_prompt_path="paimon_prompt.wav",
                transcript="Just, what was that? Paimon thought we were gonna get eaten.")

### Alternatively, use whisper
make_prompt(name="paimon", audio_prompt_path="paimon_prompt.wav")
```
Now let's try out the prompt we've just made!
```python
from utils.generation import SAMPLE_RATE, generate_audio, preload_models
from scipy.io.wavfile import write as write_wav

# download and load all models
preload_models()

text_prompt = """
Hey, Traveler, Listen to this, This machine has taken my voice, and now it can talk just like me!
"""
audio_array = generate_audio(text_prompt, prompt="paimon")

write_wav("paimon_cloned.wav", SAMPLE_RATE, audio_array)

```

[paimon_prompt.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/e7922859-9d12-4e2a-8651-e156e4280311)


[paimon_cloned.webm](https://github.com/Plachtaa/VALL-E-X/assets/112609742/60d3b7e9-5ead-4024-b499-a897ce5f3d5e)


</details>


<details open>
<summary><h3>🎢User Interface</h3></summary>

Not comfortable with codes? No problem! We've also created a user-friendly graphical interface for VALL-E X. It allows you to interact with the model effortlessly, making voice cloning and multilingual speech synthesis a breeze.
<br>
You can launch the UI by the following command:
```commandline
python -X utf8 launch-ui.py
```
</details>

## 🛠️ Hardware and Inference Speed

VALL-E X works well on both CPU and GPU (`pytorch 2.0+`, CUDA 11.7 and CUDA 12.0).

A GPU VRAM of 6GB is enough for running VALL-E X without offloading.

## ⚙️ Details

VALL-E X is similar to [Bark](https://github.com/suno-ai/bark), [VALL-E](https://arxiv.org/abs/2301.02111) and [AudioLM](https://arxiv.org/abs/2209.03143), which generates audio in GPT-style by predicting audio tokens quantized by [EnCodec](https://github.com/facebookresearch/encodec).
<br>
Comparing to [Bark](https://github.com/suno-ai/bark):
- ✔ **Light-weighted**: 3️⃣ ✖ smaller,
- ✔ **Efficient**: 4️⃣ ✖ faster, 
- ✔ **Better quality on Chinese & Japanese**
- ✔ **Cross-lingual speech without foreign accent**
- ✔ **Easy voice-cloning**
- ❌ **Less languages**
- ❌ **No special tokens for music / sound effects**

### Supported Languages

| Language | Status |
| --- | :---: |
| English (en) | ✅ |
| Japanese (ja) | ✅ |
| Chinese, simplified (zh) | ✅ |

## ❓ FAQ

#### Where is code for training?
* [lifeiteng's vall-e](https://github.com/lifeiteng/vall-e) has almost everything. There is no plan to release our training code because there is no difference between lifeiteng's implementation.

#### Where can I download the model checkpoint?
* We use `wget` to download the model to directory `./checkpoints/` when you run the program for the first time.
* If the download fails on the first run, please manually download from [this link](https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt), and put the file under directory `./checkpoints/`.

#### How much VRAM do I need?
* 6GB GPU VRAM - Almost all NVIDIA GPUs satisfy the requirement.

#### Why the model fails to generate long text?
* Transformer's computation complexity increases quadratically while the sequence length increases. Hence, all training 
are kept under 22 seconds. Please make sure the total length of audio prompt and generated audio is less than 22 seconds 
to ensure acceptable performance. 


#### MORE TO BE ADDED...

## 🧠 TODO
- [x] Add Chinese README
- [x] Long text generation
- [x] Replace Encodec decoder with Vocos decoder
- [ ] Fine-tuning for better voice adaptation
- [ ] `.bat` scripts for non-python users
- [ ] To be added...

## 🙏 Appreciation
- [VALL-E X paper](https://arxiv.org/pdf/2303.03926) for the brilliant idea
- [lifeiteng's vall-e](https://github.com/lifeiteng/vall-e) for related training code
- [bark](https://github.com/suno-ai/bark) for the amazing pioneering work in neuro-codec TTS model

## ⭐️ Show Your Support

If you find VALL-E X interesting and useful, give us a star on GitHub! ⭐️ It encourages us to keep improving the model and adding exciting features.

## 📜 License

VALL-E X is licensed under the [MIT License](./LICENSE).

---

Have questions or need assistance? Feel free to [open an issue](https://github.com/Plachtaa/VALL-E-X/issues/new) or join our [Discord](https://discord.gg/qCBRmAnTxg)

Happy voice cloning! 🎤


================================================
FILE: customs/ph.txt
================================================


================================================
FILE: data/__init__.py
================================================
# from .datamodule import *
# from .tokenizer import *
from .collation import *


================================================
FILE: data/collation.py
================================================
from pathlib import Path
from typing import List, Tuple

import numpy as np
import torch

from utils import SymbolTable


class TextTokenCollater:
    """Collate list of text tokens

    Map sentences to integers. Sentences are padded to equal length.
    Beginning and end-of-sequence symbols can be added.

    Example:
        >>> token_collater = TextTokenCollater(text_tokens)
        >>> tokens_batch, tokens_lens = token_collater(text)

    Returns:
        tokens_batch: IntTensor of shape (B, L)
            B: batch dimension, number of input sentences
            L: length of the longest sentence
        tokens_lens: IntTensor of shape (B,)
            Length of each sentence after adding <eos> and <bos>
            but before padding.
    """

    def __init__(
        self,
        text_tokens: List[str],
        add_eos: bool = True,
        add_bos: bool = True,
        pad_symbol: str = "<pad>",
        bos_symbol: str = "<bos>",
        eos_symbol: str = "<eos>",
    ):
        self.pad_symbol = pad_symbol

        self.add_eos = add_eos
        self.add_bos = add_bos

        self.bos_symbol = bos_symbol
        self.eos_symbol = eos_symbol

        unique_tokens = (
            [pad_symbol]
            + ([bos_symbol] if add_bos else [])
            + ([eos_symbol] if add_eos else [])
            + sorted(text_tokens)
        )

        self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)}
        self.idx2token = [token for token in unique_tokens]

    def index(
        self, tokens_list: List[str]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        seqs, seq_lens = [], []
        for tokens in tokens_list:
            assert (
                all([True if s in self.token2idx else False for s in tokens])
                is True
            )
            seq = (
                ([self.bos_symbol] if self.add_bos else [])
                + list(tokens)
                + ([self.eos_symbol] if self.add_eos else [])
            )
            seqs.append(seq)
            seq_lens.append(len(seq))

        max_len = max(seq_lens)
        for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)):
            seq.extend([self.pad_symbol] * (max_len - seq_len))

        tokens = torch.from_numpy(
            np.array(
                [[self.token2idx[token] for token in seq] for seq in seqs],
                dtype=np.int64,
            )
        )
        tokens_lens = torch.IntTensor(seq_lens)

        return tokens, tokens_lens

    def __call__(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
        tokens_seqs = [[p for p in text] for text in texts]
        max_len = len(max(tokens_seqs, key=len))

        seqs = [
            ([self.bos_symbol] if self.add_bos else [])
            + list(seq)
            + ([self.eos_symbol] if self.add_eos else [])
            + [self.pad_symbol] * (max_len - len(seq))
            for seq in tokens_seqs
        ]

        tokens_batch = torch.from_numpy(
            np.array(
                [seq for seq in seqs],
                dtype=np.int64,
            )
        )

        tokens_lens = torch.IntTensor(
            [
                len(seq) + int(self.add_eos) + int(self.add_bos)
                for seq in tokens_seqs
            ]
        )

        return tokens_batch, tokens_lens


def get_text_token_collater() -> TextTokenCollater:
    collater = TextTokenCollater(
        ['0'], add_bos=False, add_eos=False
    )
    return collater


================================================
FILE: data/datamodule.py
================================================
# Copyright      2023                          (authors: Feiteng Li)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional

import torch
# from icefall.utils import str2bool
# from lhotse import CutSet, load_manifest_lazy
# from lhotse.dataset import (
#     CutConcatenate,
#     DynamicBucketingSampler,
#     PrecomputedFeatures,
#     SingleCutSampler,
#     SpecAugment,
# )
# from lhotse.dataset.input_strategies import OnTheFlyFeatures
# from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader

from data.collation import get_text_token_collater
# from data.dataset import SpeechSynthesisDataset
from data.fbank import get_fbank_extractor
from data.input_strategies import PromptedPrecomputedFeatures

# PrecomputedFeatures = PrecomputedFeatures


class _SeedWorkers:
    def __init__(self, seed: int):
        self.seed = seed

    def __call__(self, worker_id: int):
        fix_random_seed(self.seed + worker_id)


def _get_input_strategy(input_strategy, dataset, cuts):
    if input_strategy == "PromptedPrecomputedFeatures":
        return PromptedPrecomputedFeatures(dataset, cuts)

    return eval(input_strategy)()


class TtsDataModule:
    """
    DataModule for VALL-E TTS experiments.
    It assumes there is always one train and valid dataloader.

    It contains all the common data pipeline modules used in TTS
    experiments, e.g.:
    - dynamic batch size,
    - bucketing samplers,
    - cut concatenation[not used & tested yet],
    - augmentation[not used & tested yet],
    - on-the-fly feature extraction[not used & tested yet]

    This class should be derived for specific corpora used in TTS tasks.
    """

    def __init__(self, args: argparse.Namespace):
        self.args = args

    @classmethod
    def add_arguments(cls, parser: argparse.ArgumentParser):
        group = parser.add_argument_group(
            title="TTS data related options",
            description="These options are used for the preparation of "
            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
            "effective batch sizes, sampling strategies, applied data "
            "augmentations, etc.",
        )
        group.add_argument(
            "--manifest-dir",
            type=Path,
            default=Path("data/tokenized"),
            help="Path to directory with train/valid/test cuts.",
        )
        group.add_argument(
            "--max-duration",
            type=int,
            default=40.0,
            help="Maximum pooled recordings duration (seconds) in a "
            "single batch. You can reduce it if it causes CUDA OOM.",
        )
        group.add_argument(
            "--bucketing-sampler",
            type=str2bool,
            default=True,
            help="When enabled, the batches will come from buckets of "
            "similar duration (saves padding frames).",
        )
        group.add_argument(
            "--num-buckets",
            type=int,
            default=10,
            help="The number of buckets for the DynamicBucketingSampler"
            "(you might want to increase it for larger datasets).",
        )
        group.add_argument(
            "--concatenate-cuts",
            type=str2bool,
            default=False,
            help="When enabled, utterances (cuts) will be concatenated "
            "to minimize the amount of padding.",
        )
        group.add_argument(
            "--duration-factor",
            type=float,
            default=1.0,
            help="Determines the maximum duration of a concatenated cut "
            "relative to the duration of the longest cut in a batch.",
        )
        group.add_argument(
            "--gap",
            type=float,
            default=0.1,
            help="The amount of padding (in seconds) inserted between "
            "concatenated cuts. This padding is filled with noise when "
            "noise augmentation is used.",
        )
        group.add_argument(
            "--on-the-fly-feats",
            type=str2bool,
            default=False,
            help="When enabled, use on-the-fly cut mixing and feature "
            "extraction. Will drop existing precomputed feature manifests "
            "if available.",
        )
        group.add_argument(
            "--shuffle",
            type=str2bool,
            default=True,
            help="When enabled (=default), the examples will be "
            "shuffled for each epoch.",
        )
        group.add_argument(
            "--drop-last",
            type=str2bool,
            default=False,
            help="Whether to drop last batch. Used by sampler.",
        )
        group.add_argument(
            "--return-cuts",
            type=str2bool,
            default=True,
            help="When enabled, each batch will have the "
            "field: batch['supervisions']['cut'] with the cuts that "
            "were used to construct it.",
        )

        group.add_argument(
            "--num-workers",
            type=int,
            default=8,
            help="The number of training dataloader workers that "
            "collect the batches.",
        )

        group.add_argument(
            "--enable-spec-aug",
            type=str2bool,
            default=False,
            help="When enabled, use SpecAugment for training dataset.",
        )

        group.add_argument(
            "--spec-aug-time-warp-factor",
            type=int,
            default=80,
            help="Used only when --enable-spec-aug is True. "
            "It specifies the factor for time warping in SpecAugment. "
            "Larger values mean more warping. "
            "A value less than 1 means to disable time warp.",
        )

        group.add_argument(
            "--input-strategy",
            type=str,
            default="PrecomputedFeatures",
            help="AudioSamples or PrecomputedFeatures or PromptedPrecomputedFeatures",
        )

        group.add_argument(
            "--dataset",
            type=str,
            default="ljspeech",
            help="--input-strategy PromptedPrecomputedFeatures needs dataset name to prepare prompts.",
        )

        parser.add_argument(
            "--text-tokens",
            type=str,
            default="data/tokenized/unique_text_tokens.k2symbols",
            help="Path to the unique text tokens file",
        )

        parser.add_argument(
            "--sampling-rate",
            type=int,
            default=24000,
            help="""Audio sampling rate.""",
        )

    def train_dataloaders(
        self,
        cuts_train: CutSet,
        sampler_state_dict: Optional[Dict[str, Any]] = None,
    ) -> DataLoader:
        """
        Args:
          cuts_train:
            CutSet for training.
          sampler_state_dict:
            The state dict for the training sampler.
        """
        transforms = []

        if self.args.concatenate_cuts:
            logging.info(
                f"Using cut concatenation with duration factor "
                f"{self.args.duration_factor} and gap {self.args.gap}."
            )
            # Cut concatenation should be the first transform in the list,
            # so that if we e.g. mix noise in, it will fill the gaps between
            # different utterances.
            transforms = [
                CutConcatenate(
                    duration_factor=self.args.duration_factor, gap=self.args.gap
                )
            ] + transforms

        input_transforms = []
        if self.args.enable_spec_aug:
            logging.info("Enable SpecAugment")
            logging.info(
                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
            )
            # Set the value of num_frame_masks according to Lhotse's version.
            # In different Lhotse's versions, the default of num_frame_masks is
            # different.
            num_frame_masks = 10
            num_frame_masks_parameter = inspect.signature(
                SpecAugment.__init__
            ).parameters["num_frame_masks"]
            if num_frame_masks_parameter.default == 1:
                num_frame_masks = 2
            logging.info(f"Num frame mask: {num_frame_masks}")
            input_transforms.append(
                SpecAugment(
                    time_warp_factor=self.args.spec_aug_time_warp_factor,
                    num_frame_masks=num_frame_masks,
                    features_mask_size=27,
                    num_feature_masks=2,
                    frames_mask_size=100,
                )
            )
        else:
            logging.info("Disable SpecAugment")

        logging.info("About to create train dataset")
        if self.args.on_the_fly_feats:
            # NOTE: the PerturbSpeed transform should be added only if we
            # remove it from data prep stage.
            # Add on-the-fly speed perturbation; since originally it would
            # have increased epoch size by 3, we will apply prob 2/3 and use
            # 3x more epochs.
            # Speed perturbation probably should come first before
            # concatenation, but in principle the transforms order doesn't have
            # to be strict (e.g. could be randomized)
            # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms   # noqa
            # Drop feats to be on the safe side.
            train = SpeechSynthesisDataset(
                get_text_token_collater(self.args.text_tokens),
                cut_transforms=transforms,
                feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()),
                feature_transforms=input_transforms,
            )
        else:
            train = SpeechSynthesisDataset(
                get_text_token_collater(self.args.text_tokens),
                feature_input_strategy=_get_input_strategy(
                    self.args.input_strategy, self.args.dataset, cuts_train
                ),
                cut_transforms=transforms,
                feature_transforms=input_transforms,
            )

        if self.args.bucketing_sampler:
            logging.info("Using DynamicBucketingSampler")
            train_sampler = DynamicBucketingSampler(
                cuts_train,
                max_duration=self.args.max_duration,
                shuffle=self.args.shuffle,
                num_buckets=self.args.num_buckets,
                drop_last=self.args.drop_last,
            )
        else:
            logging.info(
                "Using SingleCutSampler and sort by duraton(ascending=True)."
            )
            cuts_train = cuts_train.to_eager().sort_by_duration(ascending=True)
            train_sampler = SingleCutSampler(
                cuts_train,
                max_duration=self.args.max_duration,
                shuffle=self.args.shuffle,
            )
        logging.info("About to create train dataloader")

        if sampler_state_dict is not None:
            logging.info("Loading sampler state dict")
            train_sampler.load_state_dict(sampler_state_dict)

        # 'seed' is derived from the current random state, which will have
        # previously been set in the main process.
        seed = torch.randint(0, 100000, ()).item()
        worker_init_fn = _SeedWorkers(seed)

        train_dl = DataLoader(
            train,
            sampler=train_sampler,
            batch_size=None,
            num_workers=self.args.num_workers,
            persistent_workers=False,
            worker_init_fn=worker_init_fn,
        )

        return train_dl

    def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
        logging.info("About to create dev dataset")
        if self.args.on_the_fly_feats:
            validate = SpeechSynthesisDataset(
                get_text_token_collater(self.args.text_tokens),
                feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()),
                cut_transforms=[],
            )
        else:
            validate = SpeechSynthesisDataset(
                get_text_token_collater(self.args.text_tokens),
                feature_input_strategy=_get_input_strategy(
                    self.args.input_strategy, self.args.dataset, cuts_valid
                ),
                cut_transforms=[],
            )
        valid_sampler = DynamicBucketingSampler(
            cuts_valid,
            max_duration=self.args.max_duration,
            shuffle=False,
        )
        logging.info("About to create dev dataloader")
        valid_dl = DataLoader(
            validate,
            sampler=valid_sampler,
            batch_size=None,
            num_workers=4,
            persistent_workers=False,
        )

        return valid_dl

    def test_dataloaders(self, cuts: CutSet) -> DataLoader:
        logging.debug("About to create test dataset")
        test = SpeechSynthesisDataset(
            get_text_token_collater(self.args.text_tokens),
            feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor())
            if self.args.on_the_fly_feats
            else _get_input_strategy(
                self.args.input_strategy, self.args.dataset, cuts
            ),
            cut_transforms=[],
        )
        sampler = DynamicBucketingSampler(
            cuts,
            max_duration=self.args.max_duration,
            shuffle=False,
        )
        logging.debug("About to create test dataloader")
        test_dl = DataLoader(
            test,
            batch_size=None,
            sampler=sampler,
            num_workers=self.args.num_workers,
        )
        return test_dl

    @lru_cache()
    def train_cuts(self) -> CutSet:
        logging.info("About to get train cuts")
        return load_manifest_lazy(
            self.args.manifest_dir / "cuts_train.jsonl.gz"
        )

    @lru_cache()
    def dev_cuts(self) -> CutSet:
        logging.info("About to get dev cuts")
        return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz")

    @lru_cache()
    def test_cuts(self) -> CutSet:
        logging.info("About to get test cuts")
        return load_manifest_lazy(self.args.manifest_dir / "cuts_test.jsonl.gz")


================================================
FILE: data/dataset.py
================================================
# Copyright      2023                           (authors: Feiteng Li)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
modified from lhoste.dataset.speech_synthesis.py
"""

import torch
import math
import h5py
from tokenizers import Tokenizer
from typing import Union, List
import numpy as np
from tqdm import tqdm

_pad        = '_'
_punctuation = ',.!?-~…'
_letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ '
symbols = [_pad] + list(_punctuation) + list(_letters)

language_dict = {
    'en': 0,
    'zh': 1,
    'ja': 2,
}
def seq2phone(tokens: Union[List, np.ndarray]):
    """
    Convert tokenized phoneme ID sequence back to phoneme string
    :param tokens: phoneme tokens
    :return: recovered phoneme sequence
    """
    phones = "".join([symbols[i] for i in tokens])
    return phones

class DynamicBatchSampler(torch.utils.data.Sampler):
    def __init__(self, sampler, num_tokens_fn, num_buckets=100, min_size=0, max_size=1000,
                 max_tokens=None, max_sentences=None, drop_last=False):
        """

        :param sampler:
        :param num_tokens_fn: 根据idx返回样本的长度的函数
        :param num_buckets: 利用桶原理将相似长度的样本放在一个batchsize中,桶的数量
        :param min_size: 最小长度的样本, 小于这个值的样本会被过滤掉。 依据这个值来创建样桶
        :param max_size: 最大长度的样本
        :param max_sentences: batch_size, 但是这里可以通过max_sentences 和 max_tokens 共同控制最终的大小
        """
        super(DynamicBatchSampler, self).__init__(sampler)
        self.sampler = sampler
        self.num_tokens_fn = num_tokens_fn
        self.num_buckets = num_buckets

        self.min_size = min_size
        self.max_size = max_size

        assert max_size <= max_tokens, "max_size should be smaller than max tokens"
        assert max_tokens is not None or max_sentences is not None, \
            "max_tokens and max_sentences should not be null at the same time, please specify one parameter at least"
        self.max_tokens = max_tokens if max_tokens is not None else float('Inf')
        self.max_sentences = max_sentences if max_sentences is not None else float('Inf')
        self.drop_last = drop_last

    def set_epoch(self, epoch):
        self.sampler.set_epoch(epoch)
    def is_batch_full(self, num_tokens, batch):
        if len(batch) == 0:
            return False
        if len(batch) == self.max_sentences:
            return True
        if num_tokens > self.max_tokens:
            return True
        return False

    def __iter__(self):
        buckets = [[] for _ in range(self.num_buckets)]
        sample_len = [0] * self.num_buckets

        for idx in self.sampler:
            idx_length = self.num_tokens_fn(idx)
            if not (self.min_size <= idx_length <= self.max_size):
                print("sentence at index {} of size {} exceeds max_tokens, the sentence is ignored".format(idx, idx_length))
                continue

            index_buckets = math.floor((idx_length - self.min_size) / (self.max_size - self.min_size + 1)
                                       * self.num_buckets)
            sample_len[index_buckets] = max(sample_len[index_buckets], idx_length)

            num_tokens = (len(buckets[index_buckets]) + 1) * sample_len[index_buckets]
            if self.is_batch_full(num_tokens, buckets[index_buckets]):
                # yield this batch
                yield buckets[index_buckets]
                buckets[index_buckets] = []
                sample_len[index_buckets] = 0

            buckets[index_buckets].append(idx)

        # process left-over
        leftover_batch = []
        leftover_sample_len = 0
        leftover = [idx for bucket in buckets for idx in bucket]
        for idx in leftover:
            idx_length = self.num_tokens_fn(idx)
            leftover_sample_len = max(leftover_sample_len, idx_length)
            num_tokens = (len(leftover_batch) + 1) * leftover_sample_len
            if self.is_batch_full(num_tokens, leftover_batch):
                yield leftover_batch
                leftover_batch = []
                leftover_sample_len = 0
            leftover_batch.append(idx)

        if len(leftover_batch) > 0 and not self.drop_last:
            yield leftover_batch

    def __len__(self):
        # we do not know the exactly batch size, so do not call len(dataloader)
        pass


class AudioDataset(torch.utils.data.Dataset):
    def __init__(self, h5_path, ann_path, tokenizer_path):
        self.h5_path = h5_path
        with open(ann_path, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        ls = [l.split("|") for l in lines]
        ls_T = list(zip(*ls))
        del ls_T[-1]
        self.h5_paths, self.durations, self.langs, self.texts = \
            list(ls_T[0]), list(ls_T[1]), list(ls_T[2]), list(ls_T[3])
        self.durations = [float(dur) for dur in self.durations]
        self.tokenizer = Tokenizer.from_file(tokenizer_path)

        self._archive = None

    def __len__(self):
        return len(self.h5_paths)

    def get_dur(self, idx):
        return self.durations[idx]

    @property
    def archive(self):
        if self._archive is None:  # lazy loading here!
            self._archive = h5py.File(self.h5_path, "r")
        return self._archive
    def __getitem__(self, idx):
        archive = self.archive
        h5_path = self.h5_paths[idx]
        sub = archive[h5_path]
        audio_tokens = sub['audio'][()]
        phone_tokens = sub['text'][()]
        dur = self.durations[idx]
        lang = self.langs[idx]
        text = self.texts[idx]
        # tokenization should be done within dataloader
        phones = seq2phone(phone_tokens)
        phones = phones.replace(" ", "_")
        if not len(phones):
            cptpho_tokens = self.tokenizer.encode(text).ids
        else:
            cptpho_tokens = self.tokenizer.encode(phones).ids
        assert len(cptpho_tokens)
        return {
            'utt_id': h5_path,
            'text': text,
            'audio': None,
            'audio_lens': None,
            'audio_features': audio_tokens,
            'audio_features_lens': len(audio_tokens.T),
            'text_tokens': np.array(cptpho_tokens),
            'text_tokens_lens': len(cptpho_tokens),
            'language': language_dict[lang],
        }

def collate(batch):
    utt_id_s = [b['utt_id'] for b in batch]
    text_s = [b['text'] for b in batch]

    audio_s = [b['audio'] for b in batch]
    audio_lens_s = [b['audio_lens'] for b in batch]

    audio_features_lens_s = [b['audio_features_lens'] for b in batch]
    # create an empty tensor with maximum audio feature length
    audio_features_s = torch.zeros([len(batch), max(audio_features_lens_s), 8], dtype=torch.int64) - 1 # audio pad with -1

    text_tokens_lens_s = [b['text_tokens_lens'] for b in batch]
    # create an empty tensor with maximum text tokens length
    text_tokens_s = torch.zeros([len(batch), max(text_tokens_lens_s)], dtype=torch.int64) + 3 # [PAD] token id 3

    language_s = [b['language'] for b in batch]

    for i, b in enumerate(batch):
        audio_features = b['audio_features']
        audio_features_lens = b['audio_features_lens']
        audio_features_s[i, :audio_features_lens, :] = torch.LongTensor(audio_features.T)

        text_tokens = b['text_tokens']
        text_tokens_lens = b['text_tokens_lens']
        text_tokens_s[i, :text_tokens_lens] = torch.LongTensor(text_tokens)

    batch = {
        'utt_id': utt_id_s,
        'text': text_s,
        'audio': audio_s,
        'audio_lens': audio_lens_s,
        'audio_features': audio_features_s,
        'audio_features_lens': torch.LongTensor(np.array(audio_features_lens_s)),
        'text_tokens': text_tokens_s,
        'text_tokens_lens': torch.LongTensor(np.array(text_tokens_lens_s)),
        'languages': torch.LongTensor(np.array(language_s)),
    }
    return batch

def create_dataloader(data_dir="/root/valle/egs/mix", n_gpus=1, rank=0, num_workers=0, num_buckets=10, max_duration=120):
    train_dataset = AudioDataset(h5_path=f"{data_dir}/audio_sum.hdf5",
                                 ann_path=f"{data_dir}/audio_ann_sum.txt",
                                 tokenizer_path=f"{data_dir}/bpe_69.json")
    ran_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset,
            num_replicas=n_gpus,
            rank=rank,
            shuffle=True,
        )
    dynamic_sampler = DynamicBatchSampler(ran_sampler, train_dataset.get_dur, num_buckets=num_buckets, max_size=20,
                                          max_tokens=max_duration)


    train_loader = torch.utils.data.DataLoader(train_dataset, num_workers=num_workers, collate_fn=collate,
                                               batch_sampler=dynamic_sampler)

    return train_loader


================================================
FILE: data/fbank.py
================================================
# Copyright      2023                          (authors: Feiteng Li)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from dataclasses import asdict, dataclass
from typing import Any, Dict, Optional, Union

import numpy as np
import torch
# from lhotse.features.base import FeatureExtractor
# from lhotse.utils import EPSILON, Seconds, compute_num_frames
from librosa.filters import mel as librosa_mel_fn


@dataclass
class BigVGANFbankConfig:
    # Spectogram-related part
    # Note that frame_length and frame_shift will be converted to milliseconds before torchaudio/Kaldi sees them
    frame_length: Seconds = 1024 / 24000.0
    frame_shift: Seconds = 256 / 24000.0
    remove_dc_offset: bool = True
    round_to_power_of_two: bool = True

    # Fbank-related part
    low_freq: float = 0.0
    high_freq: float = 12000.0
    num_mel_bins: int = 100
    use_energy: bool = False

    def to_dict(self) -> Dict[str, Any]:
        return asdict(self)

    @staticmethod
    def from_dict(data: Dict[str, Any]) -> "BigVGANFbankConfig":
        return BigVGANFbankConfig(**data)


def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
    return torch.log(torch.clamp(x, min=clip_val) * C)


def spectral_normalize_torch(magnitudes):
    output = dynamic_range_compression_torch(magnitudes)
    return output


# https://github.com/NVIDIA/BigVGAN
# bigvgan_24khz_100band https://drive.google.com/drive/folders/1EpxX6AsxjCbbk0mmAhE0td6eYiABr8Oz
class BigVGANFbank(FeatureExtractor):
    name = "fbank"
    config_type = BigVGANFbankConfig

    def __init__(self, config: Optional[Any] = None):
        super(BigVGANFbank, self).__init__(config)
        sampling_rate = 24000
        self.mel_basis = torch.from_numpy(
            librosa_mel_fn(
                sampling_rate,
                1024,
                self.config.num_mel_bins,
                self.config.low_freq,
                self.config.high_freq,
            ).astype(np.float32)
        )
        self.hann_window = torch.hann_window(1024)

    def _feature_fn(self, samples, **kwargs):
        win_length, n_fft = 1024, 1024
        hop_size = 256
        if True:
            sampling_rate = 24000
            duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
            expected_num_frames = compute_num_frames(
                duration=duration,
                frame_shift=self.frame_shift,
                sampling_rate=sampling_rate,
            )
            pad_size = (
                (expected_num_frames - 1) * hop_size
                + win_length
                - samples.shape[-1]
            )
            assert pad_size >= 0

            y = torch.nn.functional.pad(
                samples,
                (0, pad_size),
                mode="constant",
            )
        else:
            y = torch.nn.functional.pad(
                samples,
                (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
                mode="reflect",
            )

        y = y.squeeze(1)

        # complex tensor as default, then use view_as_real for future pytorch compatibility
        spec = torch.stft(
            y,
            n_fft,
            hop_length=hop_size,
            win_length=win_length,
            window=self.hann_window,
            center=False,
            pad_mode="reflect",
            normalized=False,
            onesided=True,
            return_complex=True,
        )
        spec = torch.view_as_real(spec)
        spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))

        spec = torch.matmul(self.mel_basis, spec)
        spec = spectral_normalize_torch(spec)

        return spec.transpose(2, 1).squeeze(0)

    def extract(
        self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
    ) -> np.ndarray:
        assert sampling_rate == 24000
        params = asdict(self.config)
        params.update({"sample_frequency": sampling_rate, "snip_edges": False})
        params["frame_shift"] *= 1000.0
        params["frame_length"] *= 1000.0
        if not isinstance(samples, torch.Tensor):
            samples = torch.from_numpy(samples)
        # Torchaudio Kaldi feature extractors expect the channel dimension to be first.
        if len(samples.shape) == 1:
            samples = samples.unsqueeze(0)
        features = self._feature_fn(samples, **params).to(torch.float32)
        return features.numpy()

    @property
    def frame_shift(self) -> Seconds:
        return self.config.frame_shift

    def feature_dim(self, sampling_rate: int) -> int:
        return self.config.num_mel_bins

    @staticmethod
    def mix(
        features_a: np.ndarray,
        features_b: np.ndarray,
        energy_scaling_factor_b: float,
    ) -> np.ndarray:
        return np.log(
            np.maximum(
                # protection against log(0); max with EPSILON is adequate since these are energies (always >= 0)
                EPSILON,
                np.exp(features_a)
                + energy_scaling_factor_b * np.exp(features_b),
            )
        )

    @staticmethod
    def compute_energy(features: np.ndarray) -> float:
        return float(np.sum(np.exp(features)))


def get_fbank_extractor() -> BigVGANFbank:
    return BigVGANFbank(BigVGANFbankConfig())


if __name__ == "__main__":
    extractor = BigVGANFbank(BigVGANFbankConfig())

    samples = torch.from_numpy(np.random.random([1000]).astype(np.float32))
    samples = torch.clip(samples, -1.0, 1.0)
    fbank = extractor.extract(samples, 24000.0)
    print(f"fbank {fbank.shape}")

    from scipy.io.wavfile import read

    MAX_WAV_VALUE = 32768.0

    sampling_rate, samples = read(
        "egs/libritts/prompts/5639_40744_000000_000002.wav"
    )
    print(f"samples: [{samples.min()}, {samples.max()}]")
    fbank = extractor.extract(samples.astype(np.float32) / MAX_WAV_VALUE, 24000)
    print(f"fbank {fbank.shape}")

    import matplotlib.pyplot as plt

    _ = plt.figure(figsize=(18, 10))
    plt.imshow(
        X=fbank.transpose(1, 0),
        cmap=plt.get_cmap("jet"),
        aspect="auto",
        interpolation="nearest",
    )
    plt.gca().invert_yaxis()
    plt.savefig("egs/libritts/prompts/5639_40744_000000_000002.png")
    plt.close()

    print("fbank test PASS!")


================================================
FILE: data/input_strategies.py
================================================
import random
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import Tuple, Type

# from lhotse import CutSet
# from lhotse.dataset.collation import collate_features
# from lhotse.dataset.input_strategies import (
#     ExecutorType,
#     PrecomputedFeatures,
#     _get_executor,
# )
# from lhotse.utils import fastcopy


class PromptedFeatures:
    def __init__(self, prompts, features):
        self.prompts = prompts
        self.features = features

    def to(self, device):
        return PromptedFeatures(
            self.prompts.to(device), self.features.to(device)
        )

    def sum(self):
        return self.features.sum()

    @property
    def ndim(self):
        return self.features.ndim

    @property
    def data(self):
        return (self.prompts, self.features)


# class PromptedPrecomputedFeatures(PrecomputedFeatures):
#     """
#     :class:`InputStrategy` that reads pre-computed features, whose manifests
#     are attached to cuts, from disk.
#
#     It automatically pads the feature matrices with pre or post feature.
#
#     .. automethod:: __call__
#     """
#
#     def __init__(
#         self,
#         dataset: str,
#         cuts: CutSet,
#         num_workers: int = 0,
#         executor_type: Type[ExecutorType] = ThreadPoolExecutor,
#     ) -> None:
#         super(PromptedPrecomputedFeatures, self).__init__(
#             num_workers, executor_type
#         )
#
#         self.utt2neighbors = defaultdict(lambda: [])
#
#         if dataset.lower() == "libritts":
#             # 909_131041_000013_000002
#             # 909_131041_000013_000003
#             speaker2utts = defaultdict(lambda: [])
#
#             utt2cut = {}
#             for cut in cuts:
#                 speaker = cut.supervisions[0].speaker
#                 speaker2utts[speaker].append(cut.id)
#                 utt2cut[cut.id] = cut
#
#             for spk in speaker2utts:
#                 uttids = sorted(speaker2utts[spk])
#                 # Using the property of sorted keys to find previous utterance
#                 # The keys has structure speaker_book_x_y e.g. 1089_134691_000004_000001
#                 if len(uttids) == 1:
#                     self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]])
#                     continue
#
#                 utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1]))
#                 utt2postutt = dict(zip(uttids[:-1], uttids[1:]))
#
#                 for utt in utt2prevutt:
#                     self.utt2neighbors[utt].append(utt2cut[utt2prevutt[utt]])
#
#                 for utt in utt2postutt:
#                     self.utt2neighbors[utt].append(utt2cut[utt2postutt[utt]])
#         elif dataset.lower() == "ljspeech":
#             utt2cut = {}
#             uttids = []
#             for cut in cuts:
#                 uttids.append(cut.id)
#                 utt2cut[cut.id] = cut
#
#             if len(uttids) == 1:
#                 self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]])
#             else:
#                 # Using the property of sorted keys to find previous utterance
#                 # The keys has structure: LJ001-0010
#                 utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1]))
#                 utt2postutt = dict(zip(uttids[:-1], uttids[1:]))
#
#                 for utt in utt2postutt:
#                     postutt = utt2postutt[utt]
#                     if utt[:5] == postutt[:5]:
#                         self.utt2neighbors[utt].append(utt2cut[postutt])
#
#                 for utt in utt2prevutt:
#                     prevutt = utt2prevutt[utt]
#                     if utt[:5] == prevutt[:5] or not self.utt2neighbors[utt]:
#                         self.utt2neighbors[utt].append(utt2cut[prevutt])
#         else:
#             raise ValueError
#
#     def __call__(
#         self, cuts: CutSet
#     ) -> Tuple[PromptedFeatures, PromptedFeatures]:
#         """
#         Reads the pre-computed features from disk/other storage.
#         The returned shape is``(B, T, F) => (batch_size, num_frames, num_features)``.
#
#         :return: a tensor with collated features, and a tensor of ``num_frames`` of each cut before padding.
#         """
#         features, features_lens = collate_features(
#             cuts,
#             executor=_get_executor(
#                 self.num_workers, executor_type=self._executor_type
#             ),
#         )
#
#         prompts_cuts = []
#         for k, cut in enumerate(cuts):
#             prompts_cut = random.choice(self.utt2neighbors[cut.id])
#             prompts_cuts.append(fastcopy(prompts_cut, id=f"{cut.id}-{str(k)}"))
#
#         mini_duration = min([cut.duration for cut in prompts_cuts] + [3.0])
#         # prompts_cuts = CutSet.from_cuts(prompts_cuts).truncate(
#         #     max_duration=mini_duration,
#         #     offset_type="random",
#         #     preserve_id=True,
#         # )
#         prompts_cuts = CutSet(
#             cuts={k: cut for k, cut in enumerate(prompts_cuts)}
#         ).truncate(
#             max_duration=mini_duration,
#             offset_type="random",
#             preserve_id=False,
#         )
#
#         prompts, prompts_lens = collate_features(
#             prompts_cuts,
#             executor=_get_executor(
#                 self.num_workers, executor_type=self._executor_type
#             ),
#         )
#
#         return PromptedFeatures(prompts, features), PromptedFeatures(
#             prompts_lens, features_lens
#         )


================================================
FILE: data/tokenizer.py
================================================
#!/usr/bin/env python3
# Copyright    2023                            (authors: Feiteng Li)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Optional, Pattern, Union

import numpy as np
import torch
import torchaudio
from encodec import EncodecModel
from encodec.utils import convert_audio

try:
    from pypinyin import Style, pinyin
    from pypinyin.style._utils import get_finals, get_initials
except Exception:
    pass


def remove_encodec_weight_norm(model):
    from encodec.modules import SConv1d
    from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
    from torch.nn.utils import remove_weight_norm

    encoder = model.encoder.model
    for key in encoder._modules:
        if isinstance(encoder._modules[key], SEANetResnetBlock):
            remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
            block_modules = encoder._modules[key].block._modules
            for skey in block_modules:
                if isinstance(block_modules[skey], SConv1d):
                    remove_weight_norm(block_modules[skey].conv.conv)
        elif isinstance(encoder._modules[key], SConv1d):
            remove_weight_norm(encoder._modules[key].conv.conv)

    decoder = model.decoder.model
    for key in decoder._modules:
        if isinstance(decoder._modules[key], SEANetResnetBlock):
            remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
            block_modules = decoder._modules[key].block._modules
            for skey in block_modules:
                if isinstance(block_modules[skey], SConv1d):
                    remove_weight_norm(block_modules[skey].conv.conv)
        elif isinstance(decoder._modules[key], SConvTranspose1d):
            remove_weight_norm(decoder._modules[key].convtr.convtr)
        elif isinstance(decoder._modules[key], SConv1d):
            remove_weight_norm(decoder._modules[key].conv.conv)


class AudioTokenizer:
    """EnCodec audio."""

    def __init__(
        self,
        device: Any = None,
    ) -> None:
        # Instantiate a pretrained EnCodec model
        model = EncodecModel.encodec_model_24khz()
        model.set_target_bandwidth(6.0)
        remove_encodec_weight_norm(model)

        if not device:
            device = torch.device("cpu")
            if torch.cuda.is_available():
                device = torch.device("cuda:0")
            if torch.backends.mps.is_available():
                device = torch.device("mps")

        self._device = device

        self.codec = model.to(device)
        self.sample_rate = model.sample_rate
        self.channels = model.channels

    @property
    def device(self):
        return self._device

    def encode(self, wav: torch.Tensor) -> torch.Tensor:
        return self.codec.encode(wav.to(self.device))

    def decode(self, frames: torch.Tensor) -> torch.Tensor:
        return self.codec.decode(frames)


def tokenize_audio(tokenizer: AudioTokenizer, audio):
    # Load and pre-process the audio waveform
    if isinstance(audio, str):
        wav, sr = torchaudio.load(audio)
    else:
        wav, sr = audio
    wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
    wav = wav.unsqueeze(0)

    # Extract discrete codes from EnCodec
    with torch.no_grad():
        encoded_frames = tokenizer.encode(wav)
    return encoded_frames


if __name__ == "__main__":
    model = EncodecModel.encodec_model_24khz()
    model.set_target_bandwidth(6.0)

    samples = torch.from_numpy(np.random.random([4, 1, 1600])).type(
        torch.float32
    )
    codes_raw = model.encode(samples)

    remove_encodec_weight_norm(model)
    codes_norm = model.encode(samples)

    assert torch.allclose(codes_raw[0][0], codes_norm[0][0])


================================================
FILE: descriptions.py
================================================
top_md = """
# VALL-E X  
VALL-E X can synthesize high-quality personalized speech with only a 3-second enrolled recording of 
an unseen speaker as an acoustic prompt, even in another language for a monolingual speaker.<br>
This implementation supports zero-shot, mono-lingual/cross-lingual text-to-speech functionality of three languages (English, Chinese, Japanese)<br>  
See this [demo](https://plachtaa.github.io/) page for more details.
"""

infer_from_audio_md = """
Upload a speech of 3~10 seconds as the audio prompt and type in the text you'd like to synthesize.<br>
The model will synthesize speech of given text with the same voice of your audio prompt.<br>
The model also tends to preserve the emotion & acoustic environment of your given speech.<br>
For faster inference, please use **"Make prompt"** to get a `.npz` file as the encoded audio prompt, and use it by **"Infer from prompt"**
"""

make_prompt_md = """
Upload a speech of 3~10 seconds as the audio prompt.<br>
Get a `.npz` file as the encoded audio prompt. Use it by **"Infer with prompt"**
"""

infer_from_prompt_md = """
Faster than **"Infer from audio"**.<br>
You need to **"Make prompt"** first, and upload the encoded prompt (a `.npz` file)
"""

long_text_md = """
Very long text is chunked into several sentences, and each sentence is synthesized separately.<br>
Please make a prompt or use a preset prompt to infer long text.
"""

long_text_example = "Just a few years ago, there were no legions of deep learning scientists developing intelligent products and services at major companies and startups. When we entered the field, machine learning did not command headlines in daily newspapers. Our parents had no idea what machine learning was, let alone why we might prefer it to a career in medicine or law. Machine learning was a blue skies academic discipline whose industrial significance was limited to a narrow set of real-world applications, including speech recognition and computer vision. Moreover, many of these applications required so much domain knowledge that they were often regarded as entirely separate areas for which machine learning was one small component. At that time, neural networks—the predecessors of the deep learning methods that we focus on in this book—were generally regarded as outmoded."

================================================
FILE: examples.py
================================================
infer_from_audio_examples = [
    ["This is how this machine has taken my voice.", 'English', 'no-accent', "prompts/en-2.wav", None, "Wow, look at that! That's no ordinary Teddy bear!"],
    ["我喜欢抽电子烟,尤其是锐刻五代。", '中文', 'no-accent', "prompts/zh-1.wav", None, "今天我很荣幸,"],
    ["私の声を真似するのはそんなに面白いですか?", '日本語', 'no-accent', "prompts/ja-2.ogg", None, "初めまして、朝武よしのです。"],
    ["你可以听得出来我有多困。", '中文', 'no-accent', "prompts/en-1.wav", None, ""],
    ["この文は、クロスリンガル合成の例です。", '日本語', 'no-accent', "prompts/zh-2.wav", None, ""],
    ["Actually, I can't speak English, but this machine helped me do it.", 'English', 'no-accent', "prompts/ja-1.wav", None, ""],
]

make_npz_prompt_examples = [
    ["Gem-trader", "prompts/en-2.wav", None, "Wow, look at that! That's no ordinary Teddy bear!"],
    ["Ding Zhen", "prompts/zh-1.wav", None, "今天我很荣幸,"],
    ["Yoshino", "prompts/ja-2.ogg", None, "初めまして、朝武よしのです。"],
    ["Sleepy-woman", "prompts/en-1.wav", None, ""],
    ["Yae", "prompts/zh-2.wav", None, ""],
    ["Cafe", "prompts/ja-1.wav", None, ""],
]

infer_from_prompt_examples = [
    ["A prompt contains voice, prosody and emotion information of a certain speaker.", "English", "no-accent", "vctk_1", None],
    ["This prompt is made with an audio of three seconds.", "English", "no-accent", "librispeech_1", None],
    ["This prompt is made with Chinese speech", "English", "no-accent", "seel", None],
]



================================================
FILE: launch-ui.py
================================================
# coding: utf-8
import argparse
import logging
import os
import pathlib
import time
import tempfile
import platform
import webbrowser
import sys
print(f"default encoding is {sys.getdefaultencoding()},file system encoding is {sys.getfilesystemencoding()}")
print(f"You are using Python version {platform.python_version()}")
if(sys.version_info[0]<3 or sys.version_info[1]<7):
    print("The Python version is too low and may cause problems")

if platform.system().lower() == 'windows':
    temp = pathlib.PosixPath
    pathlib.PosixPath = pathlib.WindowsPath
else:
    temp = pathlib.WindowsPath
    pathlib.WindowsPath = pathlib.PosixPath
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"

import langid
langid.set_languages(['en', 'zh', 'ja'])

import nltk
nltk.data.path = nltk.data.path + [os.path.join(os.getcwd(), "nltk_data")]

import torch
import torchaudio
import random

import numpy as np

from data.tokenizer import (
    AudioTokenizer,
    tokenize_audio,
)
from data.collation import get_text_token_collater
from models.vallex import VALLE
from utils.g2p import PhonemeBpeTokenizer
from descriptions import *
from macros import *
from examples import *

import gradio as gr
import whisper
from vocos import Vocos
import multiprocessing

thread_count = multiprocessing.cpu_count()

print("Use",thread_count,"cpu cores for computing")

torch.set_num_threads(thread_count)
torch.set_num_interop_threads(thread_count)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
torch._C._set_graph_executor_optimize(False)

text_tokenizer = PhonemeBpeTokenizer(tokenizer_path="./utils/g2p/bpe_69.json")
text_collater = get_text_token_collater()

device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda", 0)
if torch.backends.mps.is_available():
    device = torch.device("mps")
# VALL-E-X model
if not os.path.exists("./checkpoints/"): os.mkdir("./checkpoints/")
if not os.path.exists(os.path.join("./checkpoints/", "vallex-checkpoint.pt")):
    import wget
    try:
        logging.info("Downloading model from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt ...")
        # download from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt to ./checkpoints/vallex-checkpoint.pt
        wget.download("https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt",
                      out="./checkpoints/vallex-checkpoint.pt", bar=wget.bar_adaptive)
    except Exception as e:
        logging.info(e)
        raise Exception(
            "\n Model weights download failed, please go to 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'"
            "\n manually download model weights and put it to {} .".format(os.getcwd() + "\checkpoints"))

model = VALLE(
        N_DIM,
        NUM_HEAD,
        NUM_LAYERS,
        norm_first=True,
        add_prenet=False,
        prefix_mode=PREFIX_MODE,
        share_embedding=True,
        nar_scale_factor=1.0,
        prepend_bos=True,
        num_quantizers=NUM_QUANTIZERS,
    )
checkpoint = torch.load("./checkpoints/vallex-checkpoint.pt", map_location='cpu')
missing_keys, unexpected_keys = model.load_state_dict(
    checkpoint["model"], strict=True
)
assert not missing_keys
model.eval()

# Encodec model
audio_tokenizer = AudioTokenizer(device)

# Vocos decoder
vocos = Vocos.from_pretrained('charactr/vocos-encodec-24khz').to(device)

# ASR
if not os.path.exists("./whisper/"): os.mkdir("./whisper/")
try:
    whisper_model = whisper.load_model("medium",download_root=os.path.join(os.getcwd(), "whisper")).cpu()
except Exception as e:
    logging.info(e)
    raise Exception(
        "\n Whisper download failed or damaged, please go to "
        "'https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt'"
        "\n manually download model and put it to {} .".format(os.getcwd() + "\whisper"))

# Voice Presets
preset_list = os.walk("./presets/").__next__()[2]
preset_list = [preset[:-4] for preset in preset_list if preset.endswith(".npz")]

def clear_prompts():
    try:
        path = tempfile.gettempdir()
        for eachfile in os.listdir(path):
            filename = os.path.join(path, eachfile)
            if os.path.isfile(filename) and filename.endswith(".npz"):
                lastmodifytime = os.stat(filename).st_mtime
                endfiletime = time.time() - 60
                if endfiletime > lastmodifytime:
                    os.remove(filename)
    except:
        return

def transcribe_one(model, audio_path):
    # load audio and pad/trim it to fit 30 seconds
    audio = whisper.load_audio(audio_path)
    audio = whisper.pad_or_trim(audio)

    # make log-Mel spectrogram and move to the same device as the model
    mel = whisper.log_mel_spectrogram(audio).to(model.device)

    # detect the spoken language
    _, probs = model.detect_language(mel)
    print(f"Detected language: {max(probs, key=probs.get)}")
    lang = max(probs, key=probs.get)
    # decode the audio
    options = whisper.DecodingOptions(temperature=1.0, best_of=5, fp16=False if device == torch.device("cpu") else True, sample_len=150)
    result = whisper.decode(model, mel, options)

    # print the recognized text
    print(result.text)

    text_pr = result.text
    if text_pr.strip(" ")[-1] not in "?!.,。,?!。、":
        text_pr += "."
    return lang, text_pr

def make_npz_prompt(name, uploaded_audio, recorded_audio, transcript_content):
    global model, text_collater, text_tokenizer, audio_tokenizer
    clear_prompts()
    audio_prompt = uploaded_audio if uploaded_audio is not None else recorded_audio
    sr, wav_pr = audio_prompt
    if not isinstance(wav_pr, torch.FloatTensor):
        wav_pr = torch.FloatTensor(wav_pr)
    if wav_pr.abs().max() > 1:
        wav_pr /= wav_pr.abs().max()
    if wav_pr.size(-1) == 2:
        wav_pr = wav_pr[:, 0]
    if wav_pr.ndim == 1:
        wav_pr = wav_pr.unsqueeze(0)
    assert wav_pr.ndim and wav_pr.size(0) == 1

    if transcript_content == "":
        text_pr, lang_pr = make_prompt(name, wav_pr, sr, save=False)
    else:
        lang_pr = langid.classify(str(transcript_content))[0]
        lang_token = lang2token[lang_pr]
        text_pr = f"{lang_token}{str(transcript_content)}{lang_token}"
    # tokenize audio
    encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr))
    audio_tokens = encoded_frames[0][0].transpose(2, 1).cpu().numpy()

    # tokenize text
    phonemes, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip())
    text_tokens, enroll_x_lens = text_collater(
        [
            phonemes
        ]
    )

    message = f"Detected language: {lang_pr}\n Detected text {text_pr}\n"

    # save as npz file
    np.savez(os.path.join(tempfile.gettempdir(), f"{name}.npz"),
             audio_tokens=audio_tokens, text_tokens=text_tokens, lang_code=lang2code[lang_pr])
    return message, os.path.join(tempfile.gettempdir(), f"{name}.npz")


def make_prompt(name, wav, sr, save=True):
    global whisper_model
    whisper_model.to(device)
    if not isinstance(wav, torch.FloatTensor):
        wav = torch.tensor(wav)
    if wav.abs().max() > 1:
        wav /= wav.abs().max()
    if wav.size(-1) == 2:
        wav = wav.mean(-1, keepdim=False)
    if wav.ndim == 1:
        wav = wav.unsqueeze(0)
    assert wav.ndim and wav.size(0) == 1
    torchaudio.save(f"./prompts/{name}.wav", wav, sr)
    lang, text = transcribe_one(whisper_model, f"./prompts/{name}.wav")
    lang_token = lang2token[lang]
    text = lang_token + text + lang_token
    with open(f"./prompts/{name}.txt", 'w', encoding='utf-8') as f:
        f.write(text)
    if not save:
        os.remove(f"./prompts/{name}.wav")
        os.remove(f"./prompts/{name}.txt")

    whisper_model.cpu()
    torch.cuda.empty_cache()
    return text, lang

@torch.no_grad()
def infer_from_audio(text, language, accent, audio_prompt, record_audio_prompt, transcript_content):
    global model, text_collater, text_tokenizer, audio_tokenizer
    audio_prompt = audio_prompt if audio_prompt is not None else record_audio_prompt
    sr, wav_pr = audio_prompt
    if not isinstance(wav_pr, torch.FloatTensor):
        wav_pr = torch.FloatTensor(wav_pr)
    if wav_pr.abs().max() > 1:
        wav_pr /= wav_pr.abs().max()
    if wav_pr.size(-1) == 2:
        wav_pr = wav_pr[:, 0]
    if wav_pr.ndim == 1:
        wav_pr = wav_pr.unsqueeze(0)
    assert wav_pr.ndim and wav_pr.size(0) == 1

    if transcript_content == "":
        text_pr, lang_pr = make_prompt('dummy', wav_pr, sr, save=False)
    else:
        lang_pr = langid.classify(str(transcript_content))[0]
        lang_token = lang2token[lang_pr]
        text_pr = f"{lang_token}{str(transcript_content)}{lang_token}"

    if language == 'auto-detect':
        lang_token = lang2token[langid.classify(text)[0]]
    else:
        lang_token = langdropdown2token[language]
    lang = token2lang[lang_token]
    text = lang_token + text + lang_token

    # onload model
    model.to(device)

    # tokenize audio
    encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr))
    audio_prompts = encoded_frames[0][0].transpose(2, 1).to(device)

    # tokenize text
    logging.info(f"synthesize text: {text}")
    phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
    text_tokens, text_tokens_lens = text_collater(
        [
            phone_tokens
        ]
    )

    enroll_x_lens = None
    if text_pr:
        text_prompts, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip())
        text_prompts, enroll_x_lens = text_collater(
            [
                text_prompts
            ]
        )
    text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
    text_tokens_lens += enroll_x_lens
    lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
    encoded_frames = model.inference(
        text_tokens.to(device),
        text_tokens_lens.to(device),
        audio_prompts,
        enroll_x_lens=enroll_x_lens,
        top_k=-100,
        temperature=1,
        prompt_language=lang_pr,
        text_language=langs if accent == "no-accent" else lang,
        best_of=5,
    )
    # Decode with Vocos
    frames = encoded_frames.permute(2,0,1)
    features = vocos.codes_to_features(frames)
    samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))

    # offload model
    model.to('cpu')
    torch.cuda.empty_cache()

    message = f"text prompt: {text_pr}\nsythesized text: {text}"
    return message, (24000, samples.squeeze(0).cpu().numpy())

@torch.no_grad()
def infer_from_prompt(text, language, accent, preset_prompt, prompt_file):
    clear_prompts()
    model.to(device)
    # text to synthesize
    if language == 'auto-detect':
        lang_token = lang2token[langid.classify(text)[0]]
    else:
        lang_token = langdropdown2token[language]
    lang = token2lang[lang_token]
    text = lang_token + text + lang_token

    # load prompt
    if prompt_file is not None:
        prompt_data = np.load(prompt_file.name)
    else:
        prompt_data = np.load(os.path.join("./presets/", f"{preset_prompt}.npz"))
    audio_prompts = prompt_data['audio_tokens']
    text_prompts = prompt_data['text_tokens']
    lang_pr = prompt_data['lang_code']
    lang_pr = code2lang[int(lang_pr)]

    # numpy to tensor
    audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)
    text_prompts = torch.tensor(text_prompts).type(torch.int32)

    enroll_x_lens = text_prompts.shape[-1]
    logging.info(f"synthesize text: {text}")
    phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
    text_tokens, text_tokens_lens = text_collater(
        [
            phone_tokens
        ]
    )
    text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
    text_tokens_lens += enroll_x_lens
    # accent control
    lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
    encoded_frames = model.inference(
        text_tokens.to(device),
        text_tokens_lens.to(device),
        audio_prompts,
        enroll_x_lens=enroll_x_lens,
        top_k=-100,
        temperature=1,
        prompt_language=lang_pr,
        text_language=langs if accent == "no-accent" else lang,
        best_of=5,
    )
    # Decode with Vocos
    frames = encoded_frames.permute(2,0,1)
    features = vocos.codes_to_features(frames)
    samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))

    model.to('cpu')
    torch.cuda.empty_cache()

    message = f"sythesized text: {text}"
    return message, (24000, samples.squeeze(0).cpu().numpy())


from utils.sentence_cutter import split_text_into_sentences
@torch.no_grad()
def infer_long_text(text, preset_prompt, prompt=None, language='auto', accent='no-accent'):
    """
    For long audio generation, two modes are available.
    fixed-prompt: This mode will keep using the same prompt the user has provided, and generate audio sentence by sentence.
    sliding-window: This mode will use the last sentence as the prompt for the next sentence, but has some concern on speaker maintenance.
    """
    mode = 'fixed-prompt'
    global model, audio_tokenizer, text_tokenizer, text_collater
    model.to(device)
    if (prompt is None or prompt == "") and preset_prompt == "":
        mode = 'sliding-window'  # If no prompt is given, use sliding-window mode
    sentences = split_text_into_sentences(text)
    # detect language
    if language == "auto-detect":
        language = langid.classify(text)[0]
    else:
        language = token2lang[langdropdown2token[language]]

    # if initial prompt is given, encode it
    if prompt is not None and prompt != "":
        # load prompt
        prompt_data = np.load(prompt.name)
        audio_prompts = prompt_data['audio_tokens']
        text_prompts = prompt_data['text_tokens']
        lang_pr = prompt_data['lang_code']
        lang_pr = code2lang[int(lang_pr)]

        # numpy to tensor
        audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)
        text_prompts = torch.tensor(text_prompts).type(torch.int32)
    elif preset_prompt is not None and preset_prompt != "":
        prompt_data = np.load(os.path.join("./presets/", f"{preset_prompt}.npz"))
        audio_prompts = prompt_data['audio_tokens']
        text_prompts = prompt_data['text_tokens']
        lang_pr = prompt_data['lang_code']
        lang_pr = code2lang[int(lang_pr)]

        # numpy to tensor
        audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)
        text_prompts = torch.tensor(text_prompts).type(torch.int32)
    else:
        audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device)
        text_prompts = torch.zeros([1, 0]).type(torch.int32)
        lang_pr = language if language != 'mix' else 'en'
    if mode == 'fixed-prompt':
        complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)
        for text in sentences:
            text = text.replace("\n", "").strip(" ")
            if text == "":
                continue
            lang_token = lang2token[language]
            lang = token2lang[lang_token]
            text = lang_token + text + lang_token

            enroll_x_lens = text_prompts.shape[-1]
            logging.info(f"synthesize text: {text}")
            phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
            text_tokens, text_tokens_lens = text_collater(
                [
                    phone_tokens
                ]
            )
            text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
            text_tokens_lens += enroll_x_lens
            # accent control
            lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
            encoded_frames = model.inference(
                text_tokens.to(device),
                text_tokens_lens.to(device),
                audio_prompts,
                enroll_x_lens=enroll_x_lens,
                top_k=-100,
                temperature=1,
                prompt_language=lang_pr,
                text_language=langs if accent == "no-accent" else lang,
                best_of=5,
            )
            complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)
        # Decode with Vocos
        frames = complete_tokens.permute(1, 0, 2)
        features = vocos.codes_to_features(frames)
        samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))

        model.to('cpu')
        message = f"Cut into {len(sentences)} sentences"
        return message, (24000, samples.squeeze(0).cpu().numpy())
    elif mode == "sliding-window":
        complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)
        original_audio_prompts = audio_prompts
        original_text_prompts = text_prompts
        for text in sentences:
            text = text.replace("\n", "").strip(" ")
            if text == "":
                continue
            lang_token = lang2token[language]
            lang = token2lang[lang_token]
            text = lang_token + text + lang_token

            enroll_x_lens = text_prompts.shape[-1]
            logging.info(f"synthesize text: {text}")
            phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
            text_tokens, text_tokens_lens = text_collater(
                [
                    phone_tokens
                ]
            )
            text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
            text_tokens_lens += enroll_x_lens
            # accent control
            lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
            encoded_frames = model.inference(
                text_tokens.to(device),
                text_tokens_lens.to(device),
                audio_prompts,
                enroll_x_lens=enroll_x_lens,
                top_k=-100,
                temperature=1,
                prompt_language=lang_pr,
                text_language=langs if accent == "no-accent" else lang,
                best_of=5,
            )
            complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)
            if torch.rand(1) < 1.0:
                audio_prompts = encoded_frames[:, :, -NUM_QUANTIZERS:]
                text_prompts = text_tokens[:, enroll_x_lens:]
            else:
                audio_prompts = original_audio_prompts
                text_prompts = original_text_prompts
        # Decode with Vocos
        frames = complete_tokens.permute(1, 0, 2)
        features = vocos.codes_to_features(frames)
        samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))

        model.to('cpu')
        message = f"Cut into {len(sentences)} sentences"
        return message, (24000, samples.squeeze(0).cpu().numpy())
    else:
        raise ValueError(f"No such mode {mode}")


def main():
    app = gr.Blocks(title="VALL-E X")
    with app:
        gr.Markdown(top_md)
        with gr.Tab("Infer from audio"):
            gr.Markdown(infer_from_audio_md)
            with gr.Row():
                with gr.Column():

                    textbox = gr.TextArea(label="Text",
                                          placeholder="Type your sentence here",
                                          value="Welcome back, Master. What can I do for you today?", elem_id=f"tts-input")
                    language_dropdown = gr.Dropdown(choices=['auto-detect', 'English', '中文', '日本語'], value='auto-detect', label='language')
                    accent_dropdown = gr.Dropdown(choices=['no-accent', 'English', '中文', '日本語'], value='no-accent', label='accent')
                    textbox_transcript = gr.TextArea(label="Transcript",
                                          placeholder="Write transcript here. (leave empty to use whisper)",
                                          value="", elem_id=f"prompt-name")
                    upload_audio_prompt = gr.Audio(label='uploaded audio prompt', source='upload', interactive=True)
                    record_audio_prompt = gr.Audio(label='recorded audio prompt', source='microphone', interactive=True)
                with gr.Column():
                    text_output = gr.Textbox(label="Message")
                    audio_output = gr.Audio(label="Output Audio", elem_id="tts-audio")
                    btn = gr.Button("Generate!")
                    btn.click(infer_from_audio,
                              inputs=[textbox, language_dropdown, accent_dropdown, upload_audio_prompt, record_audio_prompt, textbox_transcript],
                              outputs=[text_output, audio_output])
                    textbox_mp = gr.TextArea(label="Prompt name",
                                          placeholder="Name your prompt here",
                                          value="prompt_1", elem_id=f"prompt-name")
                    btn_mp = gr.Button("Make prompt!")
                    prompt_output = gr.File(interactive=False)
                    btn_mp.click(make_npz_prompt,
                                inputs=[textbox_mp, upload_audio_prompt, record_audio_prompt, textbox_transcript],
                                outputs=[text_output, prompt_output])
            gr.Examples(examples=infer_from_audio_examples,
                        inputs=[textbox, language_dropdown, accent_dropdown, upload_audio_prompt, record_audio_prompt, textbox_transcript],
                        outputs=[text_output, audio_output],
                        fn=infer_from_audio,
                        cache_examples=False,)
        with gr.Tab("Make prompt"):
            gr.Markdown(make_prompt_md)
            with gr.Row():
                with gr.Column():
                    textbox2 = gr.TextArea(label="Prompt name",
                                          placeholder="Name your prompt here",
                                          value="prompt_1", elem_id=f"prompt-name")
                    # 添加选择语言和输入台本的地方
                    textbox_transcript2 = gr.TextArea(label="Transcript",
                                          placeholder="Write transcript here. (leave empty to use whisper)",
                                          value="", elem_id=f"prompt-name")
                    upload_audio_prompt_2 = gr.Audio(label='uploaded audio prompt', source='upload', interactive=True)
                    record_audio_prompt_2 = gr.Audio(label='recorded audio prompt', source='microphone', interactive=True)
                with gr.Column():
                    text_output_2 = gr.Textbox(label="Message")
                    prompt_output_2 = gr.File(interactive=False)
                    btn_2 = gr.Button("Make!")
                    btn_2.click(make_npz_prompt,
                              inputs=[textbox2, upload_audio_prompt_2, record_audio_prompt_2, textbox_transcript2],
                              outputs=[text_output_2, prompt_output_2])
            gr.Examples(examples=make_npz_prompt_examples,
                        inputs=[textbox2, upload_audio_prompt_2, record_audio_prompt_2, textbox_transcript2],
                        outputs=[text_output_2, prompt_output_2],
                        fn=make_npz_prompt,
                        cache_examples=False,)
        with gr.Tab("Infer from prompt"):
            gr.Markdown(infer_from_prompt_md)
            with gr.Row():
                with gr.Column():
                    textbox_3 = gr.TextArea(label="Text",
                                          placeholder="Type your sentence here",
                                          value="Welcome back, Master. What can I do for you today?", elem_id=f"tts-input")
                    language_dropdown_3 = gr.Dropdown(choices=['auto-detect', 'English', '中文', '日本語', 'Mix'], value='auto-detect',
                                                    label='language')
                    accent_dropdown_3 = gr.Dropdown(choices=['no-accent', 'English', '中文', '日本語'], value='no-accent',
                                                  label='accent')
                    preset_dropdown_3 = gr.Dropdown(choices=preset_list, value=None, label='Voice preset')
                    prompt_file = gr.File(file_count='single', file_types=['.npz'], interactive=True)
                with gr.Column():
                    text_output_3 = gr.Textbox(label="Message")
                    audio_output_3 = gr.Audio(label="Output Audio", elem_id="tts-audio")
                    btn_3 = gr.Button("Generate!")
                    btn_3.click(infer_from_prompt,
                              inputs=[textbox_3, language_dropdown_3, accent_dropdown_3, preset_dropdown_3, prompt_file],
                              outputs=[text_output_3, audio_output_3])
            gr.Examples(examples=infer_from_prompt_examples,
                        inputs=[textbox_3, language_dropdown_3, accent_dropdown_3, preset_dropdown_3, prompt_file],
                        outputs=[text_output_3, audio_output_3],
                        fn=infer_from_prompt,
                        cache_examples=False,)
        with gr.Tab("Infer long text"):
            gr.Markdown("This is a long text generation demo. You can use this to generate long audio. ")
            with gr.Row():
                with gr.Column():
                    textbox_4 = gr.TextArea(label="Text",
                                          placeholder="Type your sentence here",
                                          value=long_text_example, elem_id=f"tts-input")
                    language_dropdown_4 = gr.Dropdown(choices=['auto-detect', 'English', '中文', '日本語'], value='auto-detect',
                                                    label='language')
                    accent_dropdown_4 = gr.Dropdown(choices=['no-accent', 'English', '中文', '日本語'], value='no-accent',
                                                    label='accent')
                    preset_dropdown_4 = gr.Dropdown(choices=preset_list, value=None, label='Voice preset')
                    prompt_file_4 = gr.File(file_count='single', file_types=['.npz'], interactive=True)
                with gr.Column():
                    text_output_4 = gr.TextArea(label="Message")
                    audio_output_4 = gr.Audio(label="Output Audio", elem_id="tts-audio")
                    btn_4 = gr.Button("Generate!")
                    btn_4.click(infer_long_text,
                              inputs=[textbox_4, preset_dropdown_4, prompt_file_4, language_dropdown_4, accent_dropdown_4],
                              outputs=[text_output_4, audio_output_4])

    webbrowser.open("http://127.0.0.1:7860")
    app.launch()

if __name__ == "__main__":
    formatter = (
        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
    )
    logging.basicConfig(format=formatter, level=logging.INFO)
    main()


================================================
FILE: macros.py
================================================
NUM_LAYERS = 12
NUM_HEAD = 16
N_DIM = 1024
PREFIX_MODE = 1
NUM_QUANTIZERS = 8
SAMPLE_RATE = 24000

lang2token = {
    'zh': "[ZH]",
    'ja': "[JA]",
    "en": "[EN]",
    'mix': "",
}

lang2code = {
    'zh': 0,
    'ja': 1,
    "en": 2,
}

token2lang = {
    '[ZH]': "zh",
    '[JA]': "ja",
    "[EN]": "en",
    "": "mix"
}

code2lang = {
    0: 'zh',
    1: 'ja',
    2: "en",
}

langdropdown2token = {
    'English': "[EN]",
    '中文': "[ZH]",
    '日本語': "[JA]",
    'Mix': "",
}

================================================
FILE: model-card.md
================================================
# Model Card: VALL-E X

**Author**: [Songting](https://github.com/Plachtaa).<br>
<br>
This is the official codebase for running open-sourced VALL-E X.

The following is additional information about the models released here.

## Model Details

VALL-E X is a series of two transformer models that turn text into audio.

### Phoneme to acoustic tokens
 - Input: IPAs converted from input text by a rule-based G2P tool.
 - Output: tokens from the first codebook of the [EnCodec Codec](https://github.com/facebookresearch/encodec) from facebook

### Coarse to fine tokens
 - Input: IPAs converted from input text by a rule-based G2P tool & the first codebook from EnCodec
 - Output: 8 codebooks from EnCodec

### Architecture
|          Model           | Parameters | Attention  | Output Vocab size |  
|:------------------------:|:----------:|------------|:-----------------:|
|         G2P tool         |     -      | -          |        69         |
| Phoneme to coarse tokens |   150 M    | Causal     |     1x 1,024      |
|  Coarse to fine tokens   |   150 M    | Non-causal |     7x 1,024      |

### Release date
August 2023

## Broader Implications
We anticipate that this model's text to audio capabilities can be used to improve accessbility tools in a variety of languages. 
Straightforward improvements will allow models to run faster than realtime, rendering them useful for applications such as virtual assistants. 

================================================
FILE: models/__init__.py
================================================
import argparse

import torch.nn as nn
# from icefall.utils import AttributeDict, str2bool

from .macros import (
    NUM_AUDIO_TOKENS,
    NUM_MEL_BINS,
    NUM_SPEAKER_CLASSES,
    NUM_TEXT_TOKENS,
    SPEAKER_EMBEDDING_DIM,
)
from .transformer import Transformer
from .vallex import VALLE, VALLF
from .visualizer import visualize


def add_model_arguments(parser: argparse.ArgumentParser):
    parser.add_argument(
        "--model-name",
        type=str,
        default="VALL-E",
        help="VALL-E, VALL-F, Transformer.",
    )
    parser.add_argument(
        "--decoder-dim",
        type=int,
        default=1024,
        help="Embedding dimension in the decoder model.",
    )
    parser.add_argument(
        "--nhead",
        type=int,
        default=16,
        help="Number of attention heads in the Decoder layers.",
    )
    parser.add_argument(
        "--num-decoder-layers",
        type=int,
        default=12,
        help="Number of Decoder layers.",
    )
    parser.add_argument(
        "--scale-factor",
        type=float,
        default=1.0,
        help="Model scale factor which will be assigned different meanings in different models.",
    )
    parser.add_argument(
        "--norm-first",
        type=bool,
        default=True,
        help="Pre or Post Normalization.",
    )
    parser.add_argument(
        "--add-prenet",
        type=bool,
        default=False,
        help="Whether add PreNet after Inputs.",
    )

    # VALL-E & F
    parser.add_argument(
        "--prefix-mode",
        type=int,
        default=1,
        help="The mode for how to prefix VALL-E NAR Decoder, "
        "0: no prefix, 1: 0 to random, 2: random to random, 4: chunk of pre or post utterance.",
    )
    parser.add_argument(
        "--share-embedding",
        type=bool,
        default=True,
        help="Share the parameters of the output projection layer with the parameters of the acoustic embedding.",
    )
    parser.add_argument(
        "--prepend-bos",
        type=bool,
        default=False,
        help="Whether prepend <BOS> to the acoustic tokens -> AR Decoder inputs.",
    )
    parser.add_argument(
        "--num-quantizers",
        type=int,
        default=8,
        help="Number of Audio/Semantic quantization layers.",
    )

    # Transformer
    parser.add_argument(
        "--scaling-xformers",
        type=bool,
        default=False,
        help="Apply Reworked Conformer scaling on Transformers.",
    )


def get_model(params) -> nn.Module:
    if params.model_name.lower() in ["vall-f", "vallf"]:
        model = VALLF(
            params.decoder_dim,
            params.nhead,
            params.num_decoder_layers,
            norm_first=params.norm_first,
            add_prenet=params.add_prenet,
            prefix_mode=params.prefix_mode,
            share_embedding=params.share_embedding,
            nar_scale_factor=params.scale_factor,
            prepend_bos=params.prepend_bos,
            num_quantizers=params.num_quantizers,
        )
    elif params.model_name.lower() in ["vall-e", "valle"]:
        model = VALLE(
            params.decoder_dim,
            params.nhead,
            params.num_decoder_layers,
            norm_first=params.norm_first,
            add_prenet=params.add_prenet,
            prefix_mode=params.prefix_mode,
            share_embedding=params.share_embedding,
            nar_scale_factor=params.scale_factor,
            prepend_bos=params.prepend_bos,
            num_quantizers=params.num_quantizers,
        )
    else:
        assert params.model_name in ["Transformer"]
        model = Transformer(
            params.decoder_dim,
            params.nhead,
            params.num_decoder_layers,
            norm_first=params.norm_first,
            add_prenet=params.add_prenet,
            scaling_xformers=params.scaling_xformers,
        )

    return model


================================================
FILE: models/macros.py
================================================
# Text
NUM_TEXT_TOKENS = 2048

# Audio
NUM_AUDIO_TOKENS = 1024  # EnCodec RVQ bins
NUM_MEL_BINS = 100  # BigVGAN bigvgan_24khz_100band


# Speaker
NUM_SPEAKER_CLASSES = 4096
SPEAKER_EMBEDDING_DIM = 64


================================================
FILE: models/transformer.py
================================================
# Copyright    2023                             (authors: Feiteng Li)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial
from typing import Any, Dict, List, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
# from icefall.utils import make_pad_mask
# from torchmetrics.classification import BinaryAccuracy

from models.vallex import Transpose
from modules.embedding import SinePositionalEmbedding, TokenEmbedding
from modules.scaling import BalancedDoubleSwish, ScaledLinear
from modules.transformer import (
    BalancedBasicNorm,
    IdentityNorm,
    TransformerDecoderLayer,
    TransformerEncoder,
    TransformerEncoderLayer,
)

from .macros import NUM_MEL_BINS, NUM_TEXT_TOKENS
from .visualizer import visualize

IdentityNorm = IdentityNorm


class Transformer(nn.Module):
    """It implements seq2seq Transformer TTS for debug(No StopPredictor and SpeakerEmbeding)
    Neural Speech Synthesis with Transformer Network
    https://arxiv.org/abs/1809.08895
    """

    def __init__(
        self,
        d_model: int,
        nhead: int,
        num_layers: int,
        norm_first: bool = True,
        add_prenet: bool = False,
        scaling_xformers: bool = False,
    ):
        """
        Args:
          d_model:
            The number of expected features in the input (required).
          nhead:
            The number of heads in the multiheadattention models (required).
          num_layers:
            The number of sub-decoder-layers in the decoder (required).
        """
        super().__init__()
        self.text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS)  # W_x

        if add_prenet:
            self.encoder_prenet = nn.Sequential(
                Transpose(),
                nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
                nn.BatchNorm1d(d_model),
                nn.ReLU(),
                nn.Dropout(0.5),
                nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
                nn.BatchNorm1d(d_model),
                nn.ReLU(),
                nn.Dropout(0.5),
                nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
                nn.BatchNorm1d(d_model),
                nn.ReLU(),
                nn.Dropout(0.5),
                Transpose(),
                nn.Linear(d_model, d_model),
            )

            self.decoder_prenet = nn.Sequential(
                nn.Linear(NUM_MEL_BINS, 256),
                nn.ReLU(),
                nn.Dropout(0.5),
                nn.Linear(256, 256),
                nn.ReLU(),
                nn.Dropout(0.5),
                nn.Linear(256, d_model),
            )

            assert scaling_xformers is False  # TODO: update this block
        else:
            self.encoder_prenet = nn.Identity()
            if scaling_xformers:
                self.decoder_prenet = ScaledLinear(NUM_MEL_BINS, d_model)
            else:
                self.decoder_prenet = nn.Linear(NUM_MEL_BINS, d_model)

        self.encoder_position = SinePositionalEmbedding(
            d_model,
            dropout=0.1,
            scale=False,
        )
        self.decoder_position = SinePositionalEmbedding(
            d_model, dropout=0.1, scale=False
        )

        if scaling_xformers:
            self.encoder = TransformerEncoder(
                TransformerEncoderLayer(
                    d_model,
                    nhead,
                    dim_feedforward=d_model * 4,
                    dropout=0.1,
                    batch_first=True,
                    norm_first=norm_first,
                    linear1_self_attention_cls=ScaledLinear,
                    linear2_self_attention_cls=partial(
                        ScaledLinear, initial_scale=0.01
                    ),
                    linear1_feedforward_cls=ScaledLinear,
                    linear2_feedforward_cls=partial(
                        ScaledLinear, initial_scale=0.01
                    ),
                    activation=partial(
                        BalancedDoubleSwish,
                        channel_dim=-1,
                        max_abs=10.0,
                        min_prob=0.25,
                    ),
                    layer_norm_cls=IdentityNorm,
                ),
                num_layers=num_layers,
                norm=BalancedBasicNorm(d_model) if norm_first else None,
            )

            self.decoder = nn.TransformerDecoder(
                TransformerDecoderLayer(
                    d_model,
                    nhead,
                    dim_feedforward=d_model * 4,
                    dropout=0.1,
                    batch_first=True,
                    norm_first=norm_first,
                    linear1_self_attention_cls=ScaledLinear,
                    linear2_self_attention_cls=partial(
                        ScaledLinear, initial_scale=0.01
                    ),
                    linear1_feedforward_cls=ScaledLinear,
                    linear2_feedforward_cls=partial(
                        ScaledLinear, initial_scale=0.01
                    ),
                    activation=partial(
                        BalancedDoubleSwish,
                        channel_dim=-1,
                        max_abs=10.0,
                        min_prob=0.25,
                    ),
                    layer_norm_cls=IdentityNorm,
                ),
                num_layers=num_layers,
                norm=BalancedBasicNorm(d_model) if norm_first else None,
            )

            self.predict_layer = ScaledLinear(d_model, NUM_MEL_BINS)
            self.stop_layer = nn.Linear(d_model, 1)
        else:
            self.encoder = nn.TransformerEncoder(
                nn.TransformerEncoderLayer(
                    d_model,
                    nhead,
                    dim_feedforward=d_model * 4,
                    activation=F.relu,
                    dropout=0.1,
                    batch_first=True,
                    norm_first=norm_first,
                ),
                num_layers=num_layers,
                norm=nn.LayerNorm(d_model) if norm_first else None,
            )

            self.decoder = nn.TransformerDecoder(
                nn.TransformerDecoderLayer(
                    d_model,
                    nhead,
                    dim_feedforward=d_model * 4,
                    activation=F.relu,
                    dropout=0.1,
                    batch_first=True,
                    norm_first=norm_first,
                ),
                num_layers=num_layers,
                norm=nn.LayerNorm(d_model) if norm_first else None,
            )

            self.predict_layer = nn.Linear(d_model, NUM_MEL_BINS)
            self.stop_layer = nn.Linear(d_model, 1)

        self.stop_accuracy_metric = BinaryAccuracy(
            threshold=0.5, multidim_average="global"
        )

    #     self.apply(self._init_weights)

    # def _init_weights(self, module):
    #     if isinstance(module, (nn.Linear)):
    #         module.weight.data.normal_(mean=0.0, std=0.02)
    #         if isinstance(module, nn.Linear) and module.bias is not None:
    #             module.bias.data.zero_()
    #     elif isinstance(module, nn.LayerNorm):
    #         module.bias.data.zero_()
    #         module.weight.data.fill_(1.0)
    #     elif isinstance(module, nn.Embedding):
    #         module.weight.data.normal_(mean=0.0, std=0.02)

    def forward(
        self,
        x: torch.Tensor,
        x_lens: torch.Tensor,
        y: torch.Tensor,
        y_lens: torch.Tensor,
        reduction: str = "sum",
        train_stage: int = 0,
        **kwargs,
    ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
        """
        Args:
          x:
            A 2-D tensor of shape (N, S).
          x_lens:
            A 1-D tensor of shape (N,). It contains the number of tokens in `x`
            before padding.
          y:
            A 3-D tensor of shape (N, T, 8).
          y_lens:
            A 1-D tensor of shape (N,). It contains the number of tokens in `x`
            before padding.
          train_stage:
            Not used in this model.
        Returns:
          Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy.
        """
        del train_stage

        assert x.ndim == 2, x.shape
        assert x_lens.ndim == 1, x_lens.shape
        assert y.ndim == 3, y.shape
        assert y_lens.ndim == 1, y_lens.shape

        assert torch.all(x_lens > 0)

        # NOTE: x has been padded in TextTokenCollater
        x_mask = make_pad_mask(x_lens).to(x.device)

        x = self.text_embedding(x)
        x = self.encoder_prenet(x)
        x = self.encoder_position(x)
        x = self.encoder(x, src_key_padding_mask=x_mask)

        total_loss, metrics = 0.0, {}

        y_mask = make_pad_mask(y_lens).to(y.device)
        y_mask_float = y_mask.type(torch.float32)
        data_mask = 1.0 - y_mask_float.unsqueeze(-1)

        # Training
        # AR Decoder
        def pad_y(y):
            y = F.pad(y, (0, 0, 1, 0, 0, 0), value=0).detach()
            # inputs, targets
            return y[:, :-1], y[:, 1:]

        y, targets = pad_y(y * data_mask)  # mask padding as zeros

        y_emb = self.decoder_prenet(y)
        y_pos = self.decoder_position(y_emb)

        y_len = y_lens.max()
        tgt_mask = torch.triu(
            torch.ones(y_len, y_len, device=y.device, dtype=torch.bool),
            diagonal=1,
        )
        y_dec = self.decoder(
            y_pos,
            x,
            tgt_mask=tgt_mask,
            memory_key_padding_mask=x_mask,
        )

        predict = self.predict_layer(y_dec)
        # loss
        total_loss = F.mse_loss(predict, targets, reduction=reduction)

        logits = self.stop_layer(y_dec).squeeze(-1)
        stop_loss = F.binary_cross_entropy_with_logits(
            logits,
            y_mask_float.detach(),
            weight=1.0 + y_mask_float.detach() * 4.0,
            reduction=reduction,
        )
        metrics["stop_loss"] = stop_loss.detach()

        stop_accuracy = self.stop_accuracy_metric(
            (torch.sigmoid(logits) >= 0.5).type(torch.int64),
            y_mask.type(torch.int64),
        )
        # icefall MetricsTracker.norm_items()
        metrics["stop_accuracy"] = stop_accuracy.item() * y_lens.sum().type(
            torch.float32
        )

        return ((x, predict), total_loss + 100.0 * stop_loss, metrics)

    def inference(
        self,
        x: torch.Tensor,
        x_lens: torch.Tensor,
        y: Any = None,
        **kwargs,
    ) -> torch.Tensor:
        """
        Args:
          x:
            A 2-D tensor of shape (1, S).
          x_lens:
            A 1-D tensor of shape (1,). It contains the number of tokens in `x`
            before padding.
        Returns:
          Return the predicted audio code matrix and cross-entropy loss.
        """
        assert x.ndim == 2, x.shape
        assert x_lens.ndim == 1, x_lens.shape

        assert torch.all(x_lens > 0)

        x_mask = make_pad_mask(x_lens).to(x.device)

        x = self.text_embedding(x)
        x = self.encoder_prenet(x)
        x = self.encoder_position(x)
        x = self.encoder(x, src_key_padding_mask=x_mask)

        x_mask = make_pad_mask(x_lens).to(x.device)

        # AR Decoder
        # TODO: Managing decoder steps avoid repetitive computation
        y = torch.zeros(
            [x.shape[0], 1, NUM_MEL_BINS], dtype=torch.float32, device=x.device
        )
        while True:
            y_emb = self.decoder_prenet(y)
            y_pos = self.decoder_position(y_emb)

            tgt_mask = torch.triu(
                torch.ones(
                    y.shape[1], y.shape[1], device=y.device, dtype=torch.bool
                ),
                diagonal=1,
            )

            y_dec = self.decoder(
                y_pos,
                x,
                tgt_mask=tgt_mask,
                memory_mask=None,
                memory_key_padding_mask=x_mask,
            )
            predict = self.predict_layer(y_dec[:, -1:])

            logits = self.stop_layer(y_dec[:, -1:]) > 0  # sigmoid(0.0) = 0.5
            if y.shape[1] > x_lens.max() * 10 or all(logits.cpu().numpy()):
                print(
                    f"TransformerTTS EOS [Text {x_lens[0]} -> Audio {y.shape[1]}]"
                )
                break

            y = torch.concat([y, predict], dim=1)

        return y[:, 1:]

    def visualize(
        self,
        predicts: Tuple[torch.Tensor],
        batch: Dict[str, Union[List, torch.Tensor]],
        output_dir: str,
        limit: int = 4,
    ) -> None:
        visualize(predicts, batch, output_dir, limit=limit)


================================================
FILE: models/vallex.py
================================================
# Copyright    2023                             (authors: Feiteng Li)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
from typing import Dict, Iterator, List, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# from icefall.utils import make_pad_mask
# from torchmetrics.classification import MulticlassAccuracy

from data.input_strategies import PromptedFeatures
from modules.embedding import SinePositionalEmbedding, TokenEmbedding
from modules.transformer import (
    AdaptiveLayerNorm,
    LayerNorm,
    TransformerDecoderLayer,
    TransformerEncoder,
    TransformerEncoderLayer,
)

from .macros import NUM_AUDIO_TOKENS, NUM_TEXT_TOKENS
from .visualizer import visualize


class Transpose(nn.Identity):
    """(N, T, D) -> (N, D, T)"""

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return input.transpose(1, 2)


# NOTE: There are two ways to implement the model
#       1) [VALL-F] standard TransformerDecoder, use x as memory
#       2) [VALL-E] modified TransformerDecoder like GPT-x(e.g. causal TransformerEncoder),
#          use x as the prefix of decoder inputs
class VALLF(nn.Module):
    """It implements https://arxiv.org/abs/2301.02111
    "Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers"
    """

    def __init__(
        self,
        d_model: int,
        nhead: int,
        num_layers: int,
        norm_first: bool = True,
        add_prenet: bool = False,
        decoder_cls: Union[
            nn.TransformerDecoder, nn.TransformerEncoder
        ] = nn.TransformerDecoder,
        decoder_layer_cls: Union[
            TransformerDecoderLayer, TransformerEncoderLayer
        ] = TransformerDecoderLayer,
        prefix_mode: int = 0,
        share_embedding: bool = True,
        nar_scale_factor: float = 1.0,
        prepend_bos: bool = True,
        num_quantizers: int = 8,
    ):
        """
        Args:
          d_model:
            The number of expected features in the input (required).
          nhead:
            The number of heads in the multiheadattention models (required).
          num_layers:
            The number of sub-decoder-layers in the decoder (required).
        """
        super().__init__()
        nar_d_model = int(d_model * nar_scale_factor)

        self.ar_text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS)  # W_x
        self.nar_text_embedding = TokenEmbedding(nar_d_model, NUM_TEXT_TOKENS)

        # ID NUM_AUDIO_TOKENS     -> PAD
        # ID NUM_AUDIO_TOKENS + 1 -> BOS
        self.ar_audio_prepend_bos = prepend_bos
        self.ar_audio_embedding = TokenEmbedding(
            d_model, NUM_AUDIO_TOKENS + 1 + int(prepend_bos)
        )

        # PreNet
        if add_prenet:
            self.ar_text_prenet = nn.Sequential(
                Transpose(),
                nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
                nn.BatchNorm1d(d_model),
                nn.ReLU(),
                nn.Dropout(0.5),
                nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
                nn.BatchNorm1d(d_model),
                nn.ReLU(),
                nn.Dropout(0.5),
                nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
                nn.BatchNorm1d(d_model),
                nn.ReLU(),
                nn.Dropout(0.5),
                Transpose(),
                nn.Linear(d_model, d_model),
            )

            self.ar_audio_prenet = nn.Sequential(
                nn.Linear(d_model, 256),
                nn.ReLU(),
                nn.Dropout(0.25),
                nn.Linear(256, 256),
                nn.ReLU(),
                nn.Dropout(0.25),
                nn.Linear(256, d_model),
            )
        else:
            self.ar_text_prenet = nn.Identity()
            self.ar_audio_prenet = nn.Identity()

        self.ar_text_position = SinePositionalEmbedding(
            d_model,
            dropout=0.1,
            scale=False,
            alpha=True,
        )
        self.ar_audio_position = SinePositionalEmbedding(
            d_model,
            dropout=0.1,
            scale=False,
            alpha=True,
        )

        self.ar_decoder = decoder_cls(
            decoder_layer_cls(
                d_model,
                nhead,
                dim_feedforward=d_model * 4,
                dropout=0.1,
                batch_first=True,
                norm_first=norm_first,
            ),
            num_layers=num_layers,
            norm=LayerNorm(d_model) if norm_first else None,
        )
        self.ar_predict_layer = nn.Linear(
            d_model, NUM_AUDIO_TOKENS + 1, bias=False
        )

        self.rng = random.Random(0)
        self.num_heads = nhead
        self.prefix_mode = prefix_mode
        self.num_quantizers = num_quantizers

        assert num_quantizers >= 1
        if num_quantizers > 1:
            self.nar_audio_embeddings = nn.ModuleList(
                [TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS + 1)]
                + [
                    TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS)
                    for i in range(num_quantizers - 1)
                ]
            )  # W_a

            # PreNet
            if add_prenet:
                self.nar_text_prenet = nn.Sequential(
                    Transpose(),
                    nn.Conv1d(
                        nar_d_model, nar_d_model, kernel_size=5, padding="same"
                    ),
                    nn.BatchNorm1d(nar_d_model),
                    nn.ReLU(),
                    nn.Dropout(0.5),
                    nn.Conv1d(
                        nar_d_model, nar_d_model, kernel_size=5, padding="same"
                    ),
                    nn.BatchNorm1d(nar_d_model),
                    nn.ReLU(),
                    nn.Dropout(0.5),
                    nn.Conv1d(
                        nar_d_model, nar_d_model, kernel_size=5, padding="same"
                    ),
                    nn.BatchNorm1d(nar_d_model),
                    nn.ReLU(),
                    nn.Dropout(0.5),
                    Transpose(),
                    nn.Linear(nar_d_model, nar_d_model),
                )
                self.nar_audio_prenet = nn.Sequential(
                    nn.Linear(nar_d_model, 256),
                    nn.ReLU(),
                    nn.Dropout(0.25),
                    nn.Linear(256, 256),
                    nn.ReLU(),
                    nn.Dropout(0.25),
                    nn.Linear(256, nar_d_model),
                )
            else:
                self.nar_text_prenet = nn.Identity()
                self.nar_audio_prenet = nn.Identity()

            self.nar_text_position = SinePositionalEmbedding(
                nar_d_model,
                dropout=0.0,
                scale=False,
                alpha=False,
            )
            self.nar_audio_position = SinePositionalEmbedding(
                nar_d_model,
                dropout=0.1,
                scale=False,
                alpha=False,
            )

            self.nar_decoder = decoder_cls(
                decoder_layer_cls(
                    nar_d_model,
                    int(nhead * nar_scale_factor),
                    dim_feedforward=nar_d_model * 4,
                    dropout=0.1,
                    batch_first=True,
                    norm_first=norm_first,
                    adaptive_layer_norm=True,
                ),
                num_layers=int(num_layers * nar_scale_factor),
                norm=AdaptiveLayerNorm(
                    nar_d_model, norm=nn.LayerNorm(nar_d_model)
                )
                if norm_first
                else None,
            )
            self.nar_predict_layers = nn.ModuleList(
                [
                    nn.Linear(nar_d_model, NUM_AUDIO_TOKENS, bias=False)
                    for i in range(num_quantizers - 1)
                ]
            )
            self.nar_stage_embeddings = nn.ModuleList(
                [
                    TokenEmbedding(nar_d_model, 1)
                    for i in range(num_quantizers - 1)
                ]
            )

            if share_embedding:
                # We share the parameters of the output projection layer with the parameters of the acoustic embedding Wa
                # NOTE(Feiteng): In the experiment, this undermines accuracy
                # self.ar_predict_layer.weight = self.ar_audio_embedding.weight

                # We also share the parameters of the acoustic embedding layer and the output prediction layer,
                # which means the weights of the j-th prediction layer are the same as the (j + 1)-th acoustic embedding layer.
                for j in range(0, num_quantizers - 2):
                    self.nar_predict_layers[
                        j
                    ].weight = self.nar_audio_embeddings[j + 2].weight

    def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]:
        assert stage > 0
        if stage == 1:
            for name, param in self.named_parameters():
                if name.startswith("ar_"):
                    print(f" AR parameter: {name}")
                    yield param

        if stage == 2:
            for name, param in self.named_parameters():
                if name.startswith("nar_"):
                    print(f"NAR parameter: {name}")
                    yield param

    def stage_named_parameters(
        self, stage: int = 1
    ) -> Iterator[Tuple[str, nn.Parameter]]:
        assert stage > 0
        if stage == 1:
            for pair in self.named_parameters():
                if pair[0].startswith("ar_"):
                    yield pair

        if stage == 2:
            for pair in self.named_parameters():
                if pair[0].startswith("nar_"):
                    yield pair

    def pad_y_eos(self, y, y_mask_int, eos_id):
        targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
            y_mask_int, (0, 1), value=1
        )
        # inputs, targets
        if self.ar_audio_prepend_bos:
            return (
                F.pad(targets[:, :-1], (1, 0), value=NUM_AUDIO_TOKENS + 1),
                targets,
            )

        return targets[:, :-1], targets[:, 1:]

    def _prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes, prefix_mode):
        # 5.1 For the NAR acoustic prompt tokens, we select a random segment waveform of 3 seconds
        # from the same utterance.
        # We implement this differently.
        if prefix_mode == 0:
            # no prefix
            prefix_len = 0
            y_emb = self.nar_audio_embeddings[0](y)
            for j in range(1, nar_stage):
                # Formula (4) (5)
                y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j])
        elif prefix_mode == 1:
            # prefix at begining
            int_low = (0.25 * y_lens.min()).type(torch.int64).item()
            prefix_len = torch.randint(0, int_low * 2, size=()).item()
            prefix_len = min(prefix_len, 225)  # 24000/320 * 3s = 225 frames

            y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len])
            y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:])
            for j in range(1, self.num_quantizers):
                y_prompts += self.nar_audio_embeddings[j](
                    codes[:, :prefix_len, j]
                )
                if j < nar_stage:
                    y_emb += self.nar_audio_embeddings[j](
                        codes[:, prefix_len:, j]
                    )
            y_emb = torch.concat([y_prompts, y_emb], axis=1)
        elif prefix_mode in [2, 4]:
            if prefix_mode == 2:
                # random prefix
                prefix_len = min(225, int(0.25 * y_lens.min().item()))

                y_prompts_codes = []
                for b in range(codes.shape[0]):
                    start = self.rng.randint(0, y_lens[b].item() - prefix_len)
                    y_prompts_codes.append(
                        torch.clone(codes[b, start : start + prefix_len])
                    )
                    codes[
                        b, start : start + prefix_len, nar_stage
                    ] = NUM_AUDIO_TOKENS
                y_prompts_codes = torch.stack(y_prompts_codes, dim=0)
            else:
                prefix_len = y_prompts_codes.shape[1]

            y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0])
            y_emb = self.nar_audio_embeddings[0](y)
            for j in range(1, self.num_quantizers):
                y_prompts += self.nar_audio_embeddings[j](
                    y_prompts_codes[..., j]
                )
                if j < nar_stage:
                    y_emb += self.nar_audio_embeddings[j](codes[..., j])
            y_emb = torch.concat([y_prompts, y_emb], axis=1)
        else:
            raise ValueError

        return y_emb, prefix_len

    def forward(
        self,
        x: torch.Tensor,
        x_lens: torch.Tensor,
        y: Union[torch.Tensor, PromptedFeatures],
        y_lens: Union[torch.Tensor, PromptedFeatures],
        reduction: str = "sum",
        train_stage: int = 0,
        **kwargs,
    ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
        raise NotImplementedError

    def inference(
        self,
        x: torch.Tensor,
        x_lens: torch.Tensor,
        y: torch.Tensor,
        enroll_x_lens: Union[torch.Tensor, None] = None,
        top_k: int = -100,
        temperature: float = 1.0,
    ) -> torch.Tensor:
        raise NotImplementedError

    def visualize(
        self,
        predicts: Tuple[torch.Tensor],
        batch: Dict[str, Union[List, torch.Tensor]],
        output_dir: str,
        limit: int = 4,
    ) -> None:
        raise NotImplementedError


class VALLE(VALLF):
    """It implements https://arxiv.org/abs/2301.02111
    "Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers"
    """

    def __init__(
        self,
        d_model: int,
        nhead: int,
        num_layers: int,
        norm_first: bool = True,
        add_prenet: bool = False,
        prefix_mode: int = 0,
        share_embedding: bool = True,
        nar_scale_factor: float = 1.0,
        **kwargs,
    ):
        """
        Args:
          d_model:
            The number of expected features in the input (required).
          nhead:
            The number of heads in the multiheadattention models (required).
          num_layers:
            The number of sub-decoder-layers in the decoder (required).
        """
        super(VALLE, self).__init__(
            d_model,
            nhead,
            num_layers,
            norm_first=norm_first,
            add_prenet=add_prenet,
            decoder_cls=TransformerEncoder,
            decoder_layer_cls=TransformerEncoderLayer,
            prefix_mode=prefix_mode,
            share_embedding=share_embedding,
            nar_scale_factor=nar_scale_factor,
            **kwargs,
        )
        self.language_ID = {
            'en': 0,
            'zh': 1,
            'ja': 2,
        }
        self.ar_language_embedding = TokenEmbedding(d_model, len(self.language_ID))
        self.nar_language_embedding = TokenEmbedding(d_model, len(self.language_ID))

    def forward(
        self,
        x: torch.Tensor,
        x_lens: torch.Tensor,
        y: Union[torch.Tensor, PromptedFeatures],
        y_lens: Union[torch.Tensor, PromptedFeatures],
        reduction: str = "sum",
        train_stage: int = 0,
        **kwargs,
    ):
        raise NotImplementedError
    def inference(
        self,
        x: torch.Tensor,
        x_lens: torch.Tensor,
        y: torch.Tensor,
        enroll_x_lens: torch.Tensor,
        top_k: int = -100,
        temperature: float = 1.0,
        prompt_language: str = None,
        text_language: str = None,
        best_of: int = 1,
        length_penalty: float = 1.0,
        return_worst: bool = False,
    ) -> torch.Tensor:
        """
        Args:
          x:
            A 2-D tensor of shape (1, S).
          x_lens:
            A 1-D tensor of shape (1,). It contains the number of tokens in `x`
            before padding.
          y:
            A 3-D tensor of shape (1, T, 8).
          top_k: (`optional`) int
            The number of highest probability tokens to keep for top-k-filtering. Default to -100.
          temperature: (`optional`) float
            The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
        Returns:
          Return the predicted audio code matrix.
        """
        assert x.ndim == 2, x.shape
        assert x_lens.ndim == 1, x_lens.shape
        assert y.ndim == 3, y.shape
        assert y.shape[0] == 1, y.shape

        assert torch.all(x_lens > 0)

        # NOTE: x has been padded in TextTokenCollater
        text = x
        x = self.ar_text_embedding(text)
        # Add language embedding
        prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device)
        if isinstance(text_language, str):
            text_language_id = torch.LongTensor(np.array([self.language_ID[text_language]])).to(x.device)
        elif isinstance(text_language, List):
            text_language_id = torch.LongTensor(np.array([self.language_ID[tl] for tl in text_language])).to(x.device)
        x[:, :enroll_x_lens, :] += self.ar_language_embedding(prompt_language_id)
        x[:, enroll_x_lens:, :] += self.ar_language_embedding(text_language_id)
        x = self.ar_text_prenet(x)
        x = self.ar_text_position(x)

        text_len = x_lens.max()
        prompts = y
        prefix_len = y.shape[1]

        # AR Decoder
        # TODO: Managing decoder steps avoid repetitive computation
        y = prompts[..., 0]
        if self.ar_audio_prepend_bos:
            y = F.pad(y, (1, 0), value=NUM_AUDIO_TOKENS + 1)

        x_len = x_lens.max()
        x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)

        kv_cache = None
        use_kv_caching = True

        sum_logprobs = torch.zeros(best_of, device=y.device)  # implement batch decoding here
        x = x.repeat(best_of, 1, 1)
        y = y.repeat(best_of, 1)
        while True:
            y_emb = self.ar_audio_embedding(y)
            y_emb = self.ar_audio_prenet(y_emb)
            y_pos = self.ar_audio_position(y_emb)
            xy_pos = torch.concat([x, y_pos], dim=1)

            y_len = y.shape[1]
            x_attn_mask_pad = F.pad(
                x_attn_mask,
                (0, y_len),
                value=True,
            )
            y_attn_mask = F.pad(
                torch.triu(
                    torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1
                ),
                (x_len, 0),
                value=False,
            )
            xy_attn_mask = torch.concat(
                [x_attn_mask_pad, y_attn_mask], dim=0
            ).to(y.device)


            if use_kv_caching and kv_cache is not None:
                xy_pos = xy_pos[:, [-1]]
            else:
                pass

            xy_dec, kv_cache = self.ar_decoder.infer(
                xy_pos,
                mask=xy_attn_mask,
                past_kv=kv_cache,
                use_cache=use_kv_caching,
            )
            # xy_dec, _ = self.ar_decoder(
            #     (xy_pos, None),
            #     mask=xy_attn_mask,
            # )

            logits = self.ar_predict_layer(xy_dec[:, -1])
            samples, current_logprobs = topk_sampling(
                logits, top_k=top_k, top_p=1, temperature=temperature
            )
            sum_logprobs += current_logprobs * (y[:, -1] != NUM_AUDIO_TOKENS)
            samples[y[:, -1] == NUM_AUDIO_TOKENS] = NUM_AUDIO_TOKENS
            completed = (samples[:, -1] == NUM_AUDIO_TOKENS).all()
            if (
                completed
                or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16
            ):
                if prompts.shape[1] == y.shape[1]:
                    raise SyntaxError(
                        "well trained model shouldn't reach here."
                    )
                lengths = torch.sum(y != NUM_AUDIO_TOKENS, dim=1)
                avg_logprobs = sum_logprobs / lengths ** length_penalty
                # choose the best beam according to sum_logprobs
                best_beam = y[torch.argmax(avg_logprobs), :]
                worst_beam = y[torch.argmin(avg_logprobs), :]
                # strip all eos tokens
                best_beam = best_beam[best_beam != NUM_AUDIO_TOKENS]
                worst_beam = worst_beam[worst_beam != NUM_AUDIO_TOKENS]
                if return_worst:
                    y = worst_beam.unsqueeze(0)
                else:
                    y = best_beam.unsqueeze(0)
                print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]")
                break

            y = torch.concat([y, samples], dim=1)

        codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]]
        if self.num_quantizers == 1:
            return torch.stack(codes, dim=-1)

        # Non-AR Decoders
        y_emb = self.nar_audio_embeddings[0](
            y[:, int(self.ar_audio_prepend_bos) :]
        )

        if self.prefix_mode in [2, 4]:  # Exclude enrolled_phonemes
            enrolled_len = enroll_x_lens.max().item()
            # SOS + Synthesis Text + EOS
            text = torch.concat(
                [
                    text[:, :1],
                    text[:, enrolled_len - 1 :],
                ],
                dim=1,
            )
            text_len = text_len - (enrolled_len - 2)
            assert text.shape[0] == 1

        x = self.nar_text_embedding(text)
        # Add language embedding
        prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device)
        if isinstance(text_language, str):
            text_language_id = torch.LongTensor(np.array([self.language_ID[text_language]])).to(x.device)
        elif isinstance(text_language, List):
            text_language_id = torch.LongTensor(np.array([self.language_ID[tl] for tl in text_language])).to(x.device)
        x[:, :enroll_x_lens, :] += self.nar_language_embedding(prompt_language_id)
        x[:, enroll_x_lens:, :] += self.nar_language_embedding(text_language_id)
        x = self.nar_text_prenet(x)
        x = self.nar_text_position(x)

        if self.prefix_mode == 0:
            for i, (predict_layer, embedding_layer) in enumerate(
                zip(
                    self.nar_predict_layers,
                    self.nar_audio_embeddings[1:],
                )
            ):
                y_pos = self.nar_audio_prenet(y_emb)
                y_pos = self.nar_audio_position(y_pos)
                xy_pos = torch.concat([x, y_pos], dim=1)

                xy_dec, _ = self.nar_decoder(
                    (xy_pos, self.nar_stage_embeddings[i].weight)
                )
                logits = predict_layer(xy_dec[:, text_len + prefix_len :])

                samples = torch.argmax(logits, dim=-1)
                codes.append(samples)

                if i < self.num_quantizers - 2:
                    y_emb[:, :prefix_len] += embedding_layer(
                        prompts[..., i + 1]
                    )
                    y_emb[:, prefix_len:] += embedding_layer(samples)
        else:
            for j in range(1, self.num_quantizers):
                y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](
                    prompts[..., j]
                )

            for i, (predict_layer, embedding_layer) in enumerate(
                zip(
                    self.nar_predict_layers,
                    self.nar_audio_embeddings[1:],
                )
            ):
                y_pos = self.nar_audio_prenet(y_emb)
                y_pos = self.nar_audio_position(y_pos)
                xy_pos = torch.concat([x, y_pos], dim=1)

                xy_dec, _ = self.nar_decoder(
                    (xy_pos, self.nar_stage_embeddings[i].weight)
                )
                logits = predict_layer(xy_dec[:, text_len + prefix_len :])

                samples = torch.argmax(logits, dim=-1)
                codes.append(samples)

                if i < self.num_quantizers - 2:
                    y_emb[:, prefix_len:] += embedding_layer(samples)

        assert len(codes) == self.num_quantizers
        return torch.stack(codes, dim=-1)

    def continual(
        self,
        x: torch.Tensor,
        x_lens: torch.Tensor,
        y: torch.Tensor,
    ) -> torch.Tensor:
        """
        Args:
          x:
            A 2-D tensor of shape (1, S).
          x_lens:
            A 1-D tensor of shape (1,). It contains the number of tokens in `x`
            before padding.
          y:
            A 3-D tensor of shape (1, T, 8).
        Returns:
          Return the predicted audio code matrix.
        """
        assert x.ndim == 2, x.shape
        assert x_lens.ndim == 1, x_lens.shape
        assert y.ndim == 3, y.shape
        assert y.shape[0] == 1, y.shape

        assert torch.all(x_lens > 0)
        assert self.num_quantizers == 8

        # NOTE: x has been padded in TextTokenCollater
        text = x
        x = self.ar_text_embedding(text)
        x = self.ar_text_prenet(x)
        x = self.ar_text_position(x)

        text_len = x_lens.max()

        prefix_len = min(int(y.shape[1] * 0.5), 3 * 75)

        # AR Decoder
        prompts = y[:, :prefix_len]

        codes = [y[:, prefix_len:, 0]]
        # Non-AR Decoders
        x = self.nar_text_embedding(text)
        x = self.nar_text_prenet(x)
        x = self.nar_text_position(x)

        y_emb = self.nar_audio_embeddings[0](y[..., 0])

        if self.prefix_mode == 0:
            for i, (predict_layer, embedding_layer) in enumerate(
                zip(
                    self.nar_predict_layers,
                    self.nar_audio_embeddings[1:],
                )
            ):
                y_pos = self.nar_audio_position(y_emb)
                y_pos = self.nar_audio_prenet(y_pos)
                xy_pos = torch.concat([x, y_pos], dim=1)

                xy_dec, _ = self.nar_decoder(
                    (xy_pos, self.nar_stage_embeddings[i].weight)
                )
                logits = predict_layer(xy_dec[:, text_len + prefix_len :])

                samples = torch.argmax(logits, dim=-1)
                codes.append(samples)

                if i < 6:
                    y_emb[:, :prefix_len] += embedding_layer(
                        prompts[..., i + 1]
                    )
                    y_emb[:, prefix_len:] += embedding_layer(samples)
        else:
            for j in range(1, 8):
                y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](
                    prompts[..., j]
                )

            for i, (predict_layer, embedding_layer) in enumerate(
                zip(
                    self.nar_predict_layers,
                    self.nar_audio_embeddings[1:],
                )
            ):
                y_pos = self.nar_audio_prenet(y_emb)
                y_pos = self.nar_audio_position(y_pos)
                xy_pos = torch.concat([x, y_pos], dim=1)

                xy_dec, _ = self.nar_decoder(
                    (xy_pos, self.nar_stage_embeddings[i].weight)
                )
                logits = predict_layer(xy_dec[:, text_len + prefix_len :])

                samples = torch.argmax(logits, dim=-1)
                codes.append(samples)

                if i < 6:
                    y_emb[:, prefix_len:] += embedding_layer(samples)

        assert len(codes) == 8
        return torch.stack(codes, dim=-1)


# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
def top_k_top_p_filtering(
    logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
):
    """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
    Args:
        logits: logits distribution shape (batch size, vocabulary size)
        if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
        if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
            Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        Make sure we keep at least min_tokens_to_keep per batch example in the output
    From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    if top_k > 0:
        top_k = min(
            max(top_k, min_tokens_to_keep), logits.size(-1)
        )  # Safety check
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(
            F.softmax(sorted_logits, dim=-1), dim=-1
        )

        # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
        sorted_indices_to_remove = cumulative_probs > top_p
        if min_tokens_to_keep > 1:
            # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
            sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
            ..., :-1
        ].clone()
        sorted_indices_to_remove[..., 0] = 0

        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(
            1, sorted_indices, sorted_indices_to_remove
        )
        logits[indices_to_remove] = filter_value
    return logits


def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
    # temperature: (`optional`) float
    #     The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
    # top_k: (`optional`) int
    #     The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
    # top_p: (`optional`) float
    #     The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.

    # Temperature (higher temperature => more likely to sample low probability tokens)
    if temperature != 1.0:
        logits = logits / temperature
    # Top-p/top-k filtering
    logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
    # Sample
    token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
    logprobs = F.log_softmax(logits.float(), dim=-1)
    current_logprobs = logprobs[torch.arange(logprobs.shape[0]), token.squeeze(1)]
    return token, current_logprobs


================================================
FILE: models/visualizer.py
================================================
#!/usr/bin/env python3
# Copyright    2023                           (authors: Feiteng Li)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Dict, List, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import torch


def visualize(
    predicts: Tuple[torch.Tensor],
    batch: Dict[str, Union[List, torch.Tensor]],
    output_dir: str,
    limit: int = 4,
) -> None:
    text_tokens = batch["text_tokens"].to("cpu").detach().numpy()
    text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy()
    audio_features = batch["audio_features"].to("cpu").detach().numpy()
    audio_features_lens = (
        batch["audio_features_lens"].to("cpu").detach().numpy()
    )
    assert text_tokens.ndim == 2

    utt_ids, texts = batch["utt_id"], batch["text"]

    encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy()
    decoder_outputs = predicts[1]
    if isinstance(decoder_outputs, list):
        decoder_outputs = decoder_outputs[-1]
    decoder_outputs = (
        decoder_outputs.to("cpu").type(torch.float32).detach().numpy()
    )

    vmin, vmax = 0, 1024  # Encodec
    if decoder_outputs.dtype == np.float32:
        vmin, vmax = -6, 0  # Fbank

    num_figures = 3
    for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])):
        _ = plt.figure(figsize=(14, 8 * num_figures))

        S = text_tokens_lens[b]
        T = audio_features_lens[b]

        # encoder
        plt.subplot(num_figures, 1, 1)
        plt.title(f"Text: {text}")
        plt.imshow(
            X=np.transpose(encoder_outputs[b]),
            cmap=plt.get_cmap("jet"),
            aspect="auto",
            interpolation="nearest",
        )
        plt.gca().invert_yaxis()
        plt.axvline(x=S - 0.4, linewidth=2, color="r")
        plt.xlabel("Encoder Output")
        plt.colorbar()

        # decoder
        plt.subplot(num_figures, 1, 2)
        plt.imshow(
            X=np.transpose(decoder_outputs[b]),
            cmap=plt.get_cmap("jet"),
            aspect="auto",
            interpolation="nearest",
            vmin=vmin,
            vmax=vmax,
        )
        plt.gca().invert_yaxis()
        plt.axvline(x=T - 0.4, linewidth=2, color="r")
        plt.xlabel("Decoder Output")
        plt.colorbar()

        # target
        plt.subplot(num_figures, 1, 3)
        plt.imshow(
            X=np.transpose(audio_features[b]),
            cmap=plt.get_cmap("jet"),
            aspect="auto",
            interpolation="nearest",
            vmin=vmin,
            vmax=vmax,
        )
        plt.gca().invert_yaxis()
        plt.axvline(x=T - 0.4, linewidth=2, color="r")
        plt.xlabel("Decoder Target")
        plt.colorbar()

        plt.savefig(f"{output_dir}/{utt_id}.png")
        plt.close()


================================================
FILE: modules/__init__.py
================================================


================================================
FILE: modules/activation.py
================================================
from typing import Optional, Tuple, List
import math

import torch
from torch import Tensor
from torch.nn import Linear, Module
from torch.nn import functional as F
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.parameter import Parameter

def _in_projection_packed(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    w: Tensor,
    b: Optional[Tensor] = None,
) -> List[Tensor]:
    r"""
    Performs the in-projection step of the attention operation, using packed weights.
    Output is a triple containing projection tensors for query, key and value.

    Args:
        q, k, v: query, key and value tensors to be projected. For self-attention,
            these are typically the same tensor; for encoder-decoder attention,
            k and v are typically the same tensor. (We take advantage of these
            identities for performance if they are present.) Regardless, q, k and v
            must share a common embedding dimension; otherwise their shapes may vary.
        w: projection weights for q, k and v, packed into a single tensor. Weights
            are packed along dimension 0, in q, k, v order.
        b: optional projection biases for q, k and v, packed into a single tensor
            in q, k, v order.

    Shape:
        Inputs:
        - q: :math:`(..., E)` where E is the embedding dimension
        - k: :math:`(..., E)` where E is the embedding dimension
        - v: :math:`(..., E)` where E is the embedding dimension
        - w: :math:`(E * 3, E)` where E is the embedding dimension
        - b: :math:`E * 3` where E is the embedding dimension

        Output:
        - in output list :math:`[q', k', v']`, each output tensor will have the
            same shape as the corresponding input tensor.
    """
    E = q.size(-1)
    if k is v:
        if q is k:
            # self-attention
            return F.linear(q, w, b).chunk(3, dim=-1)
        else:
            # encoder-decoder attention
            w_q, w_kv = w.split([E, E * 2])
            if b is None:
                b_q = b_kv = None
            else:
                b_q, b_kv = b.split([E, E * 2])
            return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1)
    else:
        w_q, w_k, w_v = w.chunk(3)
        if b is None:
            b_q = b_k = b_v = None
        else:
            b_q, b_k, b_v = b.chunk(3)
        return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)

def _scaled_dot_product_attention(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    attn_mask: Optional[Tensor] = None,
    dropout_p: float = 0.0,
) -> Tuple[Tensor, Tensor]:
    r"""
    Computes scaled dot product attention on query, key and value tensors, using
    an optional attention mask if passed, and applying dropout if a probability
    greater than 0.0 is specified.
    Returns a tensor pair containing attended values and attention weights.

    Args:
        q, k, v: query, key and value tensors. See Shape section for shape details.
        attn_mask: optional tensor containing mask values to be added to calculated
            attention. May be 2D or 3D; see Shape section for details.
        dropout_p: dropout probability. If greater than 0.0, dropout is applied.

    Shape:
        - q: :math:`(B, Nt, E)` where B is batch size, Nt is the target sequence length,
            and E is embedding dimension.
        - key: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
            and E is embedding dimension.
        - value: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
            and E is embedding dimension.
        - attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of
            shape :math:`(Nt, Ns)`.

        - Output: attention values have shape :math:`(B, Nt, E)`; attention weights
            have shape :math:`(B, Nt, Ns)`
    """
    B, Nt, E = q.shape
    q = q / math.sqrt(E)
    # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
    if attn_mask is not None:
        attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1))
    else:
        attn = torch.bmm(q, k.transpose(-2, -1))

    attn = F.softmax(attn, dim=-1)
    if dropout_p > 0.0:
        attn = F.dropout(attn, p=dropout_p)
    # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
    output = torch.bmm(attn, v)
    return output, attn

def multi_head_attention_forward(
        x,
        ipw,
        ipb,
        opw,
        opb,
        n_head,
        attn_mask,
        past_kv=None,
        use_cache=False,
):
    # x = x.transpose(1, 0)
    # tgt_len, bsz, embed_dim = x.shape
    # head_dim = embed_dim // n_head
    # q, k, v = _in_projection_packed(x, x, x, ipw, ipb)
    # q = q.contiguous().view(tgt_len, bsz * n_head, head_dim).transpose(0, 1)
    # k = k.contiguous().view(k.shape[0], bsz * n_head, head_dim).transpose(0, 1)
    # v = v.contiguous().view(v.shape[0], bsz * n_head, head_dim).transpose(0, 1)

    # new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
    # new_attn_mask.masked_fill_(attn_mask, float("-inf"))
    # attn_mask = new_attn_mask
    #
    # attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, 0.0)
    # attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
    # attn_output = torch._C._nn.linear(attn_output, opw, opb)
    # attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))

    B, T, C = x.size()

    q, k, v = torch._C._nn.linear(x, ipw, ipb).chunk(3, dim=-1)
    k = k.view(B, T, n_head, C // n_head).transpose(1, 2)  # (B, nh, T, hs)
    q = q.view(B, T, n_head, C // n_head).transpose(1, 2)  # (B, nh, T, hs)
    v = v.view(B, T, n_head, C // n_head).transpose(1, 2)  # (B, nh, T, hs)
    if past_kv is not None:
        past_key = past_kv[0]
        past_value = past_kv[1]
        k = torch.cat((past_key, k), dim=-2)
        v = torch.cat((past_value, v), dim=-2)

    FULL_T = k.shape[-2]

    if use_cache is True:
        present = (k, v)
    else:
        present = None

    att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
    att = att.masked_fill(attn_mask[FULL_T - T:FULL_T, :FULL_T], float('-inf'))
    att = F.softmax(att, dim=-1)
    y = att @ v  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
    y = y.transpose(1, 2).contiguous().view(B, T, C)  # re-assemble all head outputs side by side
    y = torch._C._nn.linear(y, opw, opb)
    return (y, present)


class MultiheadAttention(Module):
    r"""Allows the model to jointly attend to information
    from different representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    Multi-Head Attention is defined as:

    .. math::
        \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O

    where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.

    ``forward()`` will use a special optimized implementation if all of the following
    conditions are met:

    - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
      restriction will be loosened in the future.)
    - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
    - training is disabled (using ``.eval()``)
    - dropout is 0
    - ``add_bias_kv`` is ``False``
    - ``add_zero_attn`` is ``False``
    - ``batch_first`` is ``True`` and the input is batched
    - ``kdim`` and ``vdim`` are equal to ``embed_dim``
    - at most one of ``key_padding_mask`` or ``attn_mask`` is passed
    - if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
      nor ``attn_mask`` is passed

    If the optimized implementation is in use, a
    `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
    ``query``/``key``/``value`` to represent padding more efficiently than using a
    padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
    will be returned, and an additional speedup proportional to the fraction of the input
    that is padding can be expected.

    Args:
        embed_dim: Total dimension of the model.
        num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
            across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
        dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
        bias: If specified, adds bias to input / output projection layers. Default: ``True``.
        add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
        add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
            Default: ``False``.
        kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
        vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
        batch_first: If ``True``, then the input and output tensors are provided
            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).

    Examples::

        >>> # xdoctest: +SKIP
        >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
        >>> attn_output, attn_output_weights = multihead_attn(query, key, value)

    """
    __constants__ = ["batch_first"]
    bias_k: Optional[torch.Tensor]
    bias_v: Optional[torch.Tensor]

    def __init__(
            self,
            embed_dim,
            num_heads,
            dropout=0.0,
            bias=True,
            add_bias_kv=False,
            add_zero_attn=False,
            kdim=None,
            vdim=None,
            batch_first=False,
            linear1_cls=Linear,
            linear2_cls=Linear,
            device=None,
            dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super(MultiheadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self._qkv_same_embed_dim = (
                self.kdim == embed_dim and self.vdim == embed_dim
        )

        self.num_heads = num_heads
        self.dropout = dropout
        self.batch_first = batch_first
        self.head_dim = embed_dim // num_heads
        assert (
                self.head_dim * num_heads == self.embed_dim
        ), "embed_dim must be divisible by num_heads"

        if add_bias_kv:
            self.bias_k = Parameter(
                torch.empty((1, 1, embed_dim), **factory_kwargs)
            )
            self.bias_v = Parameter(
                torch.empty((1, 1, embed_dim), **factory_kwargs)
            )
        else:
            self.bias_k = self.bias_v = None

        if linear1_cls == Linear:
            if not self._qkv_same_embed_dim:
                self.q_proj_weight = Parameter(
                    torch.empty((embed_dim, embed_dim), **factory_kwargs)
                )
                self.k_proj_weight = Parameter(
                    torch.empty((embed_dim, self.kdim), **factory_kwargs)
                )
                self.v_proj_weight = Parameter(
                    torch.empty((embed_dim, self.vdim), **factory_kwargs)
                )
                self.register_parameter("in_proj_weight", None)
            else:
                self.in_proj_weight = Parameter(
                    torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
                )
                self.register_parameter("q_proj_weight", None)
                self.register_parameter("k_proj_weight", None)
                self.register_parameter("v_proj_weight", None)

            if bias:
                self.in_proj_bias = Parameter(
                    torch.empty(3 * embed_dim, **factory_kwargs)
                )
            else:
                self.register_parameter("in_proj_bias", None)
            self.out_proj = NonDynamicallyQuantizableLinear(
                embed_dim, embed_dim, bias=bias, **factory_kwargs
            )

            self._reset_parameters()
        else:
            if not self._qkv_same_embed_dim:
                raise NotImplementedError
            else:
                self.in_proj_linear = linear1_cls(
                    embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
                )
                self.in_proj_weight = self.in_proj_linear.weight

                self.register_parameter("q_proj_weight", None)
                self.register_parameter("k_proj_weight", None)
                self.register_parameter("v_proj_weight", None)

                if bias:
                    self.in_proj_bias = self.in_proj_linear.bias
                else:
                    self.register_parameter("in_proj_bias", None)

            self.out_proj = linear2_cls(
                embed_dim, embed_dim, bias=bias, **factory_kwargs
            )

            if self.bias_k is not None:
                xavier_normal_(self.bias_k)
            if self.bias_v is not None:
                xavier_normal_(self.bias_v)

        self.add_zero_attn = add_zero_attn

    def _reset_parameters(self):
        if self._qkv_same_embed_dim:
            xavier_uniform_(self.in_proj_weight)
        else:
            xavier_uniform_(self.q_proj_weight)
            xavier_uniform_(self.k_proj_weight)
            xavier_uniform_(self.v_proj_weight)

        if self.in_proj_bias is not None:
            constant_(self.in_proj_bias, 0.0)
            constant_(self.out_proj.bias, 0.0)

        if self.bias_k is not None:
            xavier_normal_(self.bias_k)
        if self.bias_v is not None:
            xavier_normal_(self.bias_v)

    def __setstate__(self, state):
        # Support loading old MultiheadAttention checkpoints generated by v1.1.0
        if "_qkv_same_embed_dim" not in state:
            state["_qkv_same_embed_dim"] = True

        super(MultiheadAttention, self).__setstate__(state)

    def forward(
            self,
            query: Tensor,
            key: Tensor,
            value: Tensor,
            key_padding_mask: Optional[Tensor] = None,
            need_weights: bool = True,
            attn_mask: Optional[Tensor] = None,
            average_attn_weights: bool = True,
    ) -> Tuple[Tensor, Optional[Tensor]]:
        r"""
        Args:
            query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
                or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
                :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
                Queries are compared against key-value pairs to produce the output.
                See "Attention Is All You Need" for more details.
            key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
                or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
                :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
                See "Attention Is All You Need" for more details.
            value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
                ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
                sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
                See "Attention Is All You Need" for more details.
            key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
                to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
                Binary and byte masks are supported.
                For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
                the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
            need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
                Default: ``True``.
            attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
                :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
                :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
                broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
                Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
                corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
                corresponding position is not allowed to attend. For a float mask, the mask values will be added to
                the attention weight.
            average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
                heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
                effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)

        Outputs:
            - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
              :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
              where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
              embedding dimension ``embed_dim``.
            - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
              returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
              :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
              :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
              head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.

            .. note::
                `batch_first` argument is ignored for unbatched inputs.
        """
        is_batched = query.dim() == 3
        if key_padding_mask is not None:
            _kpm_dtype = key_padding_mask.dtype
            if _kpm_dtype != torch.bool and not torch.is_floating_point(
                    key_padding_mask
            ):
                raise AssertionError(
                    "only bool and floating types of key_padding_mask are supported"
                )
        why_not_fast_path = ""
        if not is_batched:
            why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
        elif query is not key or key is not value:
            # When lifting this restriction, don't forget to either
            # enforce that the dtypes all match or test cases where
            # they don't!
            why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
        elif (
                self.in_proj_bias is not None
                and query.dtype != self.in_proj_bias.dtype
        ):
            why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
        elif (
                self.in_proj_weight is not None
                and query.dtype != self.in_proj_weight.dtype
        ):
            # this case will fail anyway, but at least they'll get a useful error message.
            why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
        elif self.training:
            why_not_fast_path = "training is enabled"
        elif not self.batch_first:
            why_not_fast_path = "batch_first was not True"
        elif self.bias_k is not None:
            why_not_fast_path = "self.bias_k was not None"
        elif self.bias_v is not None:
            why_not_fast_path = "self.bias_v was not None"
        elif self.dropout:
            why_not_fast_path = f"dropout was {self.dropout}, required zero"
        elif self.add_zero_attn:
            why_not_fast_path = "add_zero_attn was enabled"
        elif not self._qkv_same_embed_dim:
            why_not_fast_path = "_qkv_same_embed_dim was not True"
        elif attn_mask is not None:
            why_not_fast_path = "attn_mask was not None"
        elif query.is_nested and key_padding_mask is not None:
            why_not_fast_path = (
                "key_padding_mask is not supported with NestedTensor input"
            )
        elif self.num_heads % 2 == 1:
            why_not_fast_path = "num_heads is odd"
        elif torch.is_autocast_enabled():
            why_not_fast_path = "autocast is enabled"

        if not why_not_fast_path:
            tensor_args = (
                query,
                key,
                value,
                self.in_proj_weight,
                self.in_proj_bias,
                self.out_proj.weight,
                self.out_proj.bias,
            )
            # We have to use list comprehensions below because TorchScript does not support
            # generator expressions.
            if torch.overrides.has_torch_function(tensor_args):
                why_not_fast_path = "some Tensor argument has_torch_function"
            elif not all(
                    [
                        (x is None or x.is_cuda or "cpu" in str(x.device))
                        for x in tensor_args
                    ]
            ):
                why_not_fast_path = (
                    "some Tensor argument is neither CUDA nor CPU"
                )
            elif torch.is_grad_enabled() and any(
                    [x is not None and x.requires_grad for x in tensor_args]
            ):
                why_not_fast_path = (
                    "grad is enabled and at least one of query or the "
                    "input/output projection weights or biases requires_grad"
                )
            if not why_not_fast_path:
                return torch._native_multi_head_attention(
                    query,
                    key,
                    value,
                    self.embed_dim,
                    self.num_heads,
                    self.in_proj_weight,
                    self.in_proj_bias,
                    self.out_proj.weight,
                    self.out_proj.bias,
                    key_padding_mask
                    if key_padding_mask is not None
                    else attn_mask,
                    need_weights,
                    average_attn_weights,
                    1
                    if key_padding_mask is not None
                    else 0
                    if attn_mask is not None
                    else None,
                )

        any_nested = query.is_nested or key.is_nested or value.is_nested
        assert not any_nested, (
                "MultiheadAttention does not support NestedTensor outside of its fast path. "
                + f"The fast path was not hit because {why_not_fast_path}"
        )

        if self.batch_first and is_batched:
            # make sure that the transpose op does not affect the "is" property
            if key is value:
                if query is key:
                    query = key = value = query.transpose(1, 0)
                else:
                    query, key = [x.transpose(1, 0) for x in (query, key)]
                    value = key
            else:
                query, key, value = [
                    x.transpose(1, 0) for x in (query, key, value)
                ]

        if not self._qkv_same_embed_dim:
            attn_output, attn_output_weights = F.multi_head_attention_forward(
                query,
                key,
                value,
                self.embed_dim,
                self.num_heads,
                self.in_proj_weight,
                self.in_proj_bias,
                self.bias_k,
                self.bias_v,
                self.add_zero_attn,
                self.dropout,
                self.out_proj.weight,
                self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask,
                need_weights=need_weights,
                attn_mask=attn_mask,
                use_separate_proj_weight=True,
                q_proj_weight=self.q_proj_weight,
                k_proj_weight=self.k_proj_weight,
                v_proj_weight=self.v_proj_weight,
                average_attn_weights=average_attn_weights,
            )
        else:
            attn_output, attn_output_weights = F.multi_head_attention_forward(
                query,
                key,
                value,
                self.embed_dim,
                self.num_heads,
                self.in_proj_weight,
                self.in_proj_bias,
                self.bias_k,
                self.bias_v,
                self.add_zero_attn,
                self.dropout,
                self.out_proj.weight,
                self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask,
                need_weights=need_weights,
                attn_mask=attn_mask,
                average_attn_weights=average_attn_weights,
            )
        if self.batch_first and is_batched:
            return attn_output.transpose(1, 0), attn_output_weights
        else:
            return attn_output, attn_output_weights

    def infer(self,
              x: Tensor,
              key_padding_mask: Optional[Tensor] = None,
              need_weights: bool = True,
              attn_mask: Optional[Tensor] = None,
              average_attn_weights: bool = True,
              past_kv = None,
              use_cache = False
              ):
        # x = x.transpose(1, 0)
        y, kv = multi_head_attention_forward(
                x=x,
                ipw=self.in_proj_weight,
                ipb=self.in_proj_bias,
                opw=self.out_proj.weight,
                opb=self.out_proj.bias,
                n_head=self.num_heads,
                attn_mask=attn_mask,
                past_kv=past_kv,
                use_cache=use_cache,
        )
        return (y, kv)


================================================
FILE: modules/embedding.py
================================================
# Copyright    2023                             (authors: Feiteng Li)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import torch
import torch.nn as nn


class TokenEmbedding(nn.Module):
    def __init__(
        self,
        dim_model: int,
        vocab_size: int,
        dropout: float = 0.0,
    ):
        super().__init__()

        self.vocab_size = vocab_size
        self.dim_model = dim_model

        self.dropout = torch.nn.Dropout(p=dropout)
        self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)

    @property
    def weight(self) -> torch.Tensor:
        return self.word_embeddings.weight

    def embedding(self, index: int) -> torch.Tensor:
        return self.word_embeddings.weight[index : index + 1]

    def forward(self, x: torch.Tensor):
        X = self.word_embeddings(x)
        X = self.dropout(X)

        return X


class SinePositionalEmbedding(nn.Module):
    def __init__(
        self,
        dim_model: int,
        dropout: float = 0.0,
        scale: bool = False,
        alpha: bool = False,
    ):
        super().__init__()
        self.dim_model = dim_model
        self.x_scale = math.sqrt(dim_model) if scale else 1.0
        self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
        self.dropout = torch.nn.Dropout(p=dropout)

        self.reverse = False
        self.pe = None
        self.extend_pe(torch.tensor(0.0).expand(1, 4000))

    def extend_pe(self, x):
        """Reset the positional encodings."""
        if self.pe is not None:
            if self.pe.size(1) >= x.size(1):
                if self.pe.dtype != x.dtype or self.pe.device != x.device:
                    self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                return
        pe = torch.zeros(x.size(1), self.dim_model)
        if self.reverse:
            position = torch.arange(
                x.size(1) - 1, -1, -1.0, dtype=torch.float32
            ).unsqueeze(1)
        else:
            position = torch.arange(
                0, x.size(1), dtype=torch.float32
            ).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, self.dim_model, 2, dtype=torch.float32)
            * -(math.log(10000.0) / self.dim_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.pe = pe.to(device=x.device, dtype=x.dtype).detach()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        self.extend_pe(x)
        output = x.unsqueeze(-1) if x.ndim == 2 else x
        output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
        return self.dropout(output)


================================================
FILE: modules/optim.py
================================================
# Copyright      2022  Xiaomi Corp.        (authors: Daniel Povey)
#
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import logging
import random
from collections import defaultdict
from typing import List, Optional, Tuple, Union

import torch
from lhotse.utils import fix_random_seed
from torch import Tensor
from torch.optim import Optimizer


class BatchedOptimizer(Optimizer):
    """
    This class adds to class Optimizer the capability to optimize parameters in batches:
    it will stack the parameters and their grads for you so the optimizer can work
    on tensors with an extra leading dimension.  This is intended for speed with GPUs,
    as it reduces the number of kernels launched in the optimizer.

    Args:
      params:
    """

    def __init__(self, params, defaults):
        super(BatchedOptimizer, self).__init__(params, defaults)

    @contextlib.contextmanager
    def batched_params(self, param_group, group_params_names):
        """
        This function returns (technically, yields) a list of
          of tuples (p, state), where
        p is a `fake` parameter that is stacked (over axis 0) from real parameters
        that share the same shape, and its gradient is also stacked;
        `state` is the state corresponding to this batch of parameters
        (it will be physically located in the "state" for one of the real
        parameters, the last one that has any particular shape and dtype).

        This function is decorated as a context manager so that it can
        write parameters back to their "real" locations.

        The idea is, instead of doing:
        <code>
          for p in group["params"]:
             state = self.state[p]
             ...
        </code>
        you can do:
        <code>
          with self.batched_params(group["params"]) as batches:
             for p, state, p_names in batches:
                 ...
        </code>

        Args:
          group: a parameter group, which is a list of parameters; should be
                one of self.param_groups.
          group_params_names: name for each parameter in group,
                which is List[str].
        """
        batches = defaultdict(
            list
        )  # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
        batches_names = defaultdict(
            list
        )  # `batches` maps from tuple (dtype_as_str,*shape) to list of str

        assert len(param_group) == len(group_params_names)
        for p, named_p in zip(param_group, group_params_names):
            key = (str(p.dtype), *p.shape)
            batches[key].append(p)
            batches_names[key].append(named_p)

        batches_names_keys = list(batches_names.keys())
        sorted_idx = sorted(
            range(len(batches_names)), key=lambda i: batches_names_keys[i]
        )
        batches_names = [
            batches_names[batches_names_keys[idx]] for idx in sorted_idx
        ]
        batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]

        stacked_params_dict = dict()

        # turn batches into a list, in deterministic order.
        # tuples will contain tuples of (stacked_param, state, stacked_params_names),
        # one for each batch in `batches`.
        tuples = []

        for batch, batch_names in zip(batches, batches_names):
            p = batch[0]
            # we arbitrarily store the state in the
            # state corresponding to the 1st parameter in the
            # group.  class Optimizer will take care of saving/loading state.
            state = self.state[p]
            p_stacked = torch.stack(batch)
            grad = torch.stack(
                [
                    torch.zeros_like(p) if p.grad is None else p.grad
                    for p in batch
                ]
            )
            p_stacked.grad = grad
            stacked_params_dict[key] = p_stacked
            tuples.append((p_stacked, state, batch_names))

        yield tuples  # <-- calling code will do the actual optimization here!

        for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
            for i, p in enumerate(batch):  # batch is list of Parameter
                p.copy_(stacked_params[i])


class ScaledAdam(BatchedOptimizer):
    """
     Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
     proportional to the norm of that parameter; and also learn the scale of the parameter,
     in log space, subject to upper and lower limits (as if we had factored each parameter as
     param = underlying_param * log_scale.exp())


     Args:
          params:  The parameters or param_groups to optimize (like other Optimizer subclasses)
              lr:  The learning rate.  We will typically use a learning rate schedule that starts
                   at 0.03 and decreases over time, i.e. much higher than other common
                   optimizers.
     clipping_scale: (e.g. 2.0)
                   A scale for gradient-clipping: if specified, the normalized gradients
                   over the whole model will be clipped to have 2-norm equal to
                   `clipping_scale` times the median 2-norm over the most recent period
                   of `clipping_update_period` minibatches.  By "normalized gradients",
                   we mean after multiplying by the rms parameter value for this tensor
                   [for non-scalars]; this is appropriate because our update is scaled
                   by this quantity.
            betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
                   Must satisfy 0 < beta <= beta2 < 1.
     scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
                   scale of each parameter tensor and scalar parameters of the mode..
                   If each parameter were decomposed
                   as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
                   would be a the scaling factor on the learning rate of p_scale.
              eps:  A general-purpose epsilon to prevent division by zero
    param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
                   learning the scale on the parameters (we'll constrain the rms of each non-scalar
                   parameter tensor to be >= this value)
    param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
                   learning the scale on the parameters (we'll constrain the rms of each non-scalar
                   parameter tensor to be <= this value)
       scalar_max: Maximum absolute value for scalar parameters (applicable if your
                   model has any parameters with numel() == 1).
    size_update_period: The periodicity, in steps, with which we update the size (scale)
                   of the parameter tensor.  This is provided to save a little time
                   in the update.
     clipping_update_period: if clipping_scale is specified, this is the period
    """

    def __init__(
        self,
        params,
        lr=3e-02,
        clipping_scale=None,
        betas=(0.9, 0.98),
        scalar_lr_scale=0.1,
        eps=1.0e-08,
        param_min_rms=1.0e-05,
        param_max_rms=3.0,
        scalar_max=10.0,
        size_update_period=4,
        clipping_update_period=100,
        parameters_names=None,
        show_dominant_parameters=True,
    ):

        assert parameters_names is not None, (
            "Please prepare parameters_names,"
            "which is a List[List[str]]. Each List[str] is for a group"
            "and each str is for a parameter"
        )
        defaults = dict(
            lr=lr,
            clipping_scale=clipping_scale,
            betas=betas,
            scalar_lr_scale=scalar_lr_scale,
            eps=eps,
            param_min_rms=param_min_rms,
            param_max_rms=param_max_rms,
            scalar_max=scalar_max,
            size_update_period=size_update_period,
            clipping_update_period=clipping_update_period,
        )

        super(ScaledAdam, self).__init__(params, defaults)
        assert len(self.param_groups) == len(parameters_names)
        self.parameters_names = parameters_names
        self.show_dominant_parameters = show_dominant_parameters

    def __setstate__(self, state):
        super(ScaledAdam, self).__setstate__(state)

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        batch = True

        for group, group_params_names in zip(
            self.param_groups, self.parameters_names
        ):

            with self.batched_params(
                group["params"], group_params_names
            ) as batches:

                # batches is list of pairs (stacked_param, state).  stacked_param is like
                # a regular parameter, and will have a .grad, but the 1st dim corresponds to
                # a stacking dim, it is not a real dim.

                if (
                    len(batches[0][1]) == 0
                ):  # if len(first state) == 0: not yet initialized
                    clipping_scale = 1
                else:
                    clipping_scale = self._get_clipping_scale(group, batches)

                for p, state, _ in batches:
                    # Perform optimization step.
                    # grad is not going to be None, we handled that when creating the batches.
                    grad = p.grad
                    if grad.is_sparse:
                        raise RuntimeError(
                            "ScaledAdam optimizer does not support sparse gradients"
                        )
                    # State initialization
                    if len(state) == 0:
                        self._init_state(group, p, state)

                    self._step_one_batch(group, p, state, clipping_scale)

        return loss

    def _init_state(self, group: dict, p: Tensor, state: dict):
        """
        Initializes state dict for parameter 'p'.  Assumes that dim 0 of tensor p
        is actually the batch dimension, corresponding to batched-together
        parameters of a given shape.


        Args:
           group:   Dict to look up configuration values.
               p: The parameter that we are initializing the state for
           state: Dict from string to whatever state we are initializing
        """
        size_update_period = group["size_update_period"]

        state["step"] = 0

        kwargs = {"device": p.device, "dtype": p.dtype}

        # 'delta' implements conventional momentum.  There are
        # several different kinds of update going on, so rather than
        # compute "exp_avg" like in Adam, we store and decay a
        # parameter-change "delta", which combines all forms of
        # update.  this is equivalent to how it's done in Adam,
        # except for the first few steps.
        state["delta"] = torch.zeros_like(
            p, memory_format=torch.preserve_format
        )

        batch_size = p.shape[0]
        numel = p.numel() // batch_size
        numel = p.numel()

        if numel > 1:
            # "param_rms" just periodically records the scalar root-mean-square value of
            # the parameter tensor.
            # it has a shape like (batch_size, 1, 1, 1, 1)
            param_rms = (
                (p ** 2).mean(dim=list(range(1, p.ndim))
Download .txt
gitextract_bmerrgqi/

├── LICENSE
├── README-ZH.md
├── README.md
├── customs/
│   └── ph.txt
├── data/
│   ├── __init__.py
│   ├── collation.py
│   ├── datamodule.py
│   ├── dataset.py
│   ├── fbank.py
│   ├── input_strategies.py
│   └── tokenizer.py
├── descriptions.py
├── examples.py
├── launch-ui.py
├── macros.py
├── model-card.md
├── models/
│   ├── __init__.py
│   ├── macros.py
│   ├── transformer.py
│   ├── vallex.py
│   └── visualizer.py
├── modules/
│   ├── __init__.py
│   ├── activation.py
│   ├── embedding.py
│   ├── optim.py
│   ├── scaling.py
│   ├── scheduler.py
│   └── transformer.py
├── nltk_data/
│   └── tokenizers/
│       └── punkt/
│           ├── PY3/
│           │   ├── README
│           │   ├── czech.pickle
│           │   ├── danish.pickle
│           │   ├── dutch.pickle
│           │   ├── english.pickle
│           │   ├── estonian.pickle
│           │   ├── finnish.pickle
│           │   ├── french.pickle
│           │   ├── german.pickle
│           │   ├── greek.pickle
│           │   ├── italian.pickle
│           │   ├── malayalam.pickle
│           │   ├── norwegian.pickle
│           │   ├── polish.pickle
│           │   ├── portuguese.pickle
│           │   ├── russian.pickle
│           │   ├── slovene.pickle
│           │   ├── spanish.pickle
│           │   ├── swedish.pickle
│           │   └── turkish.pickle
│           ├── README
│           ├── czech.pickle
│           ├── danish.pickle
│           ├── dutch.pickle
│           ├── english.pickle
│           ├── estonian.pickle
│           ├── finnish.pickle
│           ├── french.pickle
│           ├── german.pickle
│           ├── greek.pickle
│           ├── italian.pickle
│           ├── malayalam.pickle
│           ├── norwegian.pickle
│           ├── polish.pickle
│           ├── portuguese.pickle
│           ├── russian.pickle
│           ├── slovene.pickle
│           ├── spanish.pickle
│           ├── swedish.pickle
│           └── turkish.pickle
├── presets/
│   ├── acou_1.npz
│   ├── acou_2.npz
│   ├── acou_3.npz
│   ├── acou_4.npz
│   ├── alan.npz
│   ├── amused.npz
│   ├── anger.npz
│   ├── babara.npz
│   ├── bronya.npz
│   ├── cafe.npz
│   ├── dingzhen.npz
│   ├── disgust.npz
│   ├── emo_amused.npz
│   ├── emo_anger.npz
│   ├── emo_neutral.npz
│   ├── emo_sleepy.npz
│   ├── emotion_sleepiness.npz
│   ├── en2zh_tts_1.npz
│   ├── en2zh_tts_2.npz
│   ├── en2zh_tts_3.npz
│   ├── en2zh_tts_4.npz
│   ├── esta.npz
│   ├── fuxuan.npz
│   ├── librispeech_1.npz
│   ├── librispeech_2.npz
│   ├── librispeech_3.npz
│   ├── librispeech_4.npz
│   ├── neutral.npz
│   ├── paimon.npz
│   ├── rosalia.npz
│   ├── seel.npz
│   ├── sleepiness.npz
│   ├── vctk_1.npz
│   ├── vctk_2.npz
│   ├── vctk_3.npz
│   ├── vctk_4.npz
│   ├── yaesakura.npz
│   ├── zh2en_tts_1.npz
│   ├── zh2en_tts_2.npz
│   ├── zh2en_tts_3.npz
│   └── zh2en_tts_4.npz
├── prompts/
│   ├── ja-2.ogg
│   └── ph.txt
├── requirements.txt
└── utils/
    ├── __init__.py
    ├── download.py
    ├── g2p/
    │   ├── __init__.py
    │   ├── bpe_1024.json
    │   ├── bpe_69.json
    │   ├── cleaners.py
    │   ├── english.py
    │   ├── japanese.py
    │   ├── mandarin.py
    │   └── symbols.py
    ├── generation.py
    ├── prompt_making.py
    ├── sentence_cutter.py
    └── symbol_table.py
Download .txt
SYMBOL INDEX (341 symbols across 28 files)

FILE: data/collation.py
  class TextTokenCollater (line 10) | class TextTokenCollater:
    method __init__ (line 29) | def __init__(
    method index (line 56) | def index(
    method __call__ (line 87) | def __call__(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tens...
  function get_text_token_collater (line 116) | def get_text_token_collater() -> TextTokenCollater:

FILE: data/datamodule.py
  class _SeedWorkers (line 47) | class _SeedWorkers:
    method __init__ (line 48) | def __init__(self, seed: int):
    method __call__ (line 51) | def __call__(self, worker_id: int):
  function _get_input_strategy (line 55) | def _get_input_strategy(input_strategy, dataset, cuts):
  class TtsDataModule (line 62) | class TtsDataModule:
    method __init__ (line 78) | def __init__(self, args: argparse.Namespace):
    method add_arguments (line 82) | def add_arguments(cls, parser: argparse.ArgumentParser):
    method train_dataloaders (line 222) | def train_dataloaders(
    method valid_dataloaders (line 347) | def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
    method test_dataloaders (line 379) | def test_dataloaders(self, cuts: CutSet) -> DataLoader:
    method train_cuts (line 405) | def train_cuts(self) -> CutSet:
    method dev_cuts (line 412) | def dev_cuts(self) -> CutSet:
    method test_cuts (line 417) | def test_cuts(self) -> CutSet:

FILE: data/dataset.py
  function seq2phone (line 39) | def seq2phone(tokens: Union[List, np.ndarray]):
  class DynamicBatchSampler (line 48) | class DynamicBatchSampler(torch.utils.data.Sampler):
    method __init__ (line 49) | def __init__(self, sampler, num_tokens_fn, num_buckets=100, min_size=0...
    method set_epoch (line 75) | def set_epoch(self, epoch):
    method is_batch_full (line 77) | def is_batch_full(self, num_tokens, batch):
    method __iter__ (line 86) | def __iter__(self):
    method __len__ (line 126) | def __len__(self):
  class AudioDataset (line 131) | class AudioDataset(torch.utils.data.Dataset):
    method __init__ (line 132) | def __init__(self, h5_path, ann_path, tokenizer_path):
    method __len__ (line 146) | def __len__(self):
    method get_dur (line 149) | def get_dur(self, idx):
    method archive (line 153) | def archive(self):
    method __getitem__ (line 157) | def __getitem__(self, idx):
  function collate (line 186) | def collate(batch):
  function create_dataloader (line 225) | def create_dataloader(data_dir="/root/valle/egs/mix", n_gpus=1, rank=0, ...

FILE: data/fbank.py
  class BigVGANFbankConfig (line 29) | class BigVGANFbankConfig:
    method to_dict (line 43) | def to_dict(self) -> Dict[str, Any]:
    method from_dict (line 47) | def from_dict(data: Dict[str, Any]) -> "BigVGANFbankConfig":
  function dynamic_range_compression_torch (line 51) | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
  function spectral_normalize_torch (line 55) | def spectral_normalize_torch(magnitudes):
  class BigVGANFbank (line 62) | class BigVGANFbank(FeatureExtractor):
    method __init__ (line 66) | def __init__(self, config: Optional[Any] = None):
    method _feature_fn (line 80) | def _feature_fn(self, samples, **kwargs):
    method extract (line 133) | def extract(
    method frame_shift (line 150) | def frame_shift(self) -> Seconds:
    method feature_dim (line 153) | def feature_dim(self, sampling_rate: int) -> int:
    method mix (line 157) | def mix(
    method compute_energy (line 172) | def compute_energy(features: np.ndarray) -> float:
  function get_fbank_extractor (line 176) | def get_fbank_extractor() -> BigVGANFbank:

FILE: data/input_strategies.py
  class PromptedFeatures (line 16) | class PromptedFeatures:
    method __init__ (line 17) | def __init__(self, prompts, features):
    method to (line 21) | def to(self, device):
    method sum (line 26) | def sum(self):
    method ndim (line 30) | def ndim(self):
    method data (line 34) | def data(self):

FILE: data/tokenizer.py
  function remove_encodec_weight_norm (line 33) | def remove_encodec_weight_norm(model):
  class AudioTokenizer (line 63) | class AudioTokenizer:
    method __init__ (line 66) | def __init__(
    method device (line 89) | def device(self):
    method encode (line 92) | def encode(self, wav: torch.Tensor) -> torch.Tensor:
    method decode (line 95) | def decode(self, frames: torch.Tensor) -> torch.Tensor:
  function tokenize_audio (line 99) | def tokenize_audio(tokenizer: AudioTokenizer, audio):

FILE: launch-ui.py
  function clear_prompts (line 125) | def clear_prompts():
  function transcribe_one (line 138) | def transcribe_one(model, audio_path):
  function make_npz_prompt (line 162) | def make_npz_prompt(name, uploaded_audio, recorded_audio, transcript_con...
  function make_prompt (line 203) | def make_prompt(name, wav, sr, save=True):
  function infer_from_audio (line 230) | def infer_from_audio(text, language, accent, audio_prompt, record_audio_...
  function infer_from_prompt (line 309) | def infer_from_prompt(text, language, accent, preset_prompt, prompt_file):
  function infer_long_text (line 371) | def infer_long_text(text, preset_prompt, prompt=None, language='auto', a...
  function main (line 511) | def main():

FILE: models/__init__.py
  function add_model_arguments (line 18) | def add_model_arguments(parser: argparse.ArgumentParser):
  function get_model (line 98) | def get_model(params) -> nn.Module:

FILE: models/transformer.py
  class Transformer (line 41) | class Transformer(nn.Module):
    method __init__ (line 47) | def __init__(
    method forward (line 222) | def forward(
    method inference (line 320) | def inference(
    method visualize (line 387) | def visualize(

FILE: models/vallex.py
  class Transpose (line 39) | class Transpose(nn.Identity):
    method forward (line 42) | def forward(self, input: torch.Tensor) -> torch.Tensor:
  class VALLF (line 50) | class VALLF(nn.Module):
    method __init__ (line 55) | def __init__(
    method stage_parameters (line 266) | def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]:
    method stage_named_parameters (line 280) | def stage_named_parameters(
    method pad_y_eos (line 294) | def pad_y_eos(self, y, y_mask_int, eos_id):
    method _prepare_prompts (line 307) | def _prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_code...
    method forward (line 367) | def forward(
    method inference (line 379) | def inference(
    method visualize (line 390) | def visualize(
  class VALLE (line 400) | class VALLE(VALLF):
    method __init__ (line 405) | def __init__(
    method forward (line 447) | def forward(
    method inference (line 458) | def inference(
    method continual (line 688) | def continual(
  function top_k_top_p_filtering (line 791) | def top_k_top_p_filtering(
  function topk_sampling (line 836) | def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):

FILE: models/visualizer.py
  function visualize (line 26) | def visualize(

FILE: modules/activation.py
  function _in_projection_packed (line 12) | def _in_projection_packed(
  function _scaled_dot_product_attention (line 67) | def _scaled_dot_product_attention(
  function multi_head_attention_forward (line 114) | def multi_head_attention_forward(
  class MultiheadAttention (line 170) | class MultiheadAttention(Module):
    method __init__ (line 230) | def __init__(
    method _reset_parameters (line 333) | def _reset_parameters(self):
    method __setstate__ (line 350) | def __setstate__(self, state):
    method forward (line 357) | def forward(
    method infer (line 591) | def infer(self,

FILE: modules/embedding.py
  class TokenEmbedding (line 21) | class TokenEmbedding(nn.Module):
    method __init__ (line 22) | def __init__(
    method weight (line 37) | def weight(self) -> torch.Tensor:
    method embedding (line 40) | def embedding(self, index: int) -> torch.Tensor:
    method forward (line 43) | def forward(self, x: torch.Tensor):
  class SinePositionalEmbedding (line 50) | class SinePositionalEmbedding(nn.Module):
    method __init__ (line 51) | def __init__(
    method extend_pe (line 68) | def extend_pe(self, x):
    method forward (line 93) | def forward(self, x: torch.Tensor) -> torch.Tensor:

FILE: modules/optim.py
  class BatchedOptimizer (line 29) | class BatchedOptimizer(Optimizer):
    method __init__ (line 40) | def __init__(self, params, defaults):
    method batched_params (line 44) | def batched_params(self, param_group, group_params_names):
  class ScaledAdam (line 129) | class ScaledAdam(BatchedOptimizer):
    method __init__ (line 172) | def __init__(
    method __setstate__ (line 212) | def __setstate__(self, state):
    method step (line 216) | def step(self, closure=None):
    method _init_state (line 265) | def _init_state(self, group: dict, p: Tensor, state: dict):
    method _get_clipping_scale (line 316) | def _get_clipping_scale(
    method _show_gradient_dominating_parameter (line 414) | def _show_gradient_dominating_parameter(
    method _step_one_batch (line 479) | def _step_one_batch(
    method _size_update (line 531) | def _size_update(
    method _step (line 598) | def _step(self, group: dict, p: Tensor, state: dict):
    method _step_scalar (line 639) | def _step_scalar(self, group: dict, p: Tensor, state: dict):
  class LRScheduler (line 664) | class LRScheduler(object):
    method __init__ (line 670) | def __init__(self, optimizer: Optimizer, verbose: bool = False):
    method state_dict (line 687) | def state_dict(self):
    method load_state_dict (line 699) | def load_state_dict(self, state_dict):
    method get_last_lr (line 708) | def get_last_lr(self) -> List[float]:
    method get_lr (line 712) | def get_lr(self):
    method step_batch (line 718) | def step_batch(self, batch: Optional[int] = None) -> None:
    method step_epoch (line 730) | def step_epoch(self, epoch: Optional[int] = None):
    method _set_lrs (line 740) | def _set_lrs(self):
    method print_lr (line 750) | def print_lr(self, is_verbose, group, lr):
  class Eden (line 759) | class Eden(LRScheduler):
    method __init__ (line 781) | def __init__(
    method get_lr (line 794) | def get_lr(self):
  function _test_eden (line 810) | def _test_eden():
  class Eve (line 836) | class Eve(Optimizer):
    method __init__ (line 872) | def __init__(
    method __setstate__ (line 908) | def __setstate__(self, state):
    method step (line 912) | def step(self, closure=None):
  function _test_scaled_adam (line 988) | def _test_scaled_adam(hidden_dim: int):

FILE: modules/scaling.py
  class ActivationBalancerFunction (line 35) | class ActivationBalancerFunction(torch.autograd.Function):
    method forward (line 37) | def forward(
    method backward (line 55) | def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
  function _compute_scale_factor (line 76) | def _compute_scale_factor(
  function _compute_sign_factor (line 105) | def _compute_sign_factor(
  class ActivationScaleBalancerFunction (line 141) | class ActivationScaleBalancerFunction(torch.autograd.Function):
    method forward (line 149) | def forward(
    method backward (line 164) | def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
  class RandomClampFunction (line 180) | class RandomClampFunction(torch.autograd.Function):
    method forward (line 182) | def forward(
    method backward (line 201) | def backward(
  function random_clamp (line 212) | def random_clamp(
  function random_cast_to_half (line 222) | def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
  class RandomGradFunction (line 237) | class RandomGradFunction(torch.autograd.Function):
    method forward (line 244) | def forward(ctx, x: Tensor, min_abs: float) -> Tensor:
    method backward (line 249) | def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]:
  class RandomGrad (line 261) | class RandomGrad(torch.nn.Module):
    method __init__ (line 267) | def __init__(self, min_abs: float = 5.0e-06):
    method forward (line 271) | def forward(self, x: Tensor):
  class SoftmaxFunction (line 282) | class SoftmaxFunction(torch.autograd.Function):
    method forward (line 289) | def forward(ctx, x: Tensor, dim: int):
    method backward (line 302) | def backward(ctx, ans_grad: Tensor):
  function softmax (line 312) | def softmax(x: Tensor, dim: int):
  class MaxEigLimiterFunction (line 319) | class MaxEigLimiterFunction(torch.autograd.Function):
    method forward (line 321) | def forward(
    method backward (line 335) | def backward(ctx, x_grad, *args):
  class BasicNorm (line 360) | class BasicNorm(torch.nn.Module):
    method __init__ (line 390) | def __init__(
    method forward (line 409) | def forward(self, x: Tensor) -> Tensor:
  function ScaledLinear (line 427) | def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
  function ScaledConv1d (line 452) | def ScaledConv1d(
  function TransposeScaledConv1d (line 483) | def TransposeScaledConv1d(
  function ScaledConv1dTranspose (line 505) | def ScaledConv1dTranspose(
  function TransposeConv1d (line 527) | def TransposeConv1d(
  function Conv1dTranspose (line 539) | def Conv1dTranspose(
  class SRLinear (line 551) | class SRLinear(nn.Linear):
    method __init__ (line 556) | def __init__(self, in_features, out_features, bias=True, **kwargs):
    method get_sigma (line 566) | def get_sigma(self):
    method get_weight (line 576) | def get_weight(self):
    method forward (line 583) | def forward(self, x):
  class SRConv1d (line 587) | class SRConv1d(SRLinear):
    method __init__ (line 588) | def __init__(
    method forward (line 605) | def forward(self, x):
  function TransposeSRConv1d (line 615) | def TransposeSRConv1d(
  function SRConv1dTranspose (line 627) | def SRConv1dTranspose(
  class ActivationBalancer (line 639) | class ActivationBalancer(torch.nn.Module):
    method __init__ (line 679) | def __init__(
    method forward (line 710) | def forward(self, x: Tensor) -> Tensor:
  function penalize_abs_values_gt (line 764) | def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> T...
  function _diag (line 792) | def _diag(x: Tensor):  # like .diag(), but works for tensors with 3 dims.
  function _whitening_metric (line 803) | def _whitening_metric(x: Tensor, num_groups: int):
  class WhiteningPenaltyFunction (line 841) | class WhiteningPenaltyFunction(torch.autograd.Function):
    method forward (line 843) | def forward(
    method backward (line 857) | def backward(ctx, x_grad: Tensor):
  class Whiten (line 882) | class Whiten(nn.Module):
    method __init__ (line 883) | def __init__(
    method forward (line 924) | def forward(self, x: Tensor) -> Tensor:
  class WithLoss (line 965) | class WithLoss(torch.autograd.Function):
    method forward (line 967) | def forward(ctx, x: Tensor, y: Tensor):
    method backward (line 972) | def backward(ctx, ans_grad: Tensor):
  function with_loss (line 978) | def with_loss(x, y):
  function _no_op (line 985) | def _no_op(x: Tensor) -> Tensor:
  class Identity (line 994) | class Identity(torch.nn.Module):
    method __init__ (line 995) | def __init__(self):
    method forward (line 998) | def forward(self, x):
  class MaxEig (line 1002) | class MaxEig(torch.nn.Module):
    method __init__ (line 1023) | def __init__(
    method forward (line 1053) | def forward(self, x: Tensor) -> Tensor:
    method _set_direction (line 1111) | def _set_direction(self, direction: Tensor):
    method _find_direction_coeffs (line 1126) | def _find_direction_coeffs(
  class DoubleSwishFunction (line 1156) | class DoubleSwishFunction(torch.autograd.Function):
    method forward (line 1173) | def forward(ctx, x: Tensor) -> Tensor:
    method backward (line 1206) | def backward(ctx, y_grad: Tensor) -> Tensor:
  class DoubleSwish (line 1215) | class DoubleSwish(torch.nn.Module):
    method forward (line 1216) | def forward(self, x: Tensor) -> Tensor:
  function BalancedDoubleSwish (line 1225) | def BalancedDoubleSwish(
  function _test_max_eig (line 1240) | def _test_max_eig():
  function _test_whiten (line 1267) | def _test_whiten():
  function _test_activation_balancer_sign (line 1294) | def _test_activation_balancer_sign():
  function _test_activation_balancer_magnitude (line 1320) | def _test_activation_balancer_magnitude():
  function _test_basic_norm (line 1348) | def _test_basic_norm():
  function _test_double_swish_deriv (line 1365) | def _test_double_swish_deriv():
  function _test_softmax (line 1379) | def _test_softmax():

FILE: modules/scheduler.py
  function calc_lr (line 24) | def calc_lr(step, dim_embed, warmup_steps):
  class NoamScheduler (line 30) | class NoamScheduler(torch.optim.lr_scheduler._LRScheduler):
    method __init__ (line 31) | def __init__(
    method get_lr (line 48) | def get_lr(self) -> float:
    method set_step (line 54) | def set_step(self, step: int):
  function get_scheduler (line 58) | def get_scheduler(params, optimizer):

FILE: modules/transformer.py
  class LayerNorm (line 17) | class LayerNorm(nn.Module):
    method __init__ (line 23) | def __init__(
    method reset_parameters (line 52) | def reset_parameters(self) -> None:
    method forward (line 57) | def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
    method extra_repr (line 76) | def extra_repr(self) -> str:
  class AdaptiveLayerNorm (line 83) | class AdaptiveLayerNorm(nn.Module):
    method __init__ (line 86) | def __init__(self, d_model, norm) -> None:
    method forward (line 93) | def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
  class BasicNorm (line 111) | class BasicNorm(_BasicNorm):
    method __init__ (line 112) | def __init__(
    method forward (line 121) | def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
  class BalancedBasicNorm (line 133) | class BalancedBasicNorm(nn.Module):
    method __init__ (line 134) | def __init__(
    method forward (line 151) | def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
  class IdentityNorm (line 160) | class IdentityNorm(nn.Module):
    method __init__ (line 161) | def __init__(
    method forward (line 170) | def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
  class TransformerEncoderLayer (line 178) | class TransformerEncoderLayer(nn.Module):
    method __init__ (line 181) | def __init__(
    method __setstate__ (line 260) | def __setstate__(self, state):
    method forward (line 265) | def forward(
    method infer (line 314) | def infer(
    method _sa_block (line 354) | def _sa_block(
    method _ff_block (line 371) | def _ff_block(self, x: Tensor) -> Tensor:
  class TransformerEncoder (line 376) | class TransformerEncoder(nn.Module):
    method __init__ (line 396) | def __init__(self, encoder_layer, num_layers, norm=None):
    method forward (line 402) | def forward(
    method infer (line 447) | def infer(
  class TransformerDecoderLayer (line 476) | class TransformerDecoderLayer(nn.Module):
    method __init__ (line 479) | def __init__(
    method forward (line 572) | def forward(
    method _sa_block (line 631) | def _sa_block(
    method _mha_block (line 648) | def _mha_block(
    method _ff_block (line 666) | def _ff_block(self, x: Tensor) -> Tensor:
  function _get_clones (line 671) | def _get_clones(module, N):
  function _get_activation_fn (line 675) | def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:

FILE: utils/__init__.py
  class Transpose (line 11) | class Transpose(nn.Identity):
    method forward (line 14) | def forward(self, input: torch.Tensor) -> torch.Tensor:

FILE: utils/download.py
  function download_file_from_google_drive (line 5) | def download_file_from_google_drive(id, destination):
  function get_confirm_token (line 20) | def get_confirm_token(response):
  function save_response_content (line 28) | def save_response_content(response, destination):
  function main (line 37) | def main():

FILE: utils/g2p/__init__.py
  class PhonemeBpeTokenizer (line 11) | class PhonemeBpeTokenizer:
    method __init__ (line 12) | def __init__(self, tokenizer_path = "./utils/g2p/bpe_1024.json"):
    method tokenize (line 15) | def tokenize(self, text):
  function text_to_sequence (line 27) | def text_to_sequence(text, cleaner_names):
  function cleaned_text_to_sequence (line 46) | def cleaned_text_to_sequence(cleaned_text):
  function sequence_to_text (line 57) | def sequence_to_text(sequence):
  function _clean_text (line 66) | def _clean_text(text, cleaner_names):

FILE: utils/g2p/cleaners.py
  function japanese_cleaners (line 6) | def japanese_cleaners(text):
  function japanese_cleaners2 (line 11) | def japanese_cleaners2(text):
  function chinese_cleaners (line 14) | def chinese_cleaners(text):
  function cje_cleaners (line 22) | def cje_cleaners(text):
  function clean_one (line 49) | def clean_one(text):

FILE: utils/g2p/english.py
  function expand_abbreviations (line 87) | def expand_abbreviations(text):
  function collapse_whitespace (line 93) | def collapse_whitespace(text):
  function _remove_commas (line 97) | def _remove_commas(m):
  function _expand_decimal_point (line 101) | def _expand_decimal_point(m):
  function _expand_dollars (line 105) | def _expand_dollars(m):
  function _expand_ordinal (line 126) | def _expand_ordinal(m):
  function _expand_number (line 130) | def _expand_number(m):
  function normalize_numbers (line 145) | def normalize_numbers(text):
  function mark_dark_l (line 155) | def mark_dark_l(text):
  function english_to_ipa (line 159) | def english_to_ipa(text):
  function english_to_lazy_ipa (line 169) | def english_to_lazy_ipa(text):
  function english_to_ipa2 (line 176) | def english_to_ipa2(text):
  function english_to_lazy_ipa2 (line 184) | def english_to_lazy_ipa2(text):

FILE: utils/g2p/japanese.py
  function symbols_to_japanese (line 68) | def symbols_to_japanese(text):
  function japanese_to_romaji_with_accent (line 74) | def japanese_to_romaji_with_accent(text):
  function get_real_sokuon (line 116) | def get_real_sokuon(text):
  function get_real_hatsuon (line 122) | def get_real_hatsuon(text):
  function japanese_to_ipa (line 128) | def japanese_to_ipa(text):
  function japanese_to_ipa2 (line 139) | def japanese_to_ipa2(text):
  function japanese_to_ipa3 (line 148) | def japanese_to_ipa3(text):

FILE: utils/g2p/mandarin.py
  function number_to_chinese (line 235) | def number_to_chinese(text):
  function chinese_to_bopomofo (line 242) | def chinese_to_bopomofo(text):
  function latin_to_bopomofo (line 260) | def latin_to_bopomofo(text):
  function bopomofo_to_romaji (line 266) | def bopomofo_to_romaji(text):
  function bopomofo_to_ipa (line 272) | def bopomofo_to_ipa(text):
  function bopomofo_to_ipa2 (line 278) | def bopomofo_to_ipa2(text):
  function chinese_to_romaji (line 284) | def chinese_to_romaji(text):
  function chinese_to_lazy_ipa (line 297) | def chinese_to_lazy_ipa(text):
  function chinese_to_ipa (line 304) | def chinese_to_ipa(text):
  function chinese_to_ipa2 (line 317) | def chinese_to_ipa2(text):

FILE: utils/generation.py
  function preload_models (line 50) | def preload_models():
  function generate_audio (line 92) | def generate_audio(text, prompt=None, language='auto', accent='no-accent'):
  function generate_audio_from_long_text (line 155) | def generate_audio_from_long_text(text, prompt=None, language='auto', ac...

FILE: utils/prompt_making.py
  function transcribe_one (line 33) | def transcribe_one(model, audio_path):
  function make_prompt (line 57) | def make_prompt(name, audio_prompt_path, transcript=None):
  function make_transcript (line 87) | def make_transcript(name, wav, sr, transcript=None):

FILE: utils/sentence_cutter.py
  function split_text_into_sentences (line 7) | def split_text_into_sentences(text):

FILE: utils/symbol_table.py
  class SymbolTable (line 31) | class SymbolTable(Generic[Symbol]):
    method __post_init__ (line 57) | def __post_init__(self):
    method from_str (line 76) | def from_str(s: str) -> 'SymbolTable':
    method from_file (line 109) | def from_file(filename: str) -> 'SymbolTable':
    method to_str (line 133) | def to_str(self) -> str:
    method to_file (line 144) | def to_file(self, filename: str):
    method add (line 165) | def add(self, symbol: Symbol, index: Optional[int] = None) -> int:
    method get (line 197) | def get(self, k: Union[int, Symbol]) -> Union[Symbol, int]:
    method merge (line 214) | def merge(self, other: 'SymbolTable') -> 'SymbolTable':
    method _check_compatible (line 233) | def _check_compatible(self, other: 'SymbolTable') -> None:
    method __getitem__ (line 250) | def __getitem__(self, item: Union[int, Symbol]) -> Union[Symbol, int]:
    method __contains__ (line 253) | def __contains__(self, item: Union[int, Symbol]) -> bool:
    method __len__ (line 259) | def __len__(self) -> int:
    method __eq__ (line 262) | def __eq__(self, other: 'SymbolTable') -> bool:
    method ids (line 273) | def ids(self) -> List[int]:
    method symbols (line 281) | def symbols(self) -> List[Symbol]:
Copy disabled (too large) Download .json
Condensed preview — 126 files, each showing path, character count, and a content snippet. Download the .json file for the full structured content (22,560K chars).
[
  {
    "path": "LICENSE",
    "chars": 1065,
    "preview": "MIT License\n\nCopyright (c) 2023 Songting\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\no"
  },
  {
    "path": "README-ZH.md",
    "chars": 9728,
    "preview": "# VALL-E X: 多语言文本到语音合成与语音克隆 🔊\n[![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?style=for-the-badge&logo=di"
  },
  {
    "path": "README.md",
    "chars": 14238,
    "preview": "# VALL-E X: Multilingual Text-to-Speech Synthesis and Voice Cloning 🔊\n[![Discord](https://img.shields.io/badge/Discord-%"
  },
  {
    "path": "customs/ph.txt",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "data/__init__.py",
    "chars": 80,
    "preview": "# from .datamodule import *\n# from .tokenizer import *\nfrom .collation import *\n"
  },
  {
    "path": "data/collation.py",
    "chars": 3482,
    "preview": "from pathlib import Path\nfrom typing import List, Tuple\n\nimport numpy as np\nimport torch\n\nfrom utils import SymbolTable\n"
  },
  {
    "path": "data/datamodule.py",
    "chars": 14895,
    "preview": "# Copyright      2023                          (authors: Feiteng Li)\n#\n# See ../../../../LICENSE for clarification regar"
  },
  {
    "path": "data/dataset.py",
    "chars": 9310,
    "preview": "# Copyright      2023                           (authors: Feiteng Li)\n#\n# See ../../../../LICENSE for clarification rega"
  },
  {
    "path": "data/fbank.py",
    "chars": 6817,
    "preview": "# Copyright      2023                          (authors: Feiteng Li)\n#\n# See ../../../../LICENSE for clarification regar"
  },
  {
    "path": "data/input_strategies.py",
    "chars": 5560,
    "preview": "import random\nfrom collections import defaultdict\nfrom concurrent.futures import ThreadPoolExecutor\nfrom typing import T"
  },
  {
    "path": "data/tokenizer.py",
    "chars": 4306,
    "preview": "#!/usr/bin/env python3\n# Copyright    2023                            (authors: Feiteng Li)\n#\n# Licensed under the Apach"
  },
  {
    "path": "descriptions.py",
    "chars": 2304,
    "preview": "top_md = \"\"\"\n# VALL-E X  \nVALL-E X can synthesize high-quality personalized speech with only a 3-second enrolled recordi"
  },
  {
    "path": "examples.py",
    "chars": 1391,
    "preview": "infer_from_audio_examples = [\n    [\"This is how this machine has taken my voice.\", 'English', 'no-accent', \"prompts/en-2"
  },
  {
    "path": "launch-ui.py",
    "chars": 27081,
    "preview": "# coding: utf-8\nimport argparse\nimport logging\nimport os\nimport pathlib\nimport time\nimport tempfile\nimport platform\nimpo"
  },
  {
    "path": "macros.py",
    "chars": 483,
    "preview": "NUM_LAYERS = 12\nNUM_HEAD = 16\nN_DIM = 1024\nPREFIX_MODE = 1\nNUM_QUANTIZERS = 8\nSAMPLE_RATE = 24000\n\nlang2token = {\n    'z"
  },
  {
    "path": "model-card.md",
    "chars": 1425,
    "preview": "# Model Card: VALL-E X\n\n**Author**: [Songting](https://github.com/Plachtaa).<br>\n<br>\nThis is the official codebase for "
  },
  {
    "path": "models/__init__.py",
    "chars": 3898,
    "preview": "import argparse\n\nimport torch.nn as nn\n# from icefall.utils import AttributeDict, str2bool\n\nfrom .macros import (\n    NU"
  },
  {
    "path": "models/macros.py",
    "chars": 201,
    "preview": "# Text\nNUM_TEXT_TOKENS = 2048\n\n# Audio\nNUM_AUDIO_TOKENS = 1024  # EnCodec RVQ bins\nNUM_MEL_BINS = 100  # BigVGAN bigvgan"
  },
  {
    "path": "models/transformer.py",
    "chars": 13248,
    "preview": "# Copyright    2023                             (authors: Feiteng Li)\n#\n# Licensed under the Apache License, Version 2.0"
  },
  {
    "path": "models/vallex.py",
    "chars": 31575,
    "preview": "# Copyright    2023                             (authors: Feiteng Li)\n#\n# Licensed under the Apache License, Version 2.0"
  },
  {
    "path": "models/visualizer.py",
    "chars": 3368,
    "preview": "#!/usr/bin/env python3\n# Copyright    2023                           (authors: Feiteng Li)\n#\n# See ../../../../LICENSE f"
  },
  {
    "path": "modules/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "modules/activation.py",
    "chars": 26979,
    "preview": "from typing import Optional, Tuple, List\nimport math\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Linear,"
  },
  {
    "path": "modules/embedding.py",
    "chars": 3198,
    "preview": "# Copyright    2023                             (authors: Feiteng Li)\n#\n# Licensed under the Apache License, Version 2.0"
  },
  {
    "path": "modules/optim.py",
    "chars": 42932,
    "preview": "# Copyright      2022  Xiaomi Corp.        (authors: Daniel Povey)\n#\n# See ../LICENSE for clarification regarding multip"
  },
  {
    "path": "modules/scaling.py",
    "chars": 49568,
    "preview": "# Copyright    2022  Xiaomi Corp.        (authors: Daniel Povey)\n#\n# See ../../../../LICENSE for clarification regarding"
  },
  {
    "path": "modules/scheduler.py",
    "chars": 2449,
    "preview": "#!/usr/bin/env python3\n# Copyright    2023                           (authors: Feiteng Li)\n#\n# See ../../../../LICENSE f"
  },
  {
    "path": "modules/transformer.py",
    "chars": 22441,
    "preview": "import copy\nimport numbers\nfrom functools import partial\nfrom typing import Any, Callable, List, Optional, Tuple, Union\n"
  },
  {
    "path": "nltk_data/tokenizers/punkt/PY3/README",
    "chars": 8567,
    "preview": "Pretrained Punkt Models -- Jan Strunk (New version trained after issues 313 and 514 had been corrected)\n\nMost models wer"
  },
  {
    "path": "nltk_data/tokenizers/punkt/README",
    "chars": 8567,
    "preview": "Pretrained Punkt Models -- Jan Strunk (New version trained after issues 313 and 514 had been corrected)\n\nMost models wer"
  },
  {
    "path": "nltk_data/tokenizers/punkt/czech.pickle",
    "chars": 1232394,
    "preview": "ccopy_reg\n_reconstructor\np0\n(cnltk.tokenize.punkt\nPunktSentenceTokenizer\np1\nc__builtin__\nobject\np2\nNtp3\nRp4\n(dp5\nS'_Toke"
  },
  {
    "path": "nltk_data/tokenizers/punkt/danish.pickle",
    "chars": 1249511,
    "preview": "ccopy_reg\n_reconstructor\np0\n(cnltk.tokenize.punkt\nPunktSentenceTokenizer\np1\nc__builtin__\nobject\np2\nNtp3\nRp4\n(dp5\nS'_Toke"
  },
  {
    "path": "nltk_data/tokenizers/punkt/dutch.pickle",
    "chars": 741931,
    "preview": "ccopy_reg\n_reconstructor\np0\n(cnltk.tokenize.punkt\nPunktSentenceTokenizer\np1\nc__builtin__\nobject\np2\nNtp3\nRp4\n(dp5\nS'_Toke"
  },
  {
    "path": "nltk_data/tokenizers/punkt/english.pickle",
    "chars": 433305,
    "preview": "ccopy_reg\n_reconstructor\np0\n(cnltk.tokenize.punkt\nPunktSentenceTokenizer\np1\nc__builtin__\nobject\np2\nNtp3\nRp4\n(dp5\nS'_Toke"
  },
  {
    "path": "nltk_data/tokenizers/punkt/estonian.pickle",
    "chars": 1575990,
    "preview": "ccopy_reg\n_reconstructor\np0\n(cnltk.tokenize.punkt\nPunktSentenceTokenizer\np1\nc__builtin__\nobject\np2\nNtp3\nRp4\n(dp5\nS'_Toke"
  },
  {
    "path": "nltk_data/tokenizers/punkt/finnish.pickle",
    "chars": 1917958,
    "preview": "ccopy_reg\n_reconstructor\np0\n(cnltk.tokenize.punkt\nPunktSentenceTokenizer\np1\nc__builtin__\nobject\np2\nNtp3\nRp4\n(dp5\nS'_Toke"
  },
  {
    "path": "nltk_data/tokenizers/punkt/french.pickle",
    "chars": 573905,
    "preview": "ccopy_reg\n_reconstructor\np0\n(cnltk.tokenize.punkt\nPunktSentenceTokenizer\np1\nc__builtin__\nobject\np2\nNtp3\nRp4\n(dp5\nS'_Toke"
  },
  {
    "path": "nltk_data/tokenizers/punkt/german.pickle",
    "chars": 1514390,
    "preview": "ccopy_reg\n_reconstructor\np0\n(cnltk.tokenize.punkt\nPunktSentenceTokenizer\np1\nc__builtin__\nobject\np2\nNtp3\nRp4\n(dp5\nS'_Toke"
  },
  {
    "path": "nltk_data/tokenizers/punkt/greek.pickle",
    "chars": 1953106,
    "preview": "ccopy_reg\n_reconstructor\np0\n(cnltk.tokenize.punkt\nPunktSentenceTokenizer\np1\nc__builtin__\nobject\np2\nNtp3\nRp4\n(dp5\nS'_Toke"
  },
  {
    "path": "nltk_data/tokenizers/punkt/italian.pickle",
    "chars": 655663,
    "preview": "ccopy_reg\n_reconstructor\np0\n(cnltk.tokenize.punkt\nPunktSentenceTokenizer\np1\nc__builtin__\nobject\np2\nNtp3\nRp4\n(dp5\nS'_Toke"
  },
  {
    "path": "nltk_data/tokenizers/punkt/norwegian.pickle",
    "chars": 1250709,
    "preview": "ccopy_reg\n_reconstructor\np0\n(cnltk.tokenize.punkt\nPunktSentenceTokenizer\np1\nc__builtin__\nobject\np2\nNtp3\nRp4\n(dp5\nS'_Toke"
  },
  {
    "path": "nltk_data/tokenizers/punkt/polish.pickle",
    "chars": 2038149,
    "preview": "ccopy_reg\n_reconstructor\np0\n(cnltk.tokenize.punkt\nPunktSentenceTokenizer\np1\nc__builtin__\nobject\np2\nNtp3\nRp4\n(dp5\nS'_Toke"
  },
  {
    "path": "nltk_data/tokenizers/punkt/portuguese.pickle",
    "chars": 642254,
    "preview": "ccopy_reg\n_reconstructor\np0\n(cnltk.tokenize.punkt\nPunktSentenceTokenizer\np1\nc__builtin__\nobject\np2\nNtp3\nRp4\n(dp5\nS'_Toke"
  },
  {
    "path": "nltk_data/tokenizers/punkt/slovene.pickle",
    "chars": 832859,
    "preview": "ccopy_reg\n_reconstructor\np0\n(cnltk.tokenize.punkt\nPunktSentenceTokenizer\np1\nc__builtin__\nobject\np2\nNtp3\nRp4\n(dp5\nS'_Toke"
  },
  {
    "path": "nltk_data/tokenizers/punkt/spanish.pickle",
    "chars": 592363,
    "preview": "ccopy_reg\n_reconstructor\np0\n(cnltk.tokenize.punkt\nPunktSentenceTokenizer\np1\nc__builtin__\nobject\np2\nNtp3\nRp4\n(dp5\nS'_Toke"
  },
  {
    "path": "nltk_data/tokenizers/punkt/swedish.pickle",
    "chars": 1016247,
    "preview": "ccopy_reg\n_reconstructor\np0\n(cnltk.tokenize.punkt\nPunktSentenceTokenizer\np1\nc__builtin__\nobject\np2\nNtp3\nRp4\n(dp5\nS'_Toke"
  },
  {
    "path": "nltk_data/tokenizers/punkt/turkish.pickle",
    "chars": 1211551,
    "preview": "ccopy_reg\n_reconstructor\np0\n(cnltk.tokenize.punkt\nPunktSentenceTokenizer\np1\nc__builtin__\nobject\np2\nNtp3\nRp4\n(dp5\nS'_Toke"
  },
  {
    "path": "prompts/ph.txt",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "requirements.txt",
    "chars": 226,
    "preview": "soundfile\nnumpy\ntorch\ntorchvision\ntorchaudio\ntokenizers\nencodec\nlangid\nwget\nunidecode\npyopenjtalk-prebuilt\npypinyin\ninfl"
  },
  {
    "path": "utils/__init__.py",
    "chars": 338,
    "preview": "import torch\nimport torch.nn as nn\n# from icefall.utils import make_pad_mask\n\nfrom .symbol_table import SymbolTable\n\n# m"
  },
  {
    "path": "utils/download.py",
    "chars": 1275,
    "preview": "import sys\nimport requests\n\n\ndef download_file_from_google_drive(id, destination):\n    URL = \"https://docs.google.com/uc"
  },
  {
    "path": "utils/g2p/__init__.py",
    "chars": 2375,
    "preview": "\"\"\" from https://github.com/keithito/tacotron \"\"\"\nimport utils.g2p.cleaners\nfrom utils.g2p.symbols import symbols\nfrom t"
  },
  {
    "path": "utils/g2p/bpe_1024.json",
    "chars": 33402,
    "preview": "{\n  \"version\": \"1.0\",\n  \"truncation\": null,\n  \"padding\": null,\n  \"added_tokens\": [\n    {\n      \"id\": 0,\n      \"content\":"
  },
  {
    "path": "utils/g2p/bpe_69.json",
    "chars": 2401,
    "preview": "{\n  \"version\": \"1.0\",\n  \"truncation\": null,\n  \"padding\": null,\n  \"added_tokens\": [\n    {\n      \"id\": 0,\n      \"content\":"
  },
  {
    "path": "utils/g2p/cleaners.py",
    "chars": 2301,
    "preview": "import re\nfrom utils.g2p.japanese import japanese_to_romaji_with_accent, japanese_to_ipa, japanese_to_ipa2, japanese_to_"
  },
  {
    "path": "utils/g2p/english.py",
    "chars": 5333,
    "preview": "\"\"\" from https://github.com/keithito/tacotron \"\"\"\n\n'''\nCleaners are transformations that run over the input text at both"
  },
  {
    "path": "utils/g2p/japanese.py",
    "chars": 4936,
    "preview": "import re\nfrom unidecode import unidecode\n\n\n\n# Regular expression matching Japanese without punctuation marks:\n_japanese"
  },
  {
    "path": "utils/g2p/mandarin.py",
    "chars": 6941,
    "preview": "import os\nimport sys\nimport re\nimport jieba\nimport cn2an\nimport logging\n\n\n# List of (Latin alphabet, bopomofo) pairs:\n_l"
  },
  {
    "path": "utils/g2p/symbols.py",
    "chars": 1676,
    "preview": "'''\nDefines the set of symbols used in text input to the model.\n'''\n\n# japanese_cleaners\n# _pad        = '_'\n# _punctuat"
  },
  {
    "path": "utils/generation.py",
    "chars": 11089,
    "preview": "# coding: utf-8\nimport os\nimport torch\nfrom vocos import Vocos\nimport logging\nimport langid\nlangid.set_languages(['en', "
  },
  {
    "path": "utils/prompt_making.py",
    "chars": 3927,
    "preview": "import os\nimport torch\nimport torchaudio\nimport logging\nimport langid\nimport whisper\nlangid.set_languages(['en', 'zh', '"
  },
  {
    "path": "utils/sentence_cutter.py",
    "chars": 1922,
    "preview": "import nltk\nimport jieba\nimport sudachipy\nimport langid\nlangid.set_languages(['en', 'zh', 'ja'])\n\ndef split_text_into_se"
  },
  {
    "path": "utils/symbol_table.py",
    "chars": 9350,
    "preview": "# Copyright      2020  Mobvoi Inc.        (authors: Fangjun Kuang)\n#\n# See ../../../LICENSE for clarification regarding "
  }
]

// ... and 63 more files (download for full content)

About this extraction

This page contains the full source code of the Plachtaa/VALL-E-X GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 126 files (18.9 MB), approximately 5.0M tokens, and a symbol index with 341 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!