Repository: billwuhao/ComfyUI_MegaTTS3
Branch: main
Commit: bad7ccb58afc
Files: 40
Total size: 284.9 KB
Directory structure:
gitextract_pb8s4_gh/
├── .github/
│ └── workflows/
│ └── publish_action.yml
├── .gitignore
├── LICENSE
├── README-CN.md
├── README.md
├── __init__.py
├── megatts3node.py
├── pyproject.toml
├── requirements.txt
├── tts/
│ ├── frontend_function.py
│ ├── modules/
│ │ ├── aligner/
│ │ │ └── whisper_small.py
│ │ ├── ar_dur/
│ │ │ ├── ar_dur_predictor.py
│ │ │ └── commons/
│ │ │ ├── layers.py
│ │ │ ├── nar_tts_modules.py
│ │ │ ├── rel_transformer.py
│ │ │ ├── rot_transformer.py
│ │ │ ├── seq_utils.py
│ │ │ └── transformer.py
│ │ ├── llm_dit/
│ │ │ ├── cfm.py
│ │ │ ├── dit.py
│ │ │ ├── time_embedding.py
│ │ │ └── transformer.py
│ │ └── wavvae/
│ │ ├── decoder/
│ │ │ ├── diag_gaussian.py
│ │ │ ├── hifigan_modules.py
│ │ │ ├── seanet_encoder.py
│ │ │ └── wavvae_v3.py
│ │ └── encoder/
│ │ └── common_modules/
│ │ ├── conv.py
│ │ ├── lstm.py
│ │ └── seanet.py
│ └── utils/
│ ├── audio_utils/
│ │ ├── align.py
│ │ ├── io.py
│ │ └── plot.py
│ ├── commons/
│ │ ├── ckpt_utils.py
│ │ └── hparams.py
│ └── text_utils/
│ ├── dict.json
│ ├── ph_tone_convert.py
│ ├── split_text.py
│ └── text_encoder.py
└── workflow-examples/
├── 单人语音.json
└── 双人会话.json
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/publish_action.yml
================================================
name: Publish to Comfy registry
on:
workflow_dispatch:
push:
branches:
- master
- main
paths:
- "pyproject.toml"
jobs:
publish-node:
name: Publish Custom Node to registry
runs-on: ubuntu-latest
steps:
- name: Check out code
uses: actions/checkout@v4
- name: Publish Custom Node
uses: Comfy-Org/publish-node-action@main
with:
## Add your own personal access token to your Github Repository secrets and reference it here.
personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# Ruff stuff:
.ruff_cache/
# PyPI configuration file
.pypirc
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [2025] ByteDance
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README-CN.md
================================================
[中文](README-CN.md) | [English](README.md)
# ComfyUI 的 MegaTTS3 声音克隆节点
声音克隆质量非常高, 支持中英文, 并可跨语言克隆. **支持自定义音色!!! 超长文本!!! 双人对话!!! Windows 正常安装 pynini, 不再是阉割版 TTS!!!**.
## 📣 更新
[2025-06-07]⚒️: v2.0.0. **支持自定义音色, 支持超长文本, 支持双人对话, Windows 正常安装 pynini, 不再是阉割版 TTS!**.
```
[S1] MegaTTS 真开源版本来了,效果666
[S2] 晕 xuan4 是一种 gan3 觉
[S1] 我爱你!I love you!“我爱你”的英语是“I love you”
[S2] 2.5平方电线,共465篇,约315万字
[S1] 2002年的第一场雪,下在了2003年
```
https://github.com/user-attachments/assets/b734e6bd-9303-4311-b3a4-618241ca6535
[2025-04-28]⚒️: 新增预览音色节点, 先预览音色, 满意再进行克隆. 感谢 @chenpipi0807 的 idea😍. 可在 `speakers` 文件夹下分门别类建更多文件夹.
[2025-04-06]⚒️: 发布 v1.0.0.
## 使用
- 单人克隆(超长文本用空行隔开):

- 双人对话:

## 安装
- **Windows 先安装以下依赖**:
[pynini-windows-wheels](https://github.com/billwuhao/pynini-windows-wheels/releases/tag/v2.1.6.post1) 下载相应 python 版本的 pynini 轮子.
示例:
```
D:\AIGC\python\py310\python.exe -m pip install pynini-2.1.6.post1-cp3xx-cp3xx-win_amd64.whl
D:\AIGC\python\py310\python.exe -m pip install importlib_resources
D:\AIGC\python\py310\python.exe -m pip install WeTextProcessing>=1.0.4 --no-deps
```
- **然后正常进行下列安装**:
```
cd ComfyUI/custom_nodes
git clone https://github.com/billwuhao/ComfyUI_MegaTTS3.git
cd ComfyUI_MegaTTS3
pip install -r requirements.txt
# python_embeded
./python_embeded/python.exe -m pip install -r requirements.txt
```
## 模型下载
- 模型和音色需要手动下载放到 `ComfyUI\models\TTS` 路径下:
[MegaTTS3](https://huggingface.co/ByteDance/MegaTTS3/tree/main) 整个文件夹全部下载放到 `TTS` 文件夹下.
- **VAE 编码模型, 加微信公众号获取, 放到 `TTS\MegaTTS3\wavvae` 文件夹下, 即可自定义音色而无需 `.npy` 文件**:
- [Google 云盘](https://drive.google.com/drive/folders/1p9GNdNJqeK_94lIJW8lew_G3EazU-9Wx?usp=sharing)
- 请将音频放到 `TTS\speakers` 目录下. 我将会把所有 TTS 节点的说话者音频全部统一放到 `ComfyUI\models\TTS\speakers` 路径下, 这些节点包括 `IndexTTS, CSM, Dia, KokoroTTS, MegaTTS, QuteTTS, SparkTTS, StepAudioTTS` 等.
结构如下:
```
.
│ .gitattributes
│ config.json
│ README.md
│
├─aligner_lm
│ config.yaml
│ model_only_last.ckpt
│
├─diffusion_transformer
│ config.yaml
│ model_only_last.ckpt
│
├─duration_lm
│ config.yaml
│ model_only_last.ckpt
│
├─g2p
│ added_tokens.json
│ config.json
│ generation_config.json
│ latest
│ merges.txt
│ model.safetensors
│ special_tokens_map.json
│ tokenizer.json
│ tokenizer_config.json
│ trainer_state.json
│ vocab.json
│
└─wavvae
config.yaml
decoder.ckpt
model_only_last.ckpt
```
## 鸣谢
- [MegaTTS3](https://github.com/bytedance/MegaTTS3)
## 打赏
您的赞赏是我最大的动力! 感谢您支持我一杯咖啡!
================================================
FILE: README.md
================================================
[中文](README-CN.md) | [English](README.md)
# MegaTTS3 Voice Cloning Nodes for ComfyUI
High-quality voice cloning, supporting both Chinese and English, with cross-lingual cloning capabilities. **Supports custom voice cloning!!! Extra-long text!!! Two-person dialogue!!! Full pynini installation on Windows, no more stripped-down TTS!!!**.
## 📣 Updates
[2025-06-07]⚒️: v2.0.0. **Supports custom voice cloning, extra-long text, two-person dialogue, and full pynini installation on Windows, no more stripped-down TTS!**.
```
[S1] MegaTTS 真开源版本来了,效果666
[S2] 晕 xuan4 是一种 gan3 觉
[S1] 我爱你!I love you!“我爱你”的英语是“I love you”
[S2] 2.5平方电线,共465篇,约315万字
[S1] 2002年的第一场雪,下在了2003年
```
https://github.com/user-attachments/assets/b734e6bd-9303-4311-b3a4-618241ca6535
[2025-04-28]⚒️: Added a voice preview node. Preview the voice first, then clone if you're satisfied. Thanks to @chenpipi0807 for the idea😍. You can create categorized subfolders within the `speakers` folder.
[2025-04-06]⚒️: Released v1.0.0.
## Usage
- Single-person cloning (separate long text with blank lines):

- Two-person dialogue:

## Installation
- **For Windows, install the following dependencies first**:
[pynini-windows-wheels](https://github.com/billwuhao/pynini-windows-wheels/releases/tag/v2.1.6.post1) Download the pynini wheel file corresponding to your Python version.
Example:
```
D:\AIGC\python\py310\python.exe -m pip install pynini-2.1.6.post1-cp3xx-cp3xx-win_amd64.whl
D:\AIGC\python\py310\python.exe -m pip install importlib_resources
D:\AIGC\python\py310\python.exe -m pip install WeTextProcessing>=1.0.4 --no-deps
```
- **Then, proceed with the normal installation**:
```
cd ComfyUI/custom_nodes
git clone https://github.com/billwuhao/ComfyUI_MegaTTS3.git
cd ComfyUI_MegaTTS3
pip install -r requirements.txt
# For python_embeded
./python_embeded/python.exe -m pip install -r requirements.txt
```
## Model Download
- Models and voices need to be downloaded manually and placed in the `ComfyUI\models\TTS` directory:
[MegaTTS3](https://huggingface.co/ByteDance/MegaTTS3/tree/main) Download the entire folder and place it in the `TTS` directory.
- **For the VAE encoder model, which enables custom voice cloning without `.npy` files, please follow our WeChat Official Account to obtain it. Place it in the `TTS\MegaTTS3\wavvae` folder**:
- [Google Cloud Drive](https://drive.google.com/drive/folders/1p9GNdNJqeK_94lIJW8lew_G3EazU-9Wx?usp=sharing)
- Please place the audio in the `TTS\speakers` directory. I will unify all speaker audios for TTS nodes into the `ComfyUI\models\TTS\speakers` path. These nodes include `IndexTTS, CSM, Dia, KokoroTTS, MegaTTS, QuteTTS, SparkTTS, StepAudioTTS`, etc.
The structure is as follows:
```
.
│ .gitattributes
│ config.json
│ README.md
│
├─aligner_lm
│ config.yaml
│ model_only_last.ckpt
│
├─diffusion_transformer
│ config.yaml
│ model_only_last.ckpt
│
├─duration_lm
│ config.yaml
│ model_only_last.ckpt
│
├─g2p
│ added_tokens.json
│ config.json
│ generation_config.json
│ latest
│ merges.txt
│ model.safetensors
│ special_tokens_map.json
│ tokenizer.json
│ tokenizer_config.json
│ trainer_state.json
│ vocab.json
│
└─wavvae
config.yaml
decoder.ckpt
model_only_last.ckpt
```
## Credits
- [MegaTTS3](https://github.com/bytedance/MegaTTS3)
## Donation
Your appreciation is my greatest motivation! Thank you for supporting me with a cup of coffee!
================================================
FILE: __init__.py
================================================
from .megatts3node import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
================================================
FILE: megatts3node.py
================================================
import json
import os
import librosa
import numpy as np
import torch
import torchaudio
from typing import List, Union, Optional
from tn.chinese.normalizer import Normalizer as ZhNormalizer
from tn.english.normalizer import Normalizer as EnNormalizer
from langdetect import detect as classify_language
import pyloudnorm as pyln
import folder_paths
import gc
import re
import sys
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
sys.path.append(current_dir)
from tts.modules.ar_dur.commons.nar_tts_modules import LengthRegulator
from tts.frontend_function import g2p, align, make_dur_prompt, dur_pred, prepare_inputs_for_dit
from tts.utils.audio_utils.io import convert_to_wav_bytes, combine_audio_segments
from tts.utils.commons.ckpt_utils import load_ckpt
from tts.utils.commons.hparams import set_hparams, hparams
from tts.utils.text_utils.text_encoder import TokenTextEncoder
from tts.utils.text_utils.split_text import chunk_text_chinese, chunk_text_english, chunk_text_chinesev2
from tts.utils.commons.hparams import hparams, set_hparams
models_dir = folder_paths.models_dir
model_path = os.path.join(models_dir, "TTS")
speakers_dir = os.path.join(model_path, "speakers")
cache_dir = folder_paths.get_temp_directory()
def get_all_files(
root_dir: str,
return_type: str = "list",
extensions: Optional[List[str]] = None,
exclude_dirs: Optional[List[str]] = None,
relative_path: bool = False
) -> Union[List[str], dict]:
"""
递归获取目录下所有文件路径
:param root_dir: 要遍历的根目录
:param return_type: 返回类型 - "list"(列表) 或 "dict"(按目录分组)
:param extensions: 可选的文件扩展名过滤列表 (如 ['.py', '.txt'])
:param exclude_dirs: 要排除的目录名列表 (如 ['__pycache__', '.git'])
:param relative_path: 是否返回相对路径 (相对于root_dir)
:return: 文件路径列表或字典
"""
file_paths = []
file_dict = {}
# 规范化目录路径
root_dir = os.path.normpath(root_dir)
for dirpath, dirnames, filenames in os.walk(root_dir):
# 处理排除目录
if exclude_dirs:
dirnames[:] = [d for d in dirnames if d not in exclude_dirs]
current_files = []
for filename in filenames:
# 扩展名过滤
if extensions:
if not any(filename.lower().endswith(ext.lower()) for ext in extensions):
continue
# 构建完整路径
full_path = os.path.join(dirpath, filename)
# 处理相对路径
if relative_path:
full_path = os.path.relpath(full_path, root_dir)
current_files.append(full_path)
if return_type == "dict":
# 使用相对路径或绝对路径作为键
dict_key = os.path.relpath(dirpath, root_dir) if relative_path else dirpath
if current_files:
file_dict[dict_key] = current_files
else:
file_paths.extend(current_files)
return file_dict if return_type == "dict" else file_paths
def get_speakers():
if not os.path.exists(speakers_dir):
os.makedirs(speakers_dir, exist_ok=True)
return []
speakers = get_all_files(speakers_dir, extensions=[".wav", ".mp3", ".flac", ".mp4", ".WAV", ".MP3", ".FLAC", ".MP4"], relative_path=True)
return speakers
class MegaTTS3DiTInfer():
def __init__(
self,
device=None,
ckpt_root=os.path.join(model_path, "MegaTTS3"),
dit_exp_name='diffusion_transformer',
frontend_exp_name='aligner_lm',
wavvae_exp_name='wavvae',
dur_ckpt_path='duration_lm',
g2p_exp_name='g2p',
precision=torch.float16,
**kwargs
):
self.sr = 24000
self.fm = 8
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device
self.precision = precision
# build models
self.dit_exp_name = os.path.join(ckpt_root, dit_exp_name)
self.frontend_exp_name = os.path.join(ckpt_root, frontend_exp_name)
self.wavvae_exp_name = os.path.join(ckpt_root, wavvae_exp_name)
self.dur_exp_name = os.path.join(ckpt_root, dur_ckpt_path)
self.g2p_exp_name = os.path.join(ckpt_root, g2p_exp_name)
self.build_model(self.device)
# init text normalizer
self.zh_normalizer = ZhNormalizer(overwrite_cache=False, remove_erhua=False, remove_interjections=False)
self.en_normalizer = EnNormalizer(overwrite_cache=False)
# loudness meter
self.loudness_meter = pyln.Meter(self.sr)
self.ph_ref = None
self.tone_ref = None
self.mel2ph_ref = None
self.vae_latent = None
self.ctx_dur_tokens = None
self.incremental_state_dur_prompt = None
self.audio_bytes = None
def clean(self):
import gc
self.dur_model = None
self.dit= None
self.g2p_model = None
self.wavvae_en = None
self.wavvae_de = None
self.aligner_lm = None
self.audio_bytes = None
self.ph_ref = None
self.tone_ref = None
self.mel2ph_ref = None
self.vae_latent = None
self.ctx_dur_tokens = None
self.incremental_state_dur_prompt = None
gc.collect()
torch.cuda.empty_cache()
def build_model(self, device):
set_hparams(exp_name=self.dit_exp_name, print_hparams=False)
''' Load Dict '''
current_dir = os.path.dirname(os.path.abspath(__file__))
ling_dict = json.load(open(f"{current_dir}/tts/utils/text_utils/dict.json", encoding='utf-8-sig'))
self.ling_dict = {k: TokenTextEncoder(None, vocab_list=ling_dict[k], replace_oov='') for k in ['phone', 'tone']}
self.token_encoder = token_encoder = self.ling_dict['phone']
ph_dict_size = len(token_encoder)
''' Load Duration LM '''
from tts.modules.ar_dur.ar_dur_predictor import ARDurPredictor
hp_dur_model = self.hp_dur_model = set_hparams(f'{self.dur_exp_name}/config.yaml', global_hparams=False)
hp_dur_model['frames_multiple'] = hparams['frames_multiple']
self.dur_model = ARDurPredictor(
hp_dur_model, hp_dur_model['dur_txt_hs'], hp_dur_model['dur_model_hidden_size'],
hp_dur_model['dur_model_layers'], ph_dict_size,
hp_dur_model['dur_code_size'],
use_rot_embed=hp_dur_model.get('use_rot_embed', False))
self.length_regulator = LengthRegulator()
load_ckpt(self.dur_model, f'{self.dur_exp_name}', 'dur_model')
self.dur_model.eval()
self.dur_model.to(device)
''' Load Diffusion Transformer '''
from tts.modules.llm_dit.dit import Diffusion
self.dit = Diffusion()
load_ckpt(self.dit, f'{self.dit_exp_name}', 'dit', strict=False)
self.dit.eval()
self.dit.to(device)
self.cfg_mask_token_phone = 302 - 1
self.cfg_mask_token_tone = 32 - 1
''' Load Frontend LM '''
from tts.modules.aligner.whisper_small import Whisper
self.aligner_lm = Whisper()
load_ckpt(self.aligner_lm, f'{self.frontend_exp_name}', 'model')
self.aligner_lm.eval()
self.aligner_lm.to(device)
self.kv_cache = None
self.hooks = None
''' Load G2P LM'''
from transformers import AutoTokenizer, AutoModelForCausalLM
g2p_tokenizer = AutoTokenizer.from_pretrained(self.g2p_exp_name, padding_side="right")
g2p_tokenizer.padding_side = "right"
self.g2p_model = AutoModelForCausalLM.from_pretrained(self.g2p_exp_name).eval().to(device)
self.g2p_tokenizer = g2p_tokenizer
self.speech_start_idx = g2p_tokenizer.encode('')[0]
''' Wav VAE '''
self.hp_wavvae = hp_wavvae = set_hparams(f'{self.wavvae_exp_name}/config.yaml', global_hparams=False)
from tts.modules.wavvae.decoder.wavvae_v3 import WavVAE_V3
self.wavvae_en = WavVAE_V3(hparams=hp_wavvae)
self.wavvae_de = WavVAE_V3(hparams=hp_wavvae)
if os.path.exists(f'{self.wavvae_exp_name}/model_only_last.ckpt'):
load_ckpt(self.wavvae_en, f'{self.wavvae_exp_name}/model_only_last.ckpt', 'model_gen', strict=True)
self.has_vae_encoder = True
self.wavvae_en.eval()
self.wavvae_en.to(device)
else:
load_ckpt(self.wavvae_de, f'{self.wavvae_exp_name}/decoder.ckpt', 'model_gen', strict=False)
self.has_vae_encoder = False
self.wavvae_de.eval()
self.wavvae_de.to(device)
self.vae_stride = hp_wavvae.get('vae_stride', 4)
self.hop_size = hp_wavvae.get('hop_size', 4)
def preprocess(self, audio_bytes, latent_file=None, topk_dur=1, **kwargs):
if self.audio_bytes != audio_bytes:
self.audio_bytes = audio_bytes
wav_bytes = convert_to_wav_bytes(audio_bytes)
''' Load wav '''
wav, _ = librosa.core.load(wav_bytes, sr=self.sr)
# Pad wav if necessary
ws = hparams['win_size']
if len(wav) % ws < ws - 1:
wav = np.pad(wav, (0, ws - 1 - (len(wav) % ws)), mode='constant', constant_values=0.0).astype(np.float32)
wav = np.pad(wav, (0, 12000), mode='constant', constant_values=0.0).astype(np.float32)
self.loudness_prompt = self.loudness_meter.integrated_loudness(wav.astype(float))
''' obtain alignments with aligner_lm '''
ph_ref, tone_ref, mel2ph_ref = align(self, wav)
self.kv_cache = None
self.hooks = None
with torch.inference_mode():
''' Forward WaveVAE to obtain: prompt latent '''
if self.has_vae_encoder:
if latent_file is None:
wav = torch.FloatTensor(wav)[None].to(self.device)
vae_latent = self.wavvae_en.encode_latent(wav)
else:
vae_latent = torch.from_numpy(np.load(latent_file)).to(self.device)
vae_latent = vae_latent[:, :mel2ph_ref.size(1)//4]
else:
assert latent_file is not None, "WaveVAE encode model does not exist, an npy file must be provided!!!"
vae_latent = torch.from_numpy(np.load(latent_file)).to(self.device)
vae_latent = vae_latent[:, :mel2ph_ref.size(1)//4]
''' Duration Prompting '''
self.dur_model.hparams["infer_top_k"] = topk_dur if topk_dur > 1 else None
incremental_state_dur_prompt, ctx_dur_tokens = make_dur_prompt(self, mel2ph_ref, ph_ref, tone_ref)
self.ph_ref = ph_ref.to(self.device)
self.tone_ref = tone_ref.to(self.device)
self.mel2ph_ref = mel2ph_ref.to(self.device)
self.vae_latent = vae_latent.to(self.device)
self.ctx_dur_tokens = ctx_dur_tokens.to(self.device)
self.incremental_state_dur_prompt = incremental_state_dur_prompt
def forward(self, texts, time_step, p_w, t_w, dur_disturb=0.1, dur_alpha=1.0, **kwargs):
with torch.inference_mode():
''' Generating '''
waveforms = []
for input_text in texts:
wav_pred_ = []
language_type = classify_language(input_text)
if language_type == 'en':
input_text = self.en_normalizer.normalize(input_text)
text_segs = chunk_text_english(input_text, max_chars=130)
else:
input_text = self.zh_normalizer.normalize(input_text)
text_segs = chunk_text_chinesev2(input_text, limit=60)
for seg_i, text in enumerate(text_segs):
''' G2P '''
ph_pred, tone_pred = g2p(self, text)
''' Duration Prediction '''
mel2ph_pred = dur_pred(self, self.ctx_dur_tokens, self.incremental_state_dur_prompt, ph_pred, tone_pred, seg_i, dur_disturb, dur_alpha, is_first=seg_i==0, is_final=seg_i==len(text_segs)-1)
inputs = prepare_inputs_for_dit(self, self.mel2ph_ref, mel2ph_pred, self.ph_ref, self.tone_ref, ph_pred, tone_pred, self.vae_latent)
# Speech dit inference
with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
x = self.dit.inference(inputs, timesteps=time_step, seq_cfg_w=[p_w, t_w]).float()
# WavVAE decode
x[:, :self.vae_latent.size(1)] = self.vae_latent
if self.has_vae_encoder:
wav_pred = self.wavvae_en.decode(x)[0,0].to(torch.float32)
else:
wav_pred = self.wavvae_de.decode(x)[0,0].to(torch.float32)
''' Post-processing '''
# Trim prompt wav
wav_pred = wav_pred[self.vae_latent.size(1)*self.vae_stride*self.hop_size:].cpu().numpy()
# Norm generated wav to prompt wav's level
meter = pyln.Meter(self.sr) # create BS.1770 meter
loudness_pred = self.loudness_meter.integrated_loudness(wav_pred.astype(float))
wav_pred = pyln.normalize.loudness(wav_pred, loudness_pred, self.loudness_prompt)
if np.abs(wav_pred).max() >= 1:
wav_pred = wav_pred / np.abs(wav_pred).max() * 0.95
# Apply hamming window
wav_pred_.append(wav_pred)
gc.collect()
torch.cuda.empty_cache()
wav_pred = combine_audio_segments(wav_pred_, sr=self.sr).astype(np.float32)
waveform = torch.tensor(wav_pred)
waveforms.append(waveform.cpu())
return torch.cat(waveforms, dim=0), self.sr
class MegaTTS3SpeakersPreview:
@classmethod
def INPUT_TYPES(s):
speakers = get_speakers()
return {
"required": {"speaker":(speakers,),},}
RETURN_TYPES = ("AUDIO", "STRING", )
RETURN_NAMES = ("audio", "npy_file", )
FUNCTION = "preview"
CATEGORY = "🎤MW/MW-MegaTTS3"
def preview(self, speaker):
wav_path = os.path.join(speakers_dir, speaker)
latent_file = wav_path.rsplit('.', 1)[0] + '.npy'
if not os.path.exists(latent_file):
latent_file = ""
waveform, sample_rate = torchaudio.load(wav_path)
waveform = waveform.unsqueeze(0)
output_audio = {
"waveform": waveform,
"sample_rate": sample_rate
}
return (output_audio, latent_file)
def cache_audio_tensor(
cache_dir,
audio_tensor: torch.Tensor,
sample_rate: int,
filename_prefix: str = "cached_audio_",
audio_format: Optional[str] = ".wav"
) -> str:
import tempfile
try:
with tempfile.NamedTemporaryFile(
prefix=filename_prefix,
suffix=audio_format,
dir=cache_dir,
delete=False
) as tmp_file:
temp_filepath = tmp_file.name
torchaudio.save(temp_filepath, audio_tensor, sample_rate)
return temp_filepath
except Exception as e:
raise Exception(f"Error caching audio tensor: {e}")
def statistical_compare(tensor1, tensor2):
"""通过统计特征快速比较"""
stats1 = {
'mean': tensor1.mean(),
'std': tensor1.std(),
'max': tensor1.max(),
'min': tensor1.min()
}
stats2 = {
'mean': tensor2.mean(),
'std': tensor2.std(),
'max': tensor2.max(),
'min': tensor2.min()
}
return all(torch.allclose(stats1[k], stats2[k], rtol=1e-3) for k in stats1)
INFER_INS_CACHE = None
class MegaTTS3Run:
def __init__(self):
self.resource_context = None
self.audio_tensor = None
self.audio_prompt = None
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"audio": ("AUDIO",),
"text": ("STRING", {"forceInput": True}),
"time_step": ("INT", {"default": 32, "min": 1,}),
"p_w": ("FLOAT", {"default":1.6, "min": 0.1,}),
"t_w": ("FLOAT", {"default": 2.5, "min": 0.1,}),
"unload_model": ("BOOLEAN", {"default": True}),
},
"optional": {
"dialogue_audio_s2":("AUDIO",),
"audio_npy_file": ("STRING", {"forceInput": True, "tooltip": "No `npy_file` will use VAE to encode audio. 不提供 .npy 文件, 将使用 WaveVAE 编码音频"}),
"audio_s2_npy_file": ("STRING", {"forceInput": True, "tooltip": "No `npy_file` will use VAE to encode audio. 不提供 .npy 文件, 将使用 WaveVAE 编码音频"}),
}
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
FUNCTION = "clone"
CATEGORY = "🎤MW/MW-MegaTTS3"
def clone(self, audio, text, time_step, p_w, t_w, unload_model, audio_npy_file=None, dialogue_audio_s2=None, audio_s2_npy_file=None):
if not os.path.exists(os.path.join(model_path, "MegaTTS3", 'wavvae', 'model_only_last.ckpt')):
print("WaveVAE encode model does not exist, an npy file must be provided!!!")
waveform = audio["waveform"].squeeze(0)
global INFER_INS_CACHE
if INFER_INS_CACHE is None:
INFER_INS_CACHE = MegaTTS3DiTInfer()
latent_file = audio_npy_file if audio_npy_file else None
try:
import gc
if dialogue_audio_s2 is None:
# 只有音频改变时, 才重新预处理
if self.audio_tensor is None or self.audio_prompt is None or statistical_compare(self.audio_tensor, waveform) == False:
self.audio_tensor = waveform
self.audio_prompt = cache_audio_tensor(cache_dir, waveform, audio["sample_rate"])
texts = [i.strip() for i in re.split(r'\n\s*\n', text.strip()) if i.strip()]
with open(self.audio_prompt, 'rb') as file:
file_content = file.read()
INFER_INS_CACHE.preprocess(file_content, latent_file=latent_file)
del file_content
gc.collect()
torch.cuda.empty_cache()
waveform, sr = INFER_INS_CACHE.forward(texts=texts, time_step=time_step, p_w=p_w, t_w=t_w)
gc.collect()
torch.cuda.empty_cache()
else:
latent_file_2 = audio_s2_npy_file if audio_s2_npy_file else None
audio_1 = cache_audio_tensor(cache_dir, waveform, audio["sample_rate"])
audio_2 = cache_audio_tensor(cache_dir, dialogue_audio_s2["waveform"].squeeze(0), dialogue_audio_s2["sample_rate"])
with open(audio_1, 'rb') as file:
file_content_1 = file.read()
with open(audio_2, 'rb') as file:
file_content_2 = file.read()
gc.collect()
torch.cuda.empty_cache()
ress = []
for t, a, n in self.get_speaker_text_audio(text, audio_1, audio_2):
texts = [i.strip() for i in re.split(r'\n\s*\n', t.strip()) if i.strip()]
if a == audio_1:
INFER_INS_CACHE.preprocess(file_content_1, latent_file=latent_file)
res_sub, sr = INFER_INS_CACHE.forward(texts=texts, time_step=time_step, p_w=p_w, t_w=t_w)
ress.append([res_sub, n])
else:
INFER_INS_CACHE.preprocess(file_content_2, latent_file=latent_file_2)
res_sub, sr = INFER_INS_CACHE.forward(texts=texts, time_step=time_step, p_w=p_w, t_w=t_w)
ress.append([res_sub, n])
del file_content_1
del file_content_2
gc.collect()
torch.cuda.empty_cache()
waveform = torch.cat(list(zip(*sorted(ress, key=lambda x: x[1])))[0], dim=0)
except Exception as e:
if unload_model:
import gc
INFER_INS_CACHE.clean()
INFER_INS_CACHE = None
self.resource_context = None
gc.collect()
torch.cuda.empty_cache()
raise e
if unload_model:
import gc
INFER_INS_CACHE.clean()
INFER_INS_CACHE = None
self.resource_context = None
gc.collect()
torch.cuda.empty_cache()
return ({"waveform": waveform.unsqueeze(0).unsqueeze(0), "sample_rate": sr},)
def get_speaker_text_audio(self, text, audio_1, audio_2):
pattern = r'(\[s?S?1\]|\[s?S?2\])\s*([\s\S]*?)(?=\[s?S?[12]\]|$)'
matches = re.findall(pattern, text)
if len(matches) == 0:
raise ValueError("No speaker tags found in the text: [S1]... [S2]...")
labels = []
contents = []
audios = []
for label, content in matches:
labels.append(label)
contents.append(content)
audios = [
audio_1 if i.lower() == '[s1]' else audio_2 for i in labels
]
return sorted(zip(contents, audios, range(len(contents))), key=lambda x: x[1])
class MultiLinePromptMG:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"multi_line_prompt": ("STRING", {
"multiline": True,
"default": ""}),
},
}
CATEGORY = "🎤MW/MW-MegaTTS3"
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("text",)
FUNCTION = "promptgen"
def promptgen(self, multi_line_prompt: str):
return (multi_line_prompt.strip(),)
NODE_CLASS_MAPPINGS = {
"MegaTTS3SpeakersPreview": MegaTTS3SpeakersPreview,
"MegaTTS3Run": MegaTTS3Run,
"MultiLinePromptMG": MultiLinePromptMG,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"MegaTTS3SpeakersPreview": "MegaTTS3 Speakers Preview",
"MegaTTS3Run": "MegaTTS3 Run",
"MultiLinePromptMG": "Multi Line Text",
}
================================================
FILE: pyproject.toml
================================================
[project]
name = "megatts3-mw"
description = "Lightweight and Efficient, 🎧Ultra High-Quality Voice Cloning, Chinese and English."
version = "2.0.0"
license = {file = "LICENSE"}
dependencies = ["setproctitle", "attrdict", "librosa", "pydub", "pyloudnorm", "x-transformers", "torchdiffeq", "openai-whisper>=20240930"]
[project.urls]
Repository = "https://github.com/billwuhao/ComfyUI_MegaTTS3"
# Used by Comfy Registry https://comfyregistry.org
[tool.comfy]
PublisherId = "mw"
DisplayName = "MW-ComfyUI_MegaTTS3"
Icon = ""
================================================
FILE: requirements.txt
================================================
setproctitle
attrdict
librosa
pyloudnorm
x-transformers
torchdiffeq
openai-whisper>=20240930
langdetect
pynini==2.1.6; platform_system!="Windows"
WeTextProcessing>=1.0.3; platform_system!="Windows"
================================================
FILE: tts/frontend_function.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
import whisper
import librosa
from copy import deepcopy
from tts.utils.text_utils.ph_tone_convert import split_ph_timestamp, split_ph
from tts.utils.audio_utils.align import mel2token_to_dur
''' Graphme to phoneme function '''
def g2p(self, text_inp):
# prepare inputs
txt_token = self.g2p_tokenizer('' + text_inp + '')['input_ids']
input_ids = torch.LongTensor([txt_token+[145+self.speech_start_idx]]).to(self.device)
# model forward
with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
outputs = self.g2p_model.generate(input_ids, max_new_tokens=256, do_sample=True, top_k=1, eos_token_id=800+1+self.speech_start_idx)
# process outputs
ph_tokens = outputs[:, len(txt_token):-1]-self.speech_start_idx
ph_pred, tone_pred = split_ph(ph_tokens[0])
ph_pred, tone_pred = ph_pred[None, :].to(self.device), tone_pred[None, :].to(self.device)
return ph_pred, tone_pred
''' Get phoneme2mel align of prompt speech '''
def align(self, wav):
with torch.inference_mode():
whisper_wav = librosa.resample(wav, orig_sr=self.sr, target_sr=16000)
mel = torch.FloatTensor(whisper.log_mel_spectrogram(whisper_wav).T).to(self.device)[None].transpose(1,2)
prompt_max_frame = mel.size(2) // self.fm * self.fm
mel = mel[:, :, :prompt_max_frame]
token = torch.LongTensor([[798]]).to(self.device)
audio_features = self.aligner_lm.embed_audio(mel)
for i in range(768):
with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
logits = self.aligner_lm.logits(token, audio_features, None)
token_pred = torch.argmax(F.softmax(logits[:, -1], dim=-1), 1)[None]
token = torch.cat([token, token_pred], dim=1)
if token_pred[0] == 799:
break
alignment_tokens = token
ph_ref, tone_ref, dur_ref, _ = split_ph_timestamp(deepcopy(alignment_tokens)[0, 1:-1])
ph_ref = torch.Tensor(ph_ref)[None].to(self.device)
tone_ref = torch.Tensor(tone_ref)[None].to(self.device)
if dur_ref.sum() < prompt_max_frame:
dur_ref[-1] += prompt_max_frame - dur_ref.sum()
elif dur_ref.sum() > prompt_max_frame:
len_diff = dur_ref.sum() - prompt_max_frame
while True:
for i in range(len(dur_ref)):
dur_ref[i] -= 1
len_diff -= 1
if len_diff == 0:
break
if len_diff == 0:
break
mel2ph_ref = self.length_regulator(dur_ref[None]).to(self.device)
mel2ph_ref = mel2ph_ref[:, :mel2ph_ref.size(1)//self.fm*self.fm]
return ph_ref, tone_ref, mel2ph_ref
''' Duration Prompting '''
def make_dur_prompt(self, mel2ph_ref, ph_ref, tone_ref):
dur_tokens_2d_ = mel2token_to_dur(mel2ph_ref, ph_ref.shape[1]).clamp(
max=self.hp_dur_model['dur_code_size'] - 1) + 1
ctx_dur_tokens = dur_tokens_2d_.clone().flatten(0, 1).to(self.device)
txt_tokens_flat_ = ph_ref.flatten(0, 1)
ctx_dur_tokens = ctx_dur_tokens[txt_tokens_flat_ > 0][None]
last_dur_pos_prompt = ctx_dur_tokens.shape[1]
dur_spk_pos_ids_flat = range(0, last_dur_pos_prompt)
dur_spk_pos_ids_flat = torch.LongTensor([dur_spk_pos_ids_flat]).to(self.device)
with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
_, incremental_state_dur_prompt = self.dur_model.infer(
ph_ref, {'tone': tone_ref}, None, None, None,
ctx_vqcodes=ctx_dur_tokens, spk_pos_ids_flat=dur_spk_pos_ids_flat, return_state=True)
return incremental_state_dur_prompt, ctx_dur_tokens
''' Duration Prediction '''
def dur_pred(self, ctx_dur_tokens, incremental_state_dur_prompt, ph_pred, tone_pred, seg_i, dur_disturb, dur_alpha, is_first, is_final):
last_dur_token = ctx_dur_tokens[:, -1:]
last_dur_pos_prompt = ctx_dur_tokens.shape[1]
incremental_state_dur = deepcopy(incremental_state_dur_prompt)
txt_len = ph_pred.shape[1]
dur_spk_pos_ids_flat = range(last_dur_pos_prompt, last_dur_pos_prompt + txt_len)
dur_spk_pos_ids_flat = torch.LongTensor([dur_spk_pos_ids_flat]).to(self.device)
last_dur_pos_prompt = last_dur_pos_prompt + txt_len
with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
dur_pred = self.dur_model.infer(
ph_pred, {'tone': tone_pred}, None, None, None,
incremental_state=incremental_state_dur,
first_decoder_inp=last_dur_token,
spk_pos_ids_flat=dur_spk_pos_ids_flat,
)
dur_pred = dur_pred - 1
dur_pred = dur_pred.clamp(0, self.hp_dur_model['dur_code_size'] - 1)
# if is_final:
# dur_pred[:, -1] = dur_pred[:, -1].clamp(64, 128)
# else:
# dur_pred[:, -1] = dur_pred[:, -1].clamp(48, 128)
# if seg_i > 0:
# dur_pred[:, 0] = 0
# ['。', '!', '?', 'sil']
# for sil_token in [148, 153, 166, 145]:
# dur_pred[ph_pred==sil_token].clamp_min(32)
# # [',', ';']
# for sil_token in [163, 165]:
# dur_pred[ph_pred==sil_token].clamp_min(16)
if not is_final:
# add 0.32ms for crossfade
dur_pred[:, -1] = dur_pred[:, -1] + 32
else:
dur_pred[:, -1] = dur_pred[:, -1].clamp(64, 128)
''' DiT target speech generation '''
dur_disturb_choice = (torch.rand_like(dur_pred.float()) > 0.5).float()
dur_disturb_r = 1 + torch.rand_like(dur_pred.float()) * dur_disturb
dur_pred = dur_pred * dur_disturb_r * dur_disturb_choice + \
dur_pred / dur_disturb_r * (1 - dur_disturb_choice)
dur_pred = torch.round(dur_pred * dur_alpha).clamp(0, 127)
# ['。', '!', '?', 'sil']
for sil_token in [148, 153, 166, 145]:
dur_pred[ph_pred==sil_token] = dur_pred[ph_pred==sil_token].clamp_min(64)
# [',', ';']
for sil_token in [163, 165]:
dur_pred[ph_pred==sil_token] = dur_pred[ph_pred==sil_token].clamp_min(32)
if is_first:
dur_pred[:, 0] = 8
dur_sum = dur_pred.sum()
npad = self.fm - dur_sum % self.fm
if npad < self.fm:
dur_pred[:, -1] += npad
mel2ph_pred = self.length_regulator(dur_pred).to(self.device)
return mel2ph_pred
def prepare_inputs_for_dit(self, mel2ph_ref, mel2ph_pred, ph_ref, tone_ref, ph_pred, tone_pred, vae_latent):
# Prepare duration token
mel2ph_pred = torch.cat((mel2ph_ref, mel2ph_pred+ph_ref.size(1)), dim=1)
mel2ph_pred = mel2ph_pred[:, :mel2ph_pred.size(1)//self.fm*self.fm].repeat(3, 1)
# Prepare phone and tone token
ph_pred = torch.cat((ph_ref, ph_pred), dim=1)
tone_pred = torch.cat((tone_ref, tone_pred), dim=1)
# Disable the English tone (set them to 3)"""
en_tone_idx = ~((tone_pred == 4) | ( (11 <= tone_pred) & (tone_pred <= 15)) | (tone_pred == 0))
tone_pred[en_tone_idx] = 3
# Prepare cfg inputs
ph_seq = torch.cat([ph_pred, ph_pred, torch.full(ph_pred.size(), self.cfg_mask_token_phone, device=self.device)], 0)
tone_seq = torch.cat([tone_pred, tone_pred, torch.full(tone_pred.size(), self.cfg_mask_token_tone, device=self.device)], 0)
target_size = mel2ph_pred.size(1)//self.vae_stride
vae_latent_ = vae_latent.repeat(3, 1, 1)
ctx_mask = torch.ones_like(vae_latent_[:, :, 0:1])
vae_latent_ = F.pad(vae_latent_, (0, 0, 0, target_size - vae_latent.size(1)), mode='constant', value=0)
vae_latent_[1:] = 0.0
ctx_mask = F.pad(ctx_mask, (0, 0, 0, target_size - vae_latent.size(1)), mode='constant', value=0)
return {
'phone': ph_seq,
'tone': tone_seq,
"lat_ctx": vae_latent_ * ctx_mask,
"ctx_mask": ctx_mask,
"dur": mel2ph_pred,
}
================================================
FILE: tts/modules/aligner/whisper_small.py
================================================
# MIT License
# Copyright (c) 2022 OpenAI
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# Copyright (c) [2022] [OpenAI]
# Copyright (c) [2025] [Ziyue Jiang]
# SPDX-License-Identifier: MIT
# This file has been modified by Ziyue Jiang on 2025/03/19
# Original file was released under MIT, with the full license text # available at https://github.com/openai/whisper/blob/v20240930/LICENSE.
# This modified file is released under the same license.
from contextlib import contextmanager
from typing import Dict, Iterable, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn.functional import scaled_dot_product_attention
SDPA_AVAILABLE = True
class LayerNorm(nn.LayerNorm):
def forward(self, x: Tensor) -> Tensor:
return super().forward(x.float()).type(x.dtype)
class Linear(nn.Linear):
def forward(self, x: Tensor) -> Tensor:
return F.linear(
x,
self.weight.to(x.dtype),
None if self.bias is None else self.bias.to(x.dtype),
)
class Conv1d(nn.Conv1d):
def _conv_forward(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
) -> Tensor:
return super()._conv_forward(
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
)
def sinusoids(length, channels, max_timescale=10000):
"""Returns sinusoids for positional embedding"""
assert channels % 2 == 0
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
@contextmanager
def disable_sdpa():
prev_state = MultiHeadAttention.use_sdpa
try:
MultiHeadAttention.use_sdpa = False
yield
finally:
MultiHeadAttention.use_sdpa = prev_state
class MultiHeadAttention(nn.Module):
use_sdpa = True
def __init__(self, n_state: int, n_head: int):
super().__init__()
self.n_head = n_head
self.query = Linear(n_state, n_state)
self.key = Linear(n_state, n_state, bias=False)
self.value = Linear(n_state, n_state)
self.out = Linear(n_state, n_state)
def forward(
self,
x: Tensor,
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
casual: Optional[bool] = None
):
q = self.query(x)
if kv_cache is None or xa is None or self.key not in kv_cache:
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
# otherwise, perform key/value projections for self- or cross-attention as usual.
k = self.key(x if xa is None else xa)
v = self.value(x if xa is None else xa)
else:
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
k = kv_cache[self.key]
v = kv_cache[self.value]
wv = self.qkv_attention(q, k, v, mask, casual)
return self.out(wv)
def qkv_attention(
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, casual: Optional[bool] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
n_batch, n_ctx, n_state = q.shape
scale = (n_state // self.n_head) ** -0.25
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
a = scaled_dot_product_attention(
q, k, v, is_causal=casual and n_ctx > 1, attn_mask=mask[:, None, None, :] if mask is not None else None
)
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
return out
class ResidualAttentionBlock(nn.Module):
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
super().__init__()
self.attn = MultiHeadAttention(n_state, n_head)
self.attn_ln = LayerNorm(n_state)
self.cross_attn = (
MultiHeadAttention(n_state, n_head) if cross_attention else None
)
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
n_mlp = n_state * 4
self.mlp = nn.Sequential(
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
)
self.mlp_ln = LayerNorm(n_state)
def forward(
self,
x: Tensor,
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
casual: Optional[bool] = None,
):
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache, casual=casual)
if self.cross_attn:
# TODO: Cross attention mask
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache, casual=False)
x = x + self.mlp(self.mlp_ln(x))
return x
class AudioEncoder(nn.Module):
def __init__(
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__()
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
)
self.ln_post = LayerNorm(n_state)
def forward(self, x: Tensor, attn_mask: Tensor):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
"""
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)
# assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
x = (x + self.positional_embedding[:x.size(1)]).to(x.dtype)
for block in self.blocks:
x = block(x, mask=attn_mask, casual=False)
x = self.ln_post(x)
return x
class TextDecoder(nn.Module):
def __init__(
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__()
self.token_embedding = nn.Embedding(n_vocab, n_state)
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
for _ in range(n_layer)
]
)
self.ln = LayerNorm(n_state)
self.out_proj = nn.Linear(n_state, n_vocab)
def forward(self, x: Tensor, attn_mask: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
"""
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
the text tokens
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
the encoded audio features to be attended on
"""
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = (
self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.shape[-1]]
)
x = x.to(xa.dtype)
for block in self.blocks:
x = block(x, xa, mask=attn_mask, kv_cache=kv_cache, casual=True)
x = self.ln(x)
# logits = (
# x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
# ).float()
logits = self.out_proj(x)
return logits
class Whisper(nn.Module):
def __init__(self):
super().__init__()
self.n_vocab = 6800
self.n_text_layer = 6
self.n_text_head = 8
self.n_text_ctx = 2048
self.encoder = AudioEncoder(
n_mels=80, n_ctx=3000, n_state=512, n_head=8, n_layer=6,
)
self.decoder = TextDecoder(
n_vocab=6800, n_ctx=2048, n_state=512, n_head=8, n_layer=6,
)
def embed_audio(self, mel: torch.Tensor):
return self.encoder(mel, None)
def logits(self, tokens, audio_features, kv_cache=None):
return self.decoder(tokens, None, audio_features, kv_cache=kv_cache)
def forward(
self, mel, mel_len, token, token_len
) -> Dict[str, torch.Tensor]:
attn_mask_enc = self.sequence_mask(mel_len//2, device=mel.device) > 0
attn_mask_dec = self.sequence_mask(token_len, device=mel.device) > 0
return self.decoder(token, attn_mask_dec, self.encoder(mel, attn_mask_enc))
@property
def device(self):
return next(self.parameters()).device
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
"""
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
tensors calculated for the previous positions. This method returns a dictionary that stores
all caches, and the necessary hooks for the key and value projection modules that save the
intermediate tensors to be reused during later calculations.
Returns
-------
cache : Dict[nn.Module, torch.Tensor]
A dictionary object mapping the key/value projection modules to its cache
hooks : List[RemovableHandle]
List of PyTorch RemovableHandle objects to stop the hooks to be called
"""
cache = {**cache} if cache is not None else {}
hooks = []
def save_to_cache(module, _, output):
if module not in cache or output.shape[1] > self.n_text_ctx:
# save as-is, for the first token or cross attention
cache[module] = output
else:
cache[module] = torch.cat([cache[module], output], dim=1).detach()
return cache[module]
def install_hooks(layer: nn.Module):
if isinstance(layer, MultiHeadAttention):
hooks.append(layer.key.register_forward_hook(save_to_cache))
hooks.append(layer.value.register_forward_hook(save_to_cache))
self.decoder.apply(install_hooks)
return cache, hooks
def sequence_mask(self, seq_lens, max_len=None, device='cpu'):
b = seq_lens.shape[0]
if max_len is None:
max_len = seq_lens.max()
mask = torch.arange(max_len).unsqueeze(0).to(device) # [1, t]
mask = mask < (seq_lens.unsqueeze(1)) # [1, t] + [b, 1] = [b, t]
mask = mask.float()
return mask
================================================
FILE: tts/modules/ar_dur/ar_dur_predictor.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from copy import deepcopy
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import Linear
from tqdm import tqdm
from tts.modules.ar_dur.commons.layers import Embedding, LayerNorm
from tts.modules.ar_dur.commons.nar_tts_modules import PosEmb
from tts.modules.ar_dur.commons.rot_transformer import RotTransformerDecoderLayer
from tts.modules.ar_dur.commons.transformer import SinusoidalPositionalEmbedding
from tts.modules.ar_dur.commons.rel_transformer import RelTransformerEncoder
FS_ENCODERS = {
'rel_fft': lambda hp, dict_size: RelTransformerEncoder(
dict_size, hp['hidden_size'], hp['hidden_size'],
hp['ffn_hidden_size'], hp['num_heads'], hp['enc_layers'],
hp['enc_ffn_kernel_size'], hp['dropout'], prenet=hp['enc_prenet'], pre_ln=hp['enc_pre_ln']),
}
def fill_with_neg_inf2(t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(-1e8).type_as(t)
def expand_states(h, mel2token):
h = F.pad(h, [0, 0, 1, 0])
mel2token_ = mel2token[..., None].repeat([1, 1, h.shape[-1]])
h = torch.gather(h, 1, mel2token_) # [B, T, H]
return h
class CodePredictor(nn.Module):
def __init__(self, hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size):
super().__init__()
self.hparams = deepcopy(hparams)
self.hparams['hidden_size'] = hidden_size
self.hidden_size = hidden_size
char_dict_size = hparams.get('char_dict_size', 4000)
if not hparams.get('lm_use_enc'):
self.encoder = nn.Embedding(dict_size, self.hidden_size, padding_idx=0)
if hparams.get('mega_use_char', True):
self.char_encoder = nn.Embedding(char_dict_size,
self.hidden_size, padding_idx=0)
else:
self.encoder = FS_ENCODERS[self.hparams['encoder_type']](self.hparams, dict_size)
if hparams.get('mega_use_char', True):
self.char_encoder = FS_ENCODERS[self.hparams['encoder_type']](self.hparams, char_dict_size)
if hparams['use_ph_pos_embed']:
self.ph_pos_embed = PosEmb(self.hidden_size)
self.char_empty_embed = nn.Embedding(1, self.hidden_size)
if hparams.get('use_bert_input'):
self.bert_input_proj = nn.Linear(768, self.hidden_size)
self.ling_label_embed_layers = nn.ModuleDict()
for k, s in zip(hparams['ling_labels'], hparams['ling_label_dict_size']):
self.ling_label_embed_layers[k] = Embedding(s + 3, self.hidden_size, padding_idx=0)
self.dec_hidden_size = dec_hidden_size
self.enc_proj = nn.Linear(self.hidden_size, dec_hidden_size)
self.code_emb = Embedding(code_size + 2, dec_hidden_size, 0)
self.use_pos_embed = hparams.get('use_pos_embed', False)
if self.use_pos_embed:
self.embed_positions = SinusoidalPositionalEmbedding(dec_hidden_size, 0, init_size=1024)
self.use_post_ln = hparams.get('use_post_ln', False)
self.layers = None
if not self.use_post_ln:
self.layer_norm = LayerNorm(dec_hidden_size)
self.code_size = code_size
self.project_out_dim = Linear(dec_hidden_size, code_size + 1, bias=True)
def forward_ling_encoder(
self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, spk_id, spk_embed, mels_timbre):
ph_tokens = txt_tokens
hparams = self.hparams
ph_nonpadding = (ph_tokens > 0).float()[:, :, None] # [B, T_phone, 1]
x_spk = self.forward_style_embed(spk_embed, spk_id, mels_timbre)
# enc_ph
if not hparams.get('lm_use_enc'):
x_ph = self.encoder(ph_tokens)
x_ph = x_ph + sum(
[self.ling_label_embed_layers[k](ling_feas[k]) for k in hparams['ling_labels']]) \
if len(hparams['ling_labels']) > 0 else 0
x_ph = x_ph + x_spk
else:
# enc_ph
ph_enc_oembed = sum(
[self.ling_label_embed_layers[k](ling_feas[k]) for k in hparams['ling_labels']]) \
if len(hparams['ling_labels']) > 0 else 0
ph_enc_oembed = ph_enc_oembed + self.ph_pos_embed(
torch.arange(0, ph_tokens.shape[1])[None,].to(ph_tokens.device))
ph_enc_oembed = ph_enc_oembed + x_spk
ph_enc_oembed = ph_enc_oembed * ph_nonpadding
x_ph = self.encoder(ph_tokens, other_embeds=ph_enc_oembed)
# enc_char
if char_tokens is not None and ph2char is not None:
char_nonpadding = (char_tokens > 0).float()[:, :, None]
x_char = self.char_encoder(char_tokens)
empty_char = (ph2char > 100000).long()
ph2char = ph2char * (1 - empty_char)
x_char_phlevel = \
expand_states(x_char * char_nonpadding, ph2char) \
* (1 - empty_char)[..., None] + \
self.char_empty_embed(torch.zeros_like(ph_tokens)) * empty_char[..., None]
else:
x_char_phlevel = 0
# x_ling
x_ling = x_ph + x_char_phlevel
x_ling = x_ling * ph_nonpadding
x_ling = self.enc_proj(x_ling)
return x_ling
def sample_one_step(self, vq_pred):
hparams = self.hparams
if hparams.get('infer_top_k'):
top_k = hparams.get('infer_top_k')
temperature = hparams.get('infer_temperature', 1)
vq_pred = vq_pred[:, -1] / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(vq_pred, min(top_k, vq_pred.size(-1)))
vq_pred[vq_pred < v[:, [-1]]] = -float('Inf')
# apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(vq_pred, dim=-1)
# sample from the distribution
vq_pred = torch.multinomial(probs, num_samples=1)
else:
vq_pred = torch.argmax(F.softmax(vq_pred[:, -1], dim=-1), 1)
return vq_pred
def forward_style_embed(self, spk_embed=None, spk_id=None, mel_ref=None):
# add spk embed
style_embed = 0
if self.hparams['use_spk_embed']:
style_embed = style_embed + self.spk_embed_proj(spk_embed)[:, None, :]
if self.hparams['use_spk_id']:
style_embed = style_embed + self.spk_id_proj(spk_id)[:, None, :]
if self.hparams['use_spk_enc']:
style_embed = style_embed + self.spk_enc(mel_ref)[:, None, :]
return style_embed
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if (
not hasattr(self, '_future_mask')
or self._future_mask is None
or self._future_mask.device != tensor.device
or self._future_mask.size(0) < dim
):
self._future_mask = torch.triu(fill_with_neg_inf2(tensor.new(dim, dim)), 1)
return self._future_mask[:dim, :dim]
class ARDurPredictor(CodePredictor):
def __init__(self, hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size, use_rot_embed=True,
op_version=1):
super().__init__(hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size)
self.use_rot_embed = use_rot_embed
bias = hparams.get('lm_bias', True)
if self.use_rot_embed:
self.layers = nn.ModuleList([])
self.layers.extend([
RotTransformerDecoderLayer(
dec_hidden_size, 0.0, kernel_size=1, ffn_hidden_size=dec_hidden_size * 4,
post_ln=self.use_post_ln, op_version=op_version, bias=bias)
for _ in range(lm_num_layers)
])
if hparams['dur_model_type'] == 'ar_mse':
self.project_out_dim = nn.Sequential(torch.nn.Linear(dec_hidden_size, 1), nn.Softplus())
else:
self.project_out_dim = torch.nn.Linear(dec_hidden_size, code_size + 1)
def forward(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
prev_code, spk_id=None, spk_embed=None, mels_timbre=None, mel2ph=None,
incremental_state=None, x_ling=None, attn_mask=None, spk_pos_ids_flat=None,
prompt_length=None, cache_size=20, streaming=False):
x = self.code_emb(prev_code)
if x_ling is None:
x_ling = self.forward_ling_encoder(
txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, spk_id, spk_embed, mels_timbre)
x_ling = x_ling.flatten(0, 1)
txt_tokens = txt_tokens.flatten(0, 1)
x_ling = x_ling[txt_tokens > 0][None]
# run decoder
self_attn_padding_mask = None
if self.use_pos_embed:
positions = self.embed_positions(
prev_code,
incremental_state=incremental_state
)
if incremental_state is not None:
x_ling = x_ling[:, x.shape[1] - 1:x.shape[1]]
if spk_pos_ids_flat is not None:
spk_pos_ids_flat = spk_pos_ids_flat[:, x.shape[1] - 1:x.shape[1]]
x = x[:, -1:]
if self.use_pos_embed:
positions = positions[:, -1:]
if streaming:
# Shift Pos: query pos is min(cache_size, idx)
spk_pos_ids_flat = torch.min(torch.LongTensor([prompt_length + cache_size]).to(x.device),
spk_pos_ids_flat)
# # B x T x C -> T x B x C
if self.use_pos_embed:
x = x + positions
x_ling = x_ling[:, :self.hparams['max_tokens']].contiguous()
T = min(self.hparams.get('max_tokens_per_item', 1e9), x_ling.shape[1])
x_ling = x_ling.reshape(-1, T, x_ling.shape[-1])
x = x + x_ling
x = x.transpose(0, 1)
for idx, layer in enumerate(self.layers):
if incremental_state is None:
self_attn_mask = self.buffered_future_mask(x)
if attn_mask is not None:
self_attn_mask = self_attn_mask + (1 - attn_mask.float()) * -1e8
self_attn_mask = self_attn_mask.clamp_min(-1e8)
else:
self_attn_mask = None
x, attn_weights = layer(
x,
incremental_state=incremental_state,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
spk_pos_ids_flat=spk_pos_ids_flat
)
if streaming and incremental_state != {}:
for k, v in incremental_state.items():
if 'attn_state' in k:
prev_key, prev_value = incremental_state[k]['prev_key'], incremental_state[k]['prev_value']
cur_length = prev_key.shape[2]
if cur_length - prompt_length > cache_size:
prev_key = torch.cat((prev_key[:, :, :prompt_length], prev_key[:, :, -cache_size:]), dim=2)
prev_value = torch.cat((prev_value[:, :, :prompt_length], prev_value[:, :, -cache_size:]),
dim=2)
incremental_state[k]['prev_key'], incremental_state[k]['prev_value'] = prev_key, prev_value
if not self.use_post_ln:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
x = self.project_out_dim(x)
return x
def infer(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
spk_id=None, spk_embed=None, mels_timbre=None,
incremental_state=None, ctx_vqcodes=None, spk_pos_ids_flat=None, return_state=False,
first_step_min=0, return_probs=False, first_decoder_inp=None, dur_disturb=0.0, **kwargs):
if incremental_state is None:
incremental_state = {}
x_ling = self.forward_ling_encoder(
txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
spk_id, spk_embed, mels_timbre)
x_ling = x_ling.flatten(0, 1)
txt_tokens_ori = txt_tokens
txt_tokens_withpad = txt_tokens = txt_tokens.flatten(0, 1)
x_ling = x_ling[txt_tokens > 0][None]
txt_tokens = txt_tokens[txt_tokens > 0][None]
decoded = torch.zeros_like(txt_tokens)
decoded = F.pad(decoded, [1, 0], value=self.code_size + 1)
if incremental_state != {}:
if first_decoder_inp is None:
assert ctx_vqcodes is not None
decoded[:, :ctx_vqcodes.shape[1]] = ctx_vqcodes
ctx_vqcodes = None
else:
decoded[:, :1] = first_decoder_inp
probs = []
for step in range(decoded.shape[1] - 1):
vq_pred = self(txt_tokens, None, None, None, None,
decoded[:, :step + 1], None, None, None,
incremental_state=incremental_state, x_ling=x_ling,
spk_pos_ids_flat=spk_pos_ids_flat, **kwargs)
probs.append(vq_pred.cpu())
if ctx_vqcodes is None or step >= ctx_vqcodes.shape[1]:
if self.hparams['dur_model_type'] == 'ar_mse':
d = vq_pred[:, -1, 0]
if dur_disturb > 0 and step >= 1:
if random.random() > 0.5:
d = d * (1 + random.random() * dur_disturb)
else:
d = d / (1 + random.random() * dur_disturb)
d = torch.clamp_max(d, self.code_size - 1)
vq_pred = torch.round(d).long()
else:
vq_pred = self.sample_one_step(vq_pred)
decoded[:, step + 1] = torch.clamp_min(vq_pred, 1)
if step == 0:
decoded[:, step + 1] = torch.clamp_min(vq_pred, first_step_min)
else:
decoded[:, step + 1] = ctx_vqcodes[:, step]
decoded = decoded[:, 1:]
decoded_2d = torch.zeros_like(txt_tokens_ori)
decoded_2d.flatten(0, 1)[txt_tokens_withpad > 0] = decoded
if return_state:
return decoded_2d, incremental_state
if return_probs:
return decoded_2d, torch.cat(probs, 1)
return decoded_2d
def streaming_infer(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
spk_id=None, spk_embed=None, mels_timbre=None,
incremental_state=None, ctx_vqcodes=None, spk_pos_ids_flat=None, return_state=False,
**kwargs):
if incremental_state is None:
incremental_state = {}
x_ling = self.forward_ling_encoder(
txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
spk_id, spk_embed, mels_timbre)
x_ling = x_ling.flatten(0, 1)
txt_tokens_ori = txt_tokens
txt_tokens_withpad = txt_tokens = txt_tokens.flatten(0, 1)
x_ling = x_ling[txt_tokens > 0][None]
txt_tokens = txt_tokens[txt_tokens > 0][None]
vq_decoded = torch.zeros_like(txt_tokens)
vq_decoded = F.pad(vq_decoded, [1, 0], value=self.code_size + 1)
if incremental_state != {}:
assert ctx_vqcodes is not None
vq_decoded[:, :ctx_vqcodes.shape[1]] = ctx_vqcodes
ctx_vqcodes = None
prompt_length = list(incremental_state.items())[0][1]['prev_key'].shape[2]
for step in tqdm(range(vq_decoded.shape[1] - 1), desc='AR Duration Predictor inference...'):
vq_pred = self(txt_tokens, None, None, None, None,
vq_decoded[:, :step + 1], None, None, None,
incremental_state=incremental_state, x_ling=x_ling,
spk_pos_ids_flat=spk_pos_ids_flat, prompt_length=prompt_length, streaming=True, **kwargs)
if ctx_vqcodes is None or step >= ctx_vqcodes.shape[1]:
if self.hparams['dur_model_type'] == 'ar_mse':
vq_pred = torch.round(vq_pred[:, -1, 0]).long()
else:
vq_pred = self.sample_one_step(vq_pred)
vq_decoded[:, step + 1] = vq_pred
else:
vq_decoded[:, step + 1] = ctx_vqcodes[:, step]
vq_decoded = vq_decoded[:, 1:]
vq_decoded_2d = torch.zeros_like(txt_tokens_ori)
vq_decoded_2d.flatten(0, 1)[txt_tokens_withpad > 0] = vq_decoded
if return_state:
return vq_decoded_2d, incremental_state
return vq_decoded_2d
================================================
FILE: tts/modules/ar_dur/commons/layers.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
class LayerNorm(torch.nn.LayerNorm):
"""Layer normalization module.
:param int nout: output dim size
:param int dim: dimension to be normalized
"""
def __init__(self, nout, dim=-1, eps=1e-5):
"""Construct an LayerNorm object."""
super(LayerNorm, self).__init__(nout, eps=eps)
self.dim = dim
def forward(self, x):
"""Apply layer normalization.
:param torch.Tensor x: input tensor
:return: layer normalized tensor
:rtype torch.Tensor
"""
if self.dim == -1:
return super(LayerNorm, self).forward(x)
return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
class Reshape(nn.Module):
def __init__(self, *args):
super(Reshape, self).__init__()
self.shape = args
def forward(self, x):
return x.view(self.shape)
class Permute(nn.Module):
def __init__(self, *args):
super(Permute, self).__init__()
self.args = args
def forward(self, x):
return x.permute(self.args)
def Embedding(num_embeddings, embedding_dim, padding_idx=None):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
if padding_idx is not None:
nn.init.constant_(m.weight[padding_idx], 0)
return m
================================================
FILE: tts/modules/ar_dur/commons/nar_tts_modules.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
from torch import nn
import torch.nn.functional as F
class LengthRegulator(torch.nn.Module):
def __init__(self, pad_value=0.0):
super(LengthRegulator, self).__init__()
self.pad_value = pad_value
def forward(self, dur, dur_padding=None, alpha=1.0):
"""
Example (no batch dim version):
1. dur = [2,2,3]
2. token_idx = [[1],[2],[3]], dur_cumsum = [2,4,7], dur_cumsum_prev = [0,2,4]
3. token_mask = [[1,1,0,0,0,0,0],
[0,0,1,1,0,0,0],
[0,0,0,0,1,1,1]]
4. token_idx * token_mask = [[1,1,0,0,0,0,0],
[0,0,2,2,0,0,0],
[0,0,0,0,3,3,3]]
5. (token_idx * token_mask).sum(0) = [1,1,2,2,3,3,3]
:param dur: Batch of durations of each frame (B, T_txt)
:param dur_padding: Batch of padding of each frame (B, T_txt)
:param alpha: duration rescale coefficient
:return:
mel2ph (B, T_speech)
assert alpha > 0
"""
dur = torch.round(dur.float() * alpha).long()
if dur_padding is not None:
dur = dur * (1 - dur_padding.long())
token_idx = torch.arange(1, dur.shape[1] + 1)[None, :, None].to(dur.device)
dur_cumsum = torch.cumsum(dur, 1)
dur_cumsum_prev = F.pad(dur_cumsum, [1, -1], mode='constant', value=0)
pos_idx = torch.arange(dur.sum(-1).max())[None, None].to(dur.device)
token_mask = (pos_idx >= dur_cumsum_prev[:, :, None]) & (pos_idx < dur_cumsum[:, :, None])
mel2token = (token_idx * token_mask.long()).sum(1)
return mel2token
class PosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim) * -emb)
self.emb = emb # TODO
def forward(self, x):
emb = x[:, :, None] * self.emb[None, None, :].to(x.device)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
================================================
FILE: tts/modules/ar_dur/commons/rel_transformer.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
from torch import nn
from torch.nn import functional as F
from tts.modules.ar_dur.commons.layers import Embedding
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def shift_1d(x):
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
return x
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
class Encoder(nn.Module):
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.,
window_size=None, block_length=None, pre_ln=False, **kwargs):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.window_size = window_size
self.block_length = block_length
self.pre_ln = pre_ln
self.drop = nn.Dropout(p_dropout)
self.attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList()
self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers):
self.attn_layers.append(
MultiHeadAttention(hidden_channels, hidden_channels, n_heads, window_size=window_size,
p_dropout=p_dropout, block_length=block_length))
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(
FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
self.norm_layers_2.append(LayerNorm(hidden_channels))
if pre_ln:
self.last_ln = LayerNorm(hidden_channels)
def forward(self, x, x_mask, attn_mask=1):
if isinstance(attn_mask, torch.Tensor):
attn_mask = attn_mask[:, None]
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) * attn_mask
for i in range(self.n_layers):
x = x * x_mask
x_ = x
if self.pre_ln:
x = self.norm_layers_1[i](x)
y = self.attn_layers[i](x, x, attn_mask)
y = self.drop(y)
x = x_ + y
if not self.pre_ln:
x = self.norm_layers_1[i](x)
x_ = x
if self.pre_ln:
x = self.norm_layers_2[i](x)
y = self.ffn_layers[i](x, x_mask)
y = self.drop(y)
x = x_ + y
if not self.pre_ln:
x = self.norm_layers_2[i](x)
if self.pre_ln:
x = self.last_ln(x)
x = x * x_mask
return x
class MultiHeadAttention(nn.Module):
def __init__(self, channels, out_channels, n_heads, window_size=None, heads_share=True, p_dropout=0.,
block_length=None, proximal_bias=False, proximal_init=False):
super().__init__()
assert channels % n_heads == 0
self.channels = channels
self.out_channels = out_channels
self.n_heads = n_heads
self.window_size = window_size
self.heads_share = heads_share
self.block_length = block_length
self.proximal_bias = proximal_bias
self.p_dropout = p_dropout
self.attn = None
self.k_channels = channels // n_heads
self.conv_q = nn.Conv1d(channels, channels, 1)
self.conv_k = nn.Conv1d(channels, channels, 1)
self.conv_v = nn.Conv1d(channels, channels, 1)
if window_size is not None:
n_heads_rel = 1 if heads_share else n_heads
rel_stddev = self.k_channels ** -0.5
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
self.conv_o = nn.Conv1d(channels, out_channels, 1)
self.drop = nn.Dropout(p_dropout)
nn.init.xavier_uniform_(self.conv_q.weight)
nn.init.xavier_uniform_(self.conv_k.weight)
if proximal_init:
self.conv_k.weight.data.copy_(self.conv_q.weight.data)
self.conv_k.bias.data.copy_(self.conv_q.bias.data)
nn.init.xavier_uniform_(self.conv_v.weight)
def forward(self, x, c, attn_mask=None):
q = self.conv_q(x)
k = self.conv_k(c)
v = self.conv_v(c)
x, self.attn = self.attention(q, k, v, mask=attn_mask)
x = self.conv_o(x)
return x
def attention(self, query, key, value, mask=None):
# reshape [b, d, t] -> [b, n_h, t, d_k]
b, d, t_s, t_t = (*key.size(), query.size(2))
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
if self.window_size is not None:
assert t_s == t_t, "Relative attention is only available for self-attention."
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
rel_logits = self._relative_position_to_absolute_position(rel_logits)
scores_local = rel_logits / math.sqrt(self.k_channels)
scores = scores + scores_local
if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4)
if self.block_length is not None:
block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
scores = scores * block_mask + -1e4 * (1 - block_mask)
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
p_attn = self.drop(p_attn)
output = torch.matmul(p_attn, value)
if self.window_size is not None:
relative_weights = self._absolute_position_to_relative_position(p_attn)
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
return output, p_attn
def _matmul_with_relative_values(self, x, y):
"""
x: [b, h, l, m]
y: [h or 1, m, d]
ret: [b, h, l, d]
"""
ret = torch.matmul(x, y.unsqueeze(0))
return ret
def _matmul_with_relative_keys(self, x, y):
"""
x: [b, h, l, d]
y: [h or 1, m, d]
ret: [b, h, l, m]
"""
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
return ret
def _get_relative_embeddings(self, relative_embeddings, length):
max_relative_position = 2 * self.window_size + 1
# Pad first before slice to avoid using cond ops.
pad_length = max(length - (self.window_size + 1), 0)
slice_start_position = max((self.window_size + 1) - length, 0)
slice_end_position = slice_start_position + 2 * length - 1
if pad_length > 0:
padded_relative_embeddings = F.pad(
relative_embeddings,
convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
else:
padded_relative_embeddings = relative_embeddings
used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
return used_relative_embeddings
def _relative_position_to_absolute_position(self, x):
"""
x: [b, h, l, 2*l-1]
ret: [b, h, l, l]
"""
batch, heads, length, _ = x.size()
# Concat columns of pad to shift from relative to absolute indexing.
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
# Concat extra elements so to add up to shape (len+1, 2*len-1).
x_flat = x.view([batch, heads, length * 2 * length])
x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
# Reshape and slice out the padded elements.
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:]
return x_final
def _absolute_position_to_relative_position(self, x):
"""
x: [b, h, l, l]
ret: [b, h, l, 2*l-1]
"""
batch, heads, length, _ = x.size()
# padd along column
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
x_flat = x.view([batch, heads, -1])
# add 0's in the beginning that will skew the elements after reshape
x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
return x_final
def _attention_bias_proximal(self, length):
"""Bias for self-attention to encourage attention to close positions.
Args:
length: an integer scalar.
Returns:
a Tensor with shape [1, 1, length, length]
"""
r = torch.arange(length, dtype=torch.float32)
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
class FFN(nn.Module):
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.activation = activation
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.conv_2 = nn.Conv1d(filter_channels, out_channels, 1)
self.drop = nn.Dropout(p_dropout)
def forward(self, x, x_mask):
x = self.conv_1(x * x_mask)
if self.activation == "gelu":
x = x * torch.sigmoid(1.702 * x)
else:
x = torch.relu(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
return x * x_mask
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-4):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x):
n_dims = len(x.shape)
mean = torch.mean(x, 1, keepdim=True)
variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
x = (x - mean) * torch.rsqrt(variance + self.eps)
shape = [1, -1] + [1] * (n_dims - 2)
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
return x
class ConvReluNorm(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
assert n_layers > 1, "Number of layers should be larger than 0."
self.conv_layers = nn.ModuleList()
self.norm_layers = nn.ModuleList()
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
self.norm_layers.append(LayerNorm(hidden_channels))
self.relu_drop = nn.Sequential(
nn.ReLU(),
nn.Dropout(p_dropout))
for _ in range(n_layers - 1):
self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
self.norm_layers.append(LayerNorm(hidden_channels))
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask):
x_org = x
for i in range(self.n_layers):
x = self.conv_layers[i](x * x_mask)
x = self.norm_layers[i](x)
x = self.relu_drop(x)
x = x_org + self.proj(x)
return x * x_mask
class RelTransformerEncoder(nn.Module):
def __init__(self,
n_vocab,
out_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout=0.0,
window_size=4,
block_length=None,
in_channels=None,
prenet=True,
pre_ln=True,
):
super().__init__()
self.n_vocab = n_vocab
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.window_size = window_size
self.block_length = block_length
self.prenet = prenet
if n_vocab > 0:
self.emb = Embedding(n_vocab, hidden_channels, padding_idx=0)
if prenet:
if in_channels is None:
in_channels = hidden_channels
self.pre = ConvReluNorm(in_channels, in_channels, in_channels,
kernel_size=5, n_layers=3, p_dropout=0)
if in_channels is not None and in_channels != hidden_channels:
self.encoder_inp_proj = nn.Conv1d(in_channels, hidden_channels, 1)
self.encoder = Encoder(
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
window_size=window_size,
block_length=block_length,
pre_ln=pre_ln,
)
def forward(self, x, x_mask=None, other_embeds=0, attn_mask=1):
if self.n_vocab > 0:
x_lengths = (x > 0).long().sum(-1)
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
else:
x_lengths = (x.abs().sum(-1) > 0).long().sum(-1)
x = x + other_embeds
x = torch.transpose(x, 1, -1) # [b, h, t]
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
if self.prenet:
x = self.pre(x, x_mask)
self.prenet_out = x.transpose(1, 2)
if hasattr(self, 'encoder_inp_proj'):
x = self.encoder_inp_proj(x) * x_mask
x = self.encoder(x, x_mask, attn_mask)
return x.transpose(1, 2)
================================================
FILE: tts/modules/ar_dur/commons/rot_transformer.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
from typing import Optional, Tuple
from torch import nn
from torch.nn import Parameter, Linear
from tts.modules.ar_dur.commons.layers import LayerNorm, Embedding
from tts.modules.ar_dur.commons.transformer import TransformerFFNLayer, MultiheadAttention
from tts.modules.ar_dur.commons.seq_utils import get_incremental_state, set_incremental_state, softmax, make_positions
import torch.nn.functional as F
DEFAULT_MAX_SOURCE_POSITIONS = 3000
DEFAULT_MAX_TARGET_POSITIONS = 3000
class SinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.
Padding symbols are ignored.
"""
def __init__(self, embedding_dim, padding_idx, init_size=1024):
super().__init__()
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.weights = SinusoidalPositionalEmbedding.get_embedding(
init_size,
embedding_dim,
padding_idx,
)
self.register_buffer('_float_tensor', torch.FloatTensor(1))
@staticmethod
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
"""Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
if embedding_dim % 2 == 1:
# zero pad
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input.shape[:2]
max_pos = self.padding_idx + 1 + seq_len
if self.weights is None or max_pos > self.weights.size(0):
# recompute/expand embeddings if needed
self.weights = SinusoidalPositionalEmbedding.get_embedding(
max_pos,
self.embedding_dim,
self.padding_idx,
)
self.weights = self.weights.to(self._float_tensor)
if incremental_state is not None:
# positions is the same for every token when decoding a single step
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
positions = make_positions(input, self.padding_idx) if positions is None else positions
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
def max_positions(self):
"""Maximum number of supported positions."""
return int(1e5) # an arbitrary large number
class RotaryEmbeddings(nn.Module):
cos: torch.Tensor
sin: torch.Tensor
theta: torch.Tensor
def __init__(
self,
width: int,
*,
seq_len: int = 40000,
base: int = 10000,
device: Optional[torch.device] = None,
):
"""Rotary embeddings (Su et al., 2021) layer. The rotary embedding
will be precomputed for up to 'seq _len' positions. The embedding
will be recomputed when a longer sequence is found in the input.
:param width:
Rotary embedding dimensionality, must be even.
:param seq_len:
Number of positons to initially precompute.
:param base:
The base used for Θ_i, determines the cycle length of the
embeddings.
:param device: Device on which the module is to be initialized.
"""
super().__init__()
if width % 2:
raise ValueError(f"Width of rotary embeddings must be even, was: {width}")
# Ignore allocations on the meta device as we don't persist our buffer,
# i.e., we don't expect the backing tensor to be replaced with pretrained weights.
if device is not None and device.type == "meta":
device = None
# Θ_i = 10000^(-2(i-1)/d)
theta = torch.pow(
base, -torch.arange(0, width, 2, dtype=torch.float, device=device) / width
)
self.register_buffer("theta", theta, persistent=False)
self._create_rotary_embed(width=width, length=seq_len)
def _create_rotary_embed(self, *, width: int, length: int):
# mΘ
position = torch.arange(length, device=self.theta.device).unsqueeze(1)
m_theta = position * self.theta.unsqueeze(0)
# We apply both sin and cos twice (see Eq 15, 34), but the ordering
# is changed for compatibility with most common implementations.
m_theta = torch.cat([m_theta, m_theta], dim=-1)
re_cos = m_theta.cos().view([length, width])
re_sin = m_theta.sin().view([length, width])
self.register_buffer("cos", re_cos, persistent=False)
self.register_buffer("sin", re_sin, persistent=False)
def _rotate(self, input: torch.Tensor):
"""Rotate the input tensor by half of its innermost width.
input (Tensor): array to rotate.
RETURNS (Tensor): rotated array.
Shapes:
input - (..., width)
output - (..., width)
"""
half_idx = input.shape[-1] // 2
input_1 = -input[..., half_idx:]
input_2 = input[..., :half_idx]
return torch.cat([input_1, input_2], dim=-1)
def forward(self, input: torch.Tensor, *, positions: Optional[torch.Tensor] = None):
"""
Apply rotary embeddings to an array.
:param input: Array to apply the rotary embeddings to.
:param positions: positions of the inputs. If no positions are
provided, they are assumed to be [0, seq_len).
:return: Array with the rotary embeddings applied.
Shapes:
input - (batch_size, num_heads, seq_len, width_per_head)
positions - (batch_size, seq_len)
output - (batch_size, num_heads, seq_len, width_per_head)
"""
batch_size, _, seq_len, width = input.shape
if positions is None:
# Fastpath: positions from [0..seq_len), avoid indexing.
if self.cos.size(-2) < seq_len:
self._create_rotary_embed(width=width, length=seq_len)
rot_cos = self.cos[:seq_len, :].view(1, 1, seq_len, width)
rot_sin = self.sin[:seq_len, :].view(1, 1, seq_len, width)
else:
max_len = int(positions.max()) + 1
if self.cos.size(-2) < max_len:
self._create_rotary_embed(width=width, length=max_len)
# Flatten positions to index cos/sin arrays, then unflatten.
#
# Example shapes:
#
# positions_flat - (batch_size * seq_len)
# self.cos - (max_len, width)
# rot_cos - (batch_size, seq_len, width)
positions_flat = positions.view(-1)
rot_cos = self.cos[positions_flat].view(batch_size, 1, seq_len, width)
rot_sin = self.sin[positions_flat].view(batch_size, 1, seq_len, width)
# Eq 34 with ordering changed for compatibility.
return rot_cos * input + rot_sin * self._rotate(input)
class RotMultiheadAttention(MultiheadAttention):
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
add_bias_kv=False, add_zero_attn=False, self_attention=False,
encoder_decoder_attention=False):
super().__init__(embed_dim, num_heads, kdim=kdim, vdim=vdim, dropout=dropout, bias=bias,
add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=self_attention,
encoder_decoder_attention=encoder_decoder_attention)
self.rotary_embeds = RotaryEmbeddings(width=embed_dim // num_heads)
def forward(
self,
query, key, value,
spk_pos_ids_flat=None,
key_padding_mask=None,
incremental_state=None,
need_weights=True,
static_kv=False,
attn_mask=None,
before_softmax=False,
need_head_weights=False,
enc_dec_attn_constraint_mask=None,
reset_attn_weight=None
):
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
before_softmax (bool, optional): return the raw attention
weights and values before the attention softmax.
need_head_weights (bool, optional): return the attention
weights for each head. Implies *need_weights*. Default:
return the average attention weights over all heads.
"""
if need_head_weights:
need_weights = True
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if 'prev_key' in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert self.encoder_decoder_attention and not self.self_attention
key = value = None
else:
saved_state = None
if self.self_attention:
# self-attention
q, k, v = self.in_proj_qkv(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.in_proj_q(query)
if key is None:
assert value is None
k = v = None
else:
k = self.in_proj_k(key)
v = self.in_proj_v(key)
else:
q = self.in_proj_q(query)
k = self.in_proj_k(key)
v = self.in_proj_v(value)
q = q * self.scaling
if self.bias_k is not None:
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
if k is not None:
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
if v is not None:
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
# Apply rot embedding and store incremental_state
q = self.rotary_embeds(q[None, :], positions=spk_pos_ids_flat)[0]
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if 'prev_key' in saved_state:
prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
k = prev_key
else:
k = torch.cat((prev_key, k), dim=1)
if 'prev_value' in saved_state:
prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
v = prev_value
else:
v = torch.cat((prev_value, v), dim=1)
saved_state['prev_key'], saved_state['prev_value'] = k.view(bsz, self.num_heads, -1, self.head_dim), v.view(
bsz, self.num_heads, -1, self.head_dim)
self._set_input_buffer(incremental_state, saved_state)
if incremental_state is not None:
key_pos = torch.arange(k.shape[-2], device=q.device).unsqueeze(0)
else:
key_pos = spk_pos_ids_flat
k = self.rotary_embeds(k[None, :], positions=key_pos)[0]
src_len = k.size(1)
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
key_padding_mask = None
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if self.add_zero_attn:
src_len += 1
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
if attn_mask is not None:
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None:
if len(attn_mask.shape) == 2:
attn_mask = attn_mask.unsqueeze(0)
elif len(attn_mask.shape) == 3:
attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
bsz * self.num_heads, tgt_len, src_len)
attn_weights = attn_weights + attn_mask
if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.masked_fill(
enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
-1e8,
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
-1e8,
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
if before_softmax:
return attn_weights, v
attn_weights_float = softmax(attn_weights, dim=-1)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
if reset_attn_weight is not None:
if reset_attn_weight:
self.last_attn_probs = attn_probs.detach()
else:
assert self.last_attn_probs is not None
attn_probs = self.last_attn_probs
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
if need_weights:
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
if not need_head_weights:
# average attention weights over heads
attn_weights = attn_weights.mean(dim=0)
else:
attn_weights = None
return attn, (attn_weights, attn_logits)
class RotMultiheadAttention2(MultiheadAttention):
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
add_bias_kv=False, add_zero_attn=False, self_attention=False,
encoder_decoder_attention=False):
super().__init__(embed_dim, num_heads, kdim=kdim, vdim=vdim, dropout=dropout, bias=bias,
add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=self_attention,
encoder_decoder_attention=encoder_decoder_attention)
self.rotary_embeds = RotaryEmbeddings(width=embed_dim // num_heads)
def forward(
self,
query, key, value,
spk_pos_ids_flat=None,
key_padding_mask=None,
incremental_state=None,
need_weights=True,
static_kv=False,
attn_mask=None,
before_softmax=False,
need_head_weights=False,
enc_dec_attn_constraint_mask=None,
reset_attn_weight=None
):
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
before_softmax (bool, optional): return the raw attention
weights and values before the attention softmax.
need_head_weights (bool, optional): return the attention
weights for each head. Implies *need_weights*. Default:
return the average attention weights over all heads.
"""
if need_head_weights:
need_weights = True
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if 'prev_key' in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert self.encoder_decoder_attention and not self.self_attention
key = value = None
else:
saved_state = None
if self.self_attention:
# self-attention
q, k, v = self.in_proj_qkv(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.in_proj_q(query)
if key is None:
assert value is None
k = v = None
else:
k = self.in_proj_k(key)
v = self.in_proj_v(key)
else:
q = self.in_proj_q(query)
k = self.in_proj_k(key)
v = self.in_proj_v(value)
if self.bias_k is not None:
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
if k is not None:
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
if v is not None:
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
# Apply rot embedding and store incremental_state
q = self.rotary_embeds(q[None, :], positions=spk_pos_ids_flat)[0]
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if 'prev_key' in saved_state:
prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
k = prev_key
else:
k = torch.cat((prev_key, k), dim=1)
if 'prev_value' in saved_state:
prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
v = prev_value
else:
v = torch.cat((prev_value, v), dim=1)
saved_state['prev_key'], saved_state['prev_value'] = k.view(bsz, self.num_heads, -1, self.head_dim), v.view(
bsz, self.num_heads, -1, self.head_dim)
self._set_input_buffer(incremental_state, saved_state)
key_pos = torch.arange(k.shape[-2], device=q.device).unsqueeze(0)
k = self.rotary_embeds(k[None, :], positions=key_pos)[0]
src_len = k.size(1)
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
key_padding_mask = None
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if attn_mask is not None:
if len(attn_mask.shape) == 2:
attn_mask = attn_mask.unsqueeze(0)
elif len(attn_mask.shape) == 3:
attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
bsz * self.num_heads, tgt_len, src_len)
attn = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=0, is_causal=False)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn_logits = None
attn_weights = None
return attn, (attn_weights, attn_logits)
class RotDecSALayer(nn.Module):
def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1,
kernel_size=9, ffn_hidden_size=1024, act='gelu', post_ln=False, bias=True):
super().__init__()
self.c = c
self.dropout = dropout
self.layer_norm1 = LayerNorm(c)
self.self_attn = RotMultiheadAttention(
c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
)
self.layer_norm2 = LayerNorm(c)
self.ffn = TransformerFFNLayer(
c, ffn_hidden_size, padding='LEFT', kernel_size=kernel_size,
dropout=relu_dropout, act=act, bias=bias)
self.post_ln = post_ln
def forward(
self,
x,
encoder_out=None,
encoder_padding_mask=None,
incremental_state=None,
self_attn_mask=None,
self_attn_padding_mask=None,
attn_out=None,
reset_attn_weight=None,
spk_pos_ids_flat=None,
**kwargs,
):
layer_norm_training = kwargs.get('layer_norm_training', None)
if layer_norm_training is not None:
self.layer_norm1.training = layer_norm_training
self.layer_norm2.training = layer_norm_training
residual = x
if not self.post_ln:
x = self.layer_norm1(x)
x, (attn_weights, _) = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
attn_mask=self_attn_mask,
spk_pos_ids_flat=spk_pos_ids_flat
)
x = F.dropout(x, self.dropout, training=self.training)
x = residual + x
if self.post_ln:
x = self.layer_norm1(x)
residual = x
if not self.post_ln:
x = self.layer_norm2(x)
x = self.ffn(x, incremental_state=incremental_state)
x = F.dropout(x, self.dropout, training=self.training)
x = residual + x
if self.post_ln:
x = self.layer_norm2(x)
return x, attn_weights
def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
self.encoder_attn.clear_buffer(incremental_state)
self.ffn.clear_buffer(incremental_state)
def set_buffer(self, name, tensor, incremental_state):
return set_incremental_state(self, incremental_state, name, tensor)
class RotDecSALayer2(RotDecSALayer):
def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1, kernel_size=9,
ffn_hidden_size=1024, act='gelu', post_ln=False):
super().__init__(c, num_heads, dropout, attention_dropout, relu_dropout, kernel_size, ffn_hidden_size, act,
post_ln)
self.self_attn = RotMultiheadAttention2(
c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
)
class RotTransformerDecoderLayer(nn.Module):
def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=8, ffn_hidden_size=1024, post_ln=False,
op_version=1, bias=True):
super().__init__()
self.hidden_size = hidden_size
self.dropout = dropout
self.num_heads = num_heads
if op_version == 1:
self.op = RotDecSALayer(
hidden_size, num_heads, dropout=dropout,
attention_dropout=0.0, relu_dropout=dropout,
kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size,
post_ln=post_ln, bias=bias)
else:
self.op = RotDecSALayer2(
hidden_size, num_heads, dropout=dropout,
attention_dropout=0.0, relu_dropout=dropout,
kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size,
post_ln=post_ln)
def forward(self, x, **kwargs):
return self.op(x, **kwargs)
def clear_buffer(self, *args):
return self.op.clear_buffer(*args)
def set_buffer(self, *args):
return self.op.set_buffer(*args)
================================================
FILE: tts/modules/ar_dur/commons/seq_utils.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
import torch
import torch.nn.functional as F
def make_positions(tensor, padding_idx):
"""Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1. Padding symbols are ignored.
"""
# The series of casts and type-conversions here are carefully
# balanced to both work with ONNX export and XLA. In particular XLA
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know
# how to handle the dtype kwarg in cumsum.
mask = tensor.ne(padding_idx).int()
return (
torch.cumsum(mask, dim=1).type_as(mask) * mask
).long() + padding_idx
def softmax(x, dim):
return F.softmax(x, dim=dim, dtype=torch.float32)
def sequence_mask(lengths, maxlen=None, dtype=torch.bool):
if maxlen is None:
maxlen = lengths.max()
mask = ~(torch.ones((len(lengths), maxlen)).to(lengths.device).cumsum(dim=1).t() > lengths).t()
mask.type(dtype)
return mask
def weights_nonzero_speech(target):
# target : B x T x mel
# Assign weight 1.0 to all labels except for padding (id=0).
dim = target.size(-1)
return target.abs().sum(-1, keepdim=True).ne(0).float().repeat(1, 1, dim)
INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0)
def _get_full_incremental_state_key(module_instance, key):
module_name = module_instance.__class__.__name__
# assign a unique ID to each module instance, so that incremental state is
# not shared across module instances
if not hasattr(module_instance, '_instance_id'):
INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1
module_instance._instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name]
return '{}.{}.{}'.format(module_name, module_instance._instance_id, key)
def get_incremental_state(module, incremental_state, key):
"""Helper for getting incremental state for an nn.Module."""
full_key = _get_full_incremental_state_key(module, key)
if incremental_state is None or full_key not in incremental_state:
return None
return incremental_state[full_key]
def set_incremental_state(module, incremental_state, key, value):
"""Helper for setting incremental state for an nn.Module."""
if incremental_state is not None:
full_key = _get_full_incremental_state_key(module, key)
incremental_state[full_key] = value
def fill_with_neg_inf(t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(float('-inf')).type_as(t)
def fill_with_neg_inf2(t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(-1e8).type_as(t)
def select_attn(attn_logits, type='best'):
"""
:param attn_logits: [n_layers, B, n_head, T_sp, T_txt]
:return:
"""
encdec_attn = torch.stack(attn_logits, 0).transpose(1, 2)
# [n_layers * n_head, B, T_sp, T_txt]
encdec_attn = (encdec_attn.reshape([-1, *encdec_attn.shape[2:]])).softmax(-1)
if type == 'best':
indices = encdec_attn.max(-1).values.sum(-1).argmax(0)
encdec_attn = encdec_attn.gather(
0, indices[None, :, None, None].repeat(1, 1, encdec_attn.size(-2), encdec_attn.size(-1)))[0]
return encdec_attn
elif type == 'mean':
return encdec_attn.mean(0)
def make_pad_mask(lengths, xs=None, length_dim=-1):
"""Make mask tensor containing indices of padded part.
Args:
lengths (LongTensor or List): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor.
If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor.
See the example.
Returns:
Tensor: Mask tensor containing indices of padded part.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
With the reference tensor.
>>> xs = torch.zeros((3, 2, 4))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 0, 0, 1],
[0, 0, 0, 1]],
[[0, 0, 1, 1],
[0, 0, 1, 1]]], dtype=torch.uint8)
>>> xs = torch.zeros((3, 2, 6))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
With the reference tensor and dimension indicator.
>>> xs = torch.zeros((3, 6, 6))
>>> make_pad_mask(lengths, xs, 1)
tensor([[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
>>> make_pad_mask(lengths, xs, 2)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
"""
if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
if not isinstance(lengths, list):
lengths = lengths.tolist()
bs = int(len(lengths))
if xs is None:
maxlen = int(max(lengths))
else:
maxlen = xs.size(length_dim)
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
if xs is not None:
assert xs.size(0) == bs, (xs.size(0), bs)
if length_dim < 0:
length_dim = xs.dim() + length_dim
# ind = (:, None, ..., None, :, , None, ..., None)
ind = tuple(
slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
)
mask = mask[ind].expand_as(xs).to(xs.device)
return mask
def make_non_pad_mask(lengths, xs=None, length_dim=-1):
"""Make mask tensor containing indices of non-padded part.
Args:
lengths (LongTensor or List): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor.
If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor.
See the example.
Returns:
ByteTensor: mask tensor containing indices of padded part.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[1, 1, 1, 1 ,1],
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
With the reference tensor.
>>> xs = torch.zeros((3, 2, 4))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1],
[1, 1, 1, 1]],
[[1, 1, 1, 0],
[1, 1, 1, 0]],
[[1, 1, 0, 0],
[1, 1, 0, 0]]], dtype=torch.uint8)
>>> xs = torch.zeros((3, 2, 6))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
With the reference tensor and dimension indicator.
>>> xs = torch.zeros((3, 6, 6))
>>> make_non_pad_mask(lengths, xs, 1)
tensor([[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
>>> make_non_pad_mask(lengths, xs, 2)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
"""
return ~make_pad_mask(lengths, xs, length_dim)
def get_mask_from_lengths(lengths):
max_len = torch.max(lengths).item()
ids = torch.arange(0, max_len).to(lengths.device)
mask = (ids < lengths.unsqueeze(1)).bool()
return mask
def group_hidden_by_segs(h, seg_ids, max_len):
"""
:param h: [B, T, H]
:param seg_ids: [B, T]
:return: h_ph: [B, T_ph, H]
"""
B, T, H = h.shape
h_gby_segs = h.new_zeros([B, max_len + 1, H]).scatter_add_(1, seg_ids[:, :, None].repeat([1, 1, H]), h)
all_ones = h.new_ones(h.shape[:2])
cnt_gby_segs = h.new_zeros([B, max_len + 1]).scatter_add_(1, seg_ids, all_ones).contiguous()
h_gby_segs = h_gby_segs[:, 1:]
cnt_gby_segs = cnt_gby_segs[:, 1:]
h_gby_segs = h_gby_segs / torch.clamp(cnt_gby_segs[:, :, None], min=1)
return h_gby_segs, cnt_gby_segs
def expand_by_repeat_times(source_encoding, lengths):
"""
source_encoding: [T, C]
lengths, list of int, [T,], how many times each token should repeat
return:
expanded_encoding: [T_expand, C]
"""
hid_dim = source_encoding.shape[1]
out2source = []
for i, length in enumerate(lengths):
out2source += [i for _ in range(length)]
out2source = torch.LongTensor(out2source).to(source_encoding.device)
out2source_ = out2source[:, None].repeat([1, hid_dim])
expanded_encoding = torch.gather(source_encoding, 0, out2source_) # [B, T, H]
return expanded_encoding
def expand_word2ph(word_encoding, ph2word):
word_encoding = F.pad(word_encoding,[0,0,1,0])
ph2word_ = ph2word[:, :, None].repeat([1, 1, word_encoding.shape[-1]])
out = torch.gather(word_encoding, 1, ph2word_) # [B, T, H]
return out
================================================
FILE: tts/modules/ar_dur/commons/transformer.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
from torch import nn
from torch.nn import Parameter, Linear
from tts.modules.ar_dur.commons.layers import LayerNorm, Embedding
from tts.modules.ar_dur.commons.seq_utils import get_incremental_state, set_incremental_state, softmax, make_positions
import torch.nn.functional as F
DEFAULT_MAX_SOURCE_POSITIONS = 3000
DEFAULT_MAX_TARGET_POSITIONS = 3000
class SinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.
Padding symbols are ignored.
"""
def __init__(self, embedding_dim, padding_idx, init_size=1024):
super().__init__()
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.weights = SinusoidalPositionalEmbedding.get_embedding(
init_size,
embedding_dim,
padding_idx,
)
self.register_buffer('_float_tensor', torch.FloatTensor(1))
@staticmethod
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
"""Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
if embedding_dim % 2 == 1:
# zero pad
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input.shape[:2]
max_pos = self.padding_idx + 1 + seq_len
if self.weights is None or max_pos > self.weights.size(0):
# recompute/expand embeddings if needed
self.weights = SinusoidalPositionalEmbedding.get_embedding(
max_pos,
self.embedding_dim,
self.padding_idx,
)
self.weights = self.weights.to(self._float_tensor)
if incremental_state is not None:
# positions is the same for every token when decoding a single step
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
positions = make_positions(input, self.padding_idx) if positions is None else positions
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
def max_positions(self):
"""Maximum number of supported positions."""
return int(1e5) # an arbitrary large number
class TransformerFFNLayer(nn.Module):
def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu', bias=True):
super().__init__()
self.kernel_size = kernel_size
self.dropout = dropout
self.act = act
if padding == 'SAME':
self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size,
padding=kernel_size // 2, bias=bias)
elif padding == 'LEFT':
self.ffn_1 = nn.Sequential(
nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
nn.Conv1d(hidden_size, filter_size, kernel_size, bias=bias)
)
self.ffn_2 = Linear(filter_size, hidden_size, bias=bias)
def forward(self, x, incremental_state=None):
# x: T x B x C
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if 'prev_input' in saved_state:
prev_input = saved_state['prev_input']
x = torch.cat((prev_input, x), dim=0)
x = x[-self.kernel_size:]
saved_state['prev_input'] = x
self._set_input_buffer(incremental_state, saved_state)
x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
x = x * self.kernel_size ** -0.5
if incremental_state is not None:
x = x[-1:]
if self.act == 'gelu':
x = F.gelu(x)
if self.act == 'relu':
x = F.relu(x)
x = F.dropout(x, self.dropout, training=self.training)
x = self.ffn_2(x)
return x
def _get_input_buffer(self, incremental_state):
return get_incremental_state(
self,
incremental_state,
'f',
) or {}
def _set_input_buffer(self, incremental_state, buffer):
set_incremental_state(
self,
incremental_state,
'f',
buffer,
)
def clear_buffer(self, incremental_state):
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if 'prev_input' in saved_state:
del saved_state['prev_input']
self._set_input_buffer(incremental_state, saved_state)
class MultiheadAttention(nn.Module):
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
add_bias_kv=False, add_zero_attn=False, self_attention=False,
encoder_decoder_attention=False):
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention
assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
'value to be of the same size'
if self.qkv_same_dim:
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
else:
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
if bias:
self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
else:
self.register_parameter('in_proj_bias', None)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
if add_bias_kv:
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
else:
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
self.reset_parameters()
self.enable_torch_version = False
self.last_attn_probs = None
def reset_parameters(self):
if self.qkv_same_dim:
nn.init.xavier_uniform_(self.in_proj_weight)
else:
nn.init.xavier_uniform_(self.k_proj_weight)
nn.init.xavier_uniform_(self.v_proj_weight)
nn.init.xavier_uniform_(self.q_proj_weight)
nn.init.xavier_uniform_(self.out_proj.weight)
if self.in_proj_bias is not None:
nn.init.constant_(self.in_proj_bias, 0.)
nn.init.constant_(self.out_proj.bias, 0.)
if self.bias_k is not None:
nn.init.xavier_normal_(self.bias_k)
if self.bias_v is not None:
nn.init.xavier_normal_(self.bias_v)
def forward(
self,
query, key, value,
key_padding_mask=None,
incremental_state=None,
need_weights=True,
static_kv=False,
attn_mask=None,
before_softmax=False,
need_head_weights=False,
enc_dec_attn_constraint_mask=None,
reset_attn_weight=None
):
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
before_softmax (bool, optional): return the raw attention
weights and values before the attention softmax.
need_head_weights (bool, optional): return the attention
weights for each head. Implies *need_weights*. Default:
return the average attention weights over all heads.
"""
if need_head_weights:
need_weights = True
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
if self.qkv_same_dim:
return F.multi_head_attention_forward(query, key, value,
self.embed_dim, self.num_heads,
self.in_proj_weight,
self.in_proj_bias, self.bias_k, self.bias_v,
self.add_zero_attn, self.dropout,
self.out_proj.weight, self.out_proj.bias,
self.training, key_padding_mask, need_weights,
attn_mask)
else:
return F.multi_head_attention_forward(query, key, value,
self.embed_dim, self.num_heads,
torch.empty([0]),
self.in_proj_bias, self.bias_k, self.bias_v,
self.add_zero_attn, self.dropout,
self.out_proj.weight, self.out_proj.bias,
self.training, key_padding_mask, need_weights,
attn_mask, use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight,
k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight)
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if 'prev_key' in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert self.encoder_decoder_attention and not self.self_attention
key = value = None
else:
saved_state = None
if self.self_attention:
# self-attention
q, k, v = self.in_proj_qkv(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.in_proj_q(query)
if key is None:
assert value is None
k = v = None
else:
k = self.in_proj_k(key)
v = self.in_proj_v(key)
else:
q = self.in_proj_q(query)
k = self.in_proj_k(key)
v = self.in_proj_v(value)
q = q * self.scaling
if self.bias_k is not None:
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
if k is not None:
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
if v is not None:
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if 'prev_key' in saved_state:
prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
k = prev_key
else:
k = torch.cat((prev_key, k), dim=1)
if 'prev_value' in saved_state:
prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
v = prev_value
else:
v = torch.cat((prev_value, v), dim=1)
if 'prev_key_padding_mask' in saved_state and saved_state['prev_key_padding_mask'] is not None:
prev_key_padding_mask = saved_state['prev_key_padding_mask']
if static_kv:
key_padding_mask = prev_key_padding_mask
else:
key_padding_mask = torch.cat((prev_key_padding_mask, key_padding_mask), dim=1)
saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
saved_state['prev_key_padding_mask'] = key_padding_mask
self._set_input_buffer(incremental_state, saved_state)
src_len = k.size(1)
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
key_padding_mask = None
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if self.add_zero_attn:
src_len += 1
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
if attn_mask is not None:
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None:
if len(attn_mask.shape) == 2:
attn_mask = attn_mask.unsqueeze(0)
elif len(attn_mask.shape) == 3:
attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
bsz * self.num_heads, tgt_len, src_len)
attn_weights = attn_weights + attn_mask
if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.masked_fill(
enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
-1e8,
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
-1e8,
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
if before_softmax:
return attn_weights, v
attn_weights_float = softmax(attn_weights, dim=-1)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
if reset_attn_weight is not None:
if reset_attn_weight:
self.last_attn_probs = attn_probs.detach()
else:
assert self.last_attn_probs is not None
attn_probs = self.last_attn_probs
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
if need_weights:
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
if not need_head_weights:
# average attention weights over heads
attn_weights = attn_weights.mean(dim=0)
else:
attn_weights = None
return attn, (attn_weights, attn_logits)
def in_proj_qkv(self, query):
return self._in_proj(query).chunk(3, dim=-1)
def in_proj_q(self, query):
if self.qkv_same_dim:
return self._in_proj(query, end=self.embed_dim)
else:
bias = self.in_proj_bias
if bias is not None:
bias = bias[:self.embed_dim]
return F.linear(query, self.q_proj_weight, bias)
def in_proj_k(self, key):
if self.qkv_same_dim:
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
else:
weight = self.k_proj_weight
bias = self.in_proj_bias
if bias is not None:
bias = bias[self.embed_dim:2 * self.embed_dim]
return F.linear(key, weight, bias)
def in_proj_v(self, value):
if self.qkv_same_dim:
return self._in_proj(value, start=2 * self.embed_dim)
else:
weight = self.v_proj_weight
bias = self.in_proj_bias
if bias is not None:
bias = bias[2 * self.embed_dim:]
return F.linear(value, weight, bias)
def _in_proj(self, input, start=0, end=None):
weight = self.in_proj_weight
bias = self.in_proj_bias
weight = weight[start:end, :]
if bias is not None:
bias = bias[start:end]
return F.linear(input, weight, bias)
def _get_input_buffer(self, incremental_state):
return get_incremental_state(
self,
incremental_state,
'attn_state',
) or {}
def _set_input_buffer(self, incremental_state, buffer):
set_incremental_state(
self,
incremental_state,
'attn_state',
buffer,
)
def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
return attn_weights
def clear_buffer(self, incremental_state=None):
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if 'prev_key' in saved_state:
del saved_state['prev_key']
if 'prev_value' in saved_state:
del saved_state['prev_value']
self._set_input_buffer(incremental_state, saved_state)
class EncSALayer(nn.Module):
def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
relu_dropout=0.1, kernel_size=9, padding='SAME', act='gelu',
ffn_hidden_size=1024):
super().__init__()
self.c = c
self.dropout = dropout
self.num_heads = num_heads
if num_heads > 0:
self.layer_norm1 = LayerNorm(c)
self.self_attn = MultiheadAttention(
self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False)
self.layer_norm2 = LayerNorm(c)
self.ffn = TransformerFFNLayer(
c, ffn_hidden_size, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act)
def forward(self, x, encoder_padding_mask=None, **kwargs):
layer_norm_training = kwargs.get('layer_norm_training', None)
if layer_norm_training is not None:
self.layer_norm1.training = layer_norm_training
self.layer_norm2.training = layer_norm_training
if self.num_heads > 0:
residual = x
x = self.layer_norm1(x)
x, _, = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=encoder_padding_mask
)
x = F.dropout(x, self.dropout, training=self.training)
x = residual + x
x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
residual = x
x = self.layer_norm2(x)
x = self.ffn(x)
x = F.dropout(x, self.dropout, training=self.training)
x = residual + x
x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
return x
class DecSALayer(nn.Module):
def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1,
kernel_size=9, ffn_hidden_size=1024, act='gelu', post_ln=False):
super().__init__()
self.c = c
self.dropout = dropout
self.layer_norm1 = LayerNorm(c)
self.self_attn = MultiheadAttention(
c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
)
self.layer_norm2 = LayerNorm(c)
self.encoder_attn = MultiheadAttention(
c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False,
)
self.layer_norm3 = LayerNorm(c)
self.ffn = TransformerFFNLayer(
c, ffn_hidden_size, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
self.post_ln = post_ln
def forward(
self,
x,
encoder_out=None,
encoder_padding_mask=None,
incremental_state=None,
self_attn_mask=None,
self_attn_padding_mask=None,
attn_out=None,
reset_attn_weight=None,
**kwargs,
):
layer_norm_training = kwargs.get('layer_norm_training', None)
if layer_norm_training is not None:
self.layer_norm1.training = layer_norm_training
self.layer_norm2.training = layer_norm_training
self.layer_norm3.training = layer_norm_training
residual = x
if not self.post_ln:
x = self.layer_norm1(x)
x, _ = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
attn_mask=self_attn_mask
)
x = F.dropout(x, self.dropout, training=self.training)
x = residual + x
if self.post_ln:
x = self.layer_norm1(x)
attn_logits = None
if encoder_out is not None or attn_out is not None:
residual = x
if not self.post_ln:
x = self.layer_norm2(x)
if encoder_out is not None:
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
enc_dec_attn_constraint_mask=get_incremental_state(self, incremental_state,
'enc_dec_attn_constraint_mask'),
reset_attn_weight=reset_attn_weight
)
attn_logits = attn[1]
elif attn_out is not None:
x = self.encoder_attn.in_proj_v(attn_out)
if encoder_out is not None or attn_out is not None:
x = F.dropout(x, self.dropout, training=self.training)
x = residual + x
if self.post_ln:
x = self.layer_norm2(x)
residual = x
if not self.post_ln:
x = self.layer_norm3(x)
x = self.ffn(x, incremental_state=incremental_state)
x = F.dropout(x, self.dropout, training=self.training)
x = residual + x
if self.post_ln:
x = self.layer_norm3(x)
return x, attn_logits
def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
self.encoder_attn.clear_buffer(incremental_state)
self.ffn.clear_buffer(incremental_state)
def set_buffer(self, name, tensor, incremental_state):
return set_incremental_state(self, incremental_state, name, tensor)
class TransformerEncoderLayer(nn.Module):
def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2, ffn_hidden_size=1024):
super().__init__()
self.hidden_size = hidden_size
self.dropout = dropout
self.num_heads = num_heads
self.op = EncSALayer(
hidden_size, num_heads, dropout=dropout,
attention_dropout=0.0, relu_dropout=dropout,
kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size)
def forward(self, x, **kwargs):
return self.op(x, **kwargs)
class TransformerDecoderLayer(nn.Module):
def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2, ffn_hidden_size=1024, post_ln=False):
super().__init__()
self.hidden_size = hidden_size
self.dropout = dropout
self.num_heads = num_heads
self.op = DecSALayer(
hidden_size, num_heads, dropout=dropout,
attention_dropout=0.0, relu_dropout=dropout,
kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size,
post_ln=post_ln)
def forward(self, x, **kwargs):
return self.op(x, **kwargs)
def clear_buffer(self, *args):
return self.op.clear_buffer(*args)
def set_buffer(self, *args):
return self.op.set_buffer(*args)
class FFTBlocks(nn.Module):
def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=0.0,
num_heads=2, use_pos_embed=True, use_last_norm=True,
use_pos_embed_alpha=True, ffn_hidden_size=1024):
super().__init__()
self.num_layers = num_layers
embed_dim = self.hidden_size = hidden_size
self.dropout = dropout
self.use_pos_embed = use_pos_embed
self.use_last_norm = use_last_norm
if use_pos_embed:
self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS
self.padding_idx = 0
self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1
self.embed_positions = SinusoidalPositionalEmbedding(
embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
)
self.layers = nn.ModuleList([])
self.layers.extend([
TransformerEncoderLayer(self.hidden_size, self.dropout,
kernel_size=ffn_kernel_size, num_heads=num_heads,
ffn_hidden_size=ffn_hidden_size)
for _ in range(self.num_layers)
])
if self.use_last_norm:
self.layer_norm = nn.LayerNorm(embed_dim)
else:
self.layer_norm = None
def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False):
"""
:param x: [B, T, C]
:param padding_mask: [B, T]
:return: [B, T, C] or [L, B, T, C]
"""
padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1]
if self.use_pos_embed:
positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
x = x + positions
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1) * nonpadding_mask_TB
hiddens = []
for layer in self.layers:
x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
hiddens.append(x)
if self.use_last_norm:
x = self.layer_norm(x) * nonpadding_mask_TB
if return_hiddens:
x = torch.stack(hiddens, 0) # [L, T, B, C]
x = x.transpose(1, 2) # [L, B, T, C]
else:
x = x.transpose(0, 1) # [B, T, C]
return x
class FastSpeechEncoder(FFTBlocks):
def __init__(self, dict_size, hidden_size=256, num_layers=4, kernel_size=9,
dropout=0.0, num_heads=2, ffn_hidden_size=1024):
super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads,
use_pos_embed=False, dropout=dropout, ffn_hidden_size=ffn_hidden_size)
self.embed_tokens = Embedding(dict_size, hidden_size, 0)
self.embed_scale = math.sqrt(hidden_size)
self.padding_idx = 0
self.embed_positions = SinusoidalPositionalEmbedding(
hidden_size, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
)
def forward(self, txt_tokens, attn_mask=None, other_embeds=0):
"""
:param txt_tokens: [B, T]
:return: {
'encoder_out': [B x T x C]
}
"""
encoder_padding_mask = txt_tokens.eq(self.padding_idx).data
x = self.forward_embedding(txt_tokens) + other_embeds # [B, T, H]
if self.num_layers > 0:
x = super(FastSpeechEncoder, self).forward(x, encoder_padding_mask, attn_mask=attn_mask)
return x
def forward_embedding(self, txt_tokens):
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(txt_tokens)
if self.use_pos_embed:
positions = self.embed_positions(txt_tokens)
x = x + positions
x = F.dropout(x, p=self.dropout, training=self.training)
return x
================================================
FILE: tts/modules/llm_dit/cfm.py
================================================
# MIT License
# Copyright (c) 2023 Alexander Tong
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# Copyright (c) [2023] [Alexander Tong]
# Copyright (c) [2025] [Ziyue Jiang]
# SPDX-License-Identifier: MIT
# This file has been modified by Ziyue Jiang on 2025/03/19
# Original file was released under MIT, with the full license text # available at https://github.com/atong01/conditional-flow-matching/blob/1.0.7/LICENSE.
# This modified file is released under the same license.
import math
import torch
from typing import Union
from torch.distributions import LogisticNormal
class LogitNormalTrainingTimesteps:
def __init__(self, T=1000.0, loc=0.0, scale=1.0):
assert T > 0
self.T = T
self.dist = LogisticNormal(loc, scale)
def sample(self, size, device):
t = self.dist.sample(size)[..., 0].to(device)
return t
def pad_t_like_x(t, x):
"""Function to reshape the time vector t by the number of dimensions of x.
Parameters
----------
x : Tensor, shape (bs, *dim)
represents the source minibatch
t : FloatTensor, shape (bs)
Returns
-------
t : Tensor, shape (bs, number of x dimensions)
Example
-------
x: Tensor (bs, C, W, H)
t: Vector (bs)
pad_t_like_x(t, x): Tensor (bs, 1, 1, 1)
"""
if isinstance(t, (float, int)):
return t
return t.reshape(-1, *([1] * (x.dim() - 1)))
class ConditionalFlowMatcher:
"""Base class for conditional flow matching methods. This class implements the independent
conditional flow matching methods from [1] and serves as a parent class for all other flow
matching methods.
It implements:
- Drawing data from gaussian probability path N(t * x1 + (1 - t) * x0, sigma) function
- conditional flow matching ut(x1|x0) = x1 - x0
- score function $\nabla log p_t(x|x0, x1)$
"""
def __init__(self, sigma: Union[float, int] = 0.0):
r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$.
Parameters
----------
sigma : Union[float, int]
"""
self.sigma = sigma
self.time_sampler = LogitNormalTrainingTimesteps()
def compute_mu_t(self, x0, x1, t):
"""
Compute the mean of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].
Parameters
----------
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the target minibatch
t : FloatTensor, shape (bs)
Returns
-------
mean mu_t: t * x1 + (1 - t) * x0
References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
"""
t = pad_t_like_x(t, x0)
return t * x1 + (1 - t) * x0
def compute_sigma_t(self, t):
"""
Compute the standard deviation of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].
Parameters
----------
t : FloatTensor, shape (bs)
Returns
-------
standard deviation sigma
References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
"""
del t
return self.sigma
def sample_xt(self, x0, x1, t, epsilon):
"""
Draw a sample from the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].
Parameters
----------
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the target minibatch
t : FloatTensor, shape (bs)
epsilon : Tensor, shape (bs, *dim)
noise sample from N(0, 1)
Returns
-------
xt : Tensor, shape (bs, *dim)
References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
"""
mu_t = self.compute_mu_t(x0, x1, t)
sigma_t = self.compute_sigma_t(t)
sigma_t = pad_t_like_x(sigma_t, x0)
return mu_t + sigma_t * epsilon
def compute_conditional_flow(self, x0, x1, t, xt):
"""
Compute the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].
Parameters
----------
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the target minibatch
t : FloatTensor, shape (bs)
xt : Tensor, shape (bs, *dim)
represents the samples drawn from probability path pt
Returns
-------
ut : conditional vector field ut(x1|x0) = x1 - x0
References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
"""
del t, xt
return x1 - x0
def sample_noise_like(self, x):
return torch.randn_like(x)
def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False):
"""
Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma))
and the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].
Parameters
----------
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the target minibatch
(optionally) t : Tensor, shape (bs)
represents the time levels
if None, drawn from uniform [0,1]
return_noise : bool
return the noise sample epsilon
Returns
-------
t : FloatTensor, shape (bs)
xt : Tensor, shape (bs, *dim)
represents the samples drawn from probability path pt
ut : conditional vector field ut(x1|x0) = x1 - x0
(optionally) eps: Tensor, shape (bs, *dim) such that xt = mu_t + sigma_t * epsilon
References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
"""
if t is None:
# t = torch.rand(x0.shape[0]).type_as(x0)
t = self.time_sampler.sample([x0.shape[0]], x0.device).type_as(x0)
assert len(t) == x0.shape[0], "t has to have batch size dimension"
eps = self.sample_noise_like(x0)
xt = self.sample_xt(x0, x1, t, eps)
ut = self.compute_conditional_flow(x0, x1, t, xt)
if return_noise:
return t, xt, ut, eps
else:
return t, xt, ut
def compute_lambda(self, t):
"""Compute the lambda function, see Eq.(23) [3].
Parameters
----------
t : FloatTensor, shape (bs)
Returns
-------
lambda : score weighting function
References
----------
[4] Simulation-free Schrodinger bridges via score and flow matching, Preprint, Tong et al.
"""
sigma_t = self.compute_sigma_t(t)
return 2 * sigma_t / (self.sigma**2 + 1e-8)
class VariancePreservingConditionalFlowMatcher(ConditionalFlowMatcher):
"""Albergo et al. 2023 trigonometric interpolants class. This class inherits the
ConditionalFlowMatcher and override the compute_mu_t and compute_conditional_flow functions in
order to compute [3]'s trigonometric interpolants.
[3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
"""
def compute_mu_t(self, x0, x1, t):
r"""Compute the mean of the probability path (Eq.5) from [3].
Parameters
----------
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the target minibatch
t : FloatTensor, shape (bs)
Returns
-------
mean mu_t: cos(pi t/2)x0 + sin(pi t/2)x1
References
----------
[3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
"""
t = pad_t_like_x(t, x0)
return torch.cos(math.pi / 2 * t) * x0 + torch.sin(math.pi / 2 * t) * x1
def compute_conditional_flow(self, x0, x1, t, xt):
r"""Compute the conditional vector field similar to [3].
ut(x1|x0) = pi/2 (cos(pi*t/2) x1 - sin(pi*t/2) x0),
see Eq.(21) [3].
Parameters
----------
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the target minibatch
t : FloatTensor, shape (bs)
xt : Tensor, shape (bs, *dim)
represents the samples drawn from probability path pt
Returns
-------
ut : conditional vector field
ut(x1|x0) = pi/2 (cos(pi*t/2) x1 - sin(\pi*t/2) x0)
References
----------
[3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
"""
del xt
t = pad_t_like_x(t, x0)
return math.pi / 2 * (torch.cos(math.pi / 2 * t) * x1 - torch.sin(math.pi / 2 * t) * x0)
================================================
FILE: tts/modules/llm_dit/dit.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from tts.modules.llm_dit.cfm import ConditionalFlowMatcher
from tts.modules.ar_dur.commons.layers import Embedding
from tts.modules.ar_dur.commons.nar_tts_modules import PosEmb
from tts.modules.ar_dur.commons.rel_transformer import RelTransformerEncoder
from tts.modules.ar_dur.ar_dur_predictor import expand_states
from tts.modules.llm_dit.transformer import Transformer
from tts.modules.llm_dit.time_embedding import TimestepEmbedding
class Diffusion(nn.Module):
def __init__(self):
super().__init__()
# Hparams
# cond dim
self.local_cond_dim = 512
self.ctx_mask_dim = 16
self.in_channels = 32
self.out_channels = 32
# LLM
self.encoder_dim = 1024
self.encoder_n_layers = 24
self.encoder_n_heads = 16
self.max_seq_len = 16384
self.multiple_of = 256
self.ctx_mask_proj = nn.Linear(1, self.ctx_mask_dim)
self.local_cond_project = nn.Linear(
self.out_channels + self.ctx_mask_dim, self.local_cond_dim)
self.encoder = Transformer(self.encoder_n_layers, self.encoder_dim, self.encoder_n_heads, self.max_seq_len)
self.x_prenet = nn.Linear(self.in_channels, self.encoder_dim)
self.prenet = nn.Linear(self.local_cond_dim, self.encoder_dim)
self.postnet = nn.Linear(self.encoder_dim, self.out_channels)
self.flow_matcher = ConditionalFlowMatcher(sigma=0.0)
# The implementation of TimestepEmbedding is a modified version from F5-TTS (https://github.com/SWivid/F5-TTS),
# which is licensed under the MIT License.
self.f5_time_embed = TimestepEmbedding(self.encoder_dim)
# text encoder
self.ph_encoder = RelTransformerEncoder(
302, self.encoder_dim, self.encoder_dim,
self.encoder_dim * 2, 4, 6,
3, 0.0, prenet=True, pre_ln=True)
self.tone_embed = Embedding(32, self.encoder_dim, padding_idx=0)
self.ph_pos_embed = PosEmb(self.encoder_dim)
self.ling_pre_net = torch.nn.Sequential(*[
torch.nn.Conv1d(self.encoder_dim, self.encoder_dim, kernel_size=s * 2, stride=s, padding=s // 2)
for i, s in enumerate([2, 2])
])
def forward(self, inputs, sigmas=None, x_noisy=None):
ctx_mask = inputs['ctx_mask']
ctx_feature = inputs['lat_ctx'] * ctx_mask
""" local conditioning (prompt_latent + spk_embed) """
ctx_mask_emb = self.ctx_mask_proj(ctx_mask)
# ctx_feature = ctx_feature * (1 - inputs["spk_cfg_mask"][:, :, None])
local_cond = torch.cat([ctx_feature, ctx_mask_emb], dim=-1)
local_cond = self.local_cond_project(local_cond)
""" diffusion target latent """
x = inputs['lat']
# Here, x is x1 in CFM
x0 = torch.randn_like(x)
t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow(x0, x)
# define noisy_input and target
t = t.bfloat16()
x_noisy = (xt * (1 - ctx_mask)).bfloat16()
target = ut
# concat condition.
x_ling = self.forward_ling_encoder(inputs["phone"], inputs["tone"])
x_ling = self.ling_pre_net(expand_states(x_ling, inputs['mel2ph']).transpose(1, 2)).transpose(1, 2)
x_noisy = self.x_prenet(x_noisy) + self.prenet(local_cond) + x_ling
encoder_out = self.encoder(x_noisy, self.f5_time_embed(t), attn_mask=inputs["text_mel_mask"], do_checkpoint=False)
pred = self.postnet(encoder_out)
return pred, target
def forward_ling_encoder(self, txt_tokens, tone_tokens):
ph_tokens = txt_tokens
ph_nonpadding = (ph_tokens > 0).float()[:, :, None] # [B, T_phone, 1]
# enc_ph
ph_enc_oembed = self.tone_embed(tone_tokens)
ph_enc_oembed = ph_enc_oembed + self.ph_pos_embed(
torch.arange(0, ph_tokens.shape[1])[None,].to(ph_tokens.device))
ph_enc_oembed = ph_enc_oembed
ph_enc_oembed = ph_enc_oembed * ph_nonpadding
x_ling = self.ph_encoder(ph_tokens, other_embeds=ph_enc_oembed) * ph_nonpadding
return x_ling
def _forward(self, x, local_cond, x_ling, timesteps, ctx_mask, dur=None, seq_cfg_w=[1.0,1.0]):
""" When we use torchdiffeq, we need to include the CFG process inside _forward() """
x = x * (1 - ctx_mask)
x = self.x_prenet(x) + self.prenet(local_cond) + x_ling
pred_v = self.encoder(x, self.f5_time_embed(timesteps), attn_mask=torch.ones((x.size(0), x.size(1)), device=x.device))
pred = self.postnet(pred_v)
""" Perform multi-cond CFG """
cond_spk_txt, cond_txt, uncond = pred.chunk(3)
pred = uncond + seq_cfg_w[0] * (cond_txt - uncond) + seq_cfg_w[1] * (cond_spk_txt - cond_txt)
return pred
@torch.no_grad()
def inference(self, inputs, timesteps=20, seq_cfg_w=[1.0, 1.0], **kwargs):
# txt embedding
x_ling = self.forward_ling_encoder(inputs["phone"], inputs["tone"])
x_ling = self.ling_pre_net(expand_states(x_ling, inputs['dur']).transpose(1, 2)).transpose(1, 2)
# speaker embedding
ctx_feature = inputs['lat_ctx']
ctx_feature[1:, :, :] = 0 # prefix spk cfg
ctx_mask_emb = self.ctx_mask_proj(inputs['ctx_mask'])
# local conditioning.
local_cond = torch.cat([ctx_feature, ctx_mask_emb], dim=-1)
local_cond = self.local_cond_project(local_cond)
''' Euler ODE solver '''
bsz, device, frm_len = (local_cond.size(0), local_cond.device, local_cond.size(1))
# Sway sampling from F5-TTS (https://github.com/SWivid/F5-TTS),
# which is licensed under the MIT License.
sway_sampling_coef = -1.0
t_schedule = torch.linspace(0, 1, timesteps + 1, device=device, dtype=x_ling.dtype)
if sway_sampling_coef is not None:
t_schedule = t_schedule + sway_sampling_coef * (torch.cos(torch.pi / 2 * t_schedule) - 1 + t_schedule)
# AMO sampling implementation for "AMO Sampler: Enhancing Text Rendering with Overshooting" (https://arxiv.org/pdf/2411.19415)
def amo_sampling(z_t, t, t_next, v):
# Upcast to avoid precision issues when computing prev_sample
z_t = z_t.to(torch.float32)
# Constant definition in Algorithm 1
s = t_next
c = 3
# Line 7 in Algorithm 1
o = min(t_next + c * (t_next - t), 1)
pred_z_o = z_t + (o - t) * v
# Line 11 in Algorithm 1
a = s / o
b = ((1 - s) ** 2 - (a * (1 - o)) ** 2) ** 0.5
noise_i = torch.randn(size=z_t.shape, device=z_t.device)
z_t_next = a * pred_z_o + b * noise_i
return z_t_next.to(v.dtype)
x = torch.randn([1, frm_len, self.out_channels], device=device)
for step_index in range(timesteps):
x = x.to(torch.float32)
sigma = t_schedule[step_index].to(x_ling.dtype)
sigma_next = t_schedule[step_index + 1]
model_out = self._forward(torch.cat([x] * bsz), local_cond, x_ling, timesteps=sigma.unsqueeze(0), ctx_mask=inputs['ctx_mask'], dur=inputs['dur'], seq_cfg_w=seq_cfg_w)
x = amo_sampling(x, sigma, sigma_next, model_out)
# Cast sample back to model compatible dtype
x = x.to(model_out.dtype)
return x
================================================
FILE: tts/modules/llm_dit/time_embedding.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
from torch import nn
class SinusPositionEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x, scale=1000):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class TimestepEmbedding(nn.Module):
def __init__(self, dim, freq_embed_dim=256):
super().__init__()
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
def forward(self, timestep): # noqa: F821
time_hidden = self.time_embed(timestep)
time_hidden = time_hidden.to(timestep.dtype)
time = self.time_mlp(time_hidden) # b d
return time
================================================
FILE: tts/modules/llm_dit/transformer.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class AdaLNZero(nn.Module):
def __init__(self, dim):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(dim, dim * 6)
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb=None):
emb = self.linear(self.silu(emb))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class AdaLNZero_Out(nn.Module):
def __init__(self, dim):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(dim, dim * 2)
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb):
emb = self.linear(self.silu(emb))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
class Attention(nn.Module):
def __init__(self, encoder_dim, encoder_n_heads, max_seq_len):
super().__init__()
self.encoder_n_kv_heads = encoder_n_heads
model_parallel_size = 1
self.n_local_heads = encoder_n_heads // model_parallel_size
self.n_local_kv_heads = self.encoder_n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = encoder_dim // encoder_n_heads
self.wq = nn.Linear(
encoder_dim,
encoder_n_heads * self.head_dim,
)
self.wk = nn.Linear(
encoder_dim,
self.encoder_n_kv_heads * self.head_dim,
)
self.wv = nn.Linear(
encoder_dim,
self.encoder_n_kv_heads * self.head_dim,
)
self.wo = nn.Linear(
encoder_n_heads * self.head_dim,
encoder_dim,
)
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = xk.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
values = xv.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
output = F.scaled_dot_product_attention(xq, keys, values, mask[:, None, None, :], is_causal=False)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
):
super().__init__()
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(
dim, hidden_dim
)
self.w2 = nn.Linear(
hidden_dim, dim
)
def forward(self, x):
return self.w2(F.silu(self.w1(x)))
class TransformerBlock(nn.Module):
def __init__(self, encoder_dim, encoder_n_heads, max_seq_len):
super().__init__()
self.encoder_n_heads = encoder_n_heads
self.encoder_dim = encoder_dim
self.head_dim = encoder_dim // encoder_n_heads
self.attention = Attention(encoder_dim, encoder_n_heads, max_seq_len)
self.feed_forward = FeedForward(
dim=encoder_dim,
hidden_dim=2 * encoder_dim,
multiple_of=256,
ffn_dim_multiplier=None,
)
self.attention_norm = AdaLNZero(encoder_dim)
self.ffn_norm = nn.LayerNorm(encoder_dim, elementwise_affine=False, eps=1e-6)
def forward(
self,
x: torch.Tensor,
t: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
"""
Perform a forward pass through the TransformerBlock.
Args:
x (torch.Tensor): Input tensor.
start_pos (int): Starting position for attention caching.
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.
Returns:
torch.Tensor: Output tensor after applying attention and feedforward layers.
"""
# pre-norm & modulation for attention input
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attention_norm(x, emb=t)
# attention
attn_output = self.attention(norm, start_pos, freqs_cis, mask=mask)
# process attention output for input x
h = x + gate_msa.unsqueeze(1) * attn_output
norm = self.ffn_norm(h) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.feed_forward(norm)
out = h + gate_mlp.unsqueeze(1) * ff_output
return out
class Transformer(nn.Module):
def __init__(self, encoder_n_layers, encoder_dim, encoder_n_heads, max_seq_len):
super().__init__()
# Decoder
self.layers = torch.nn.ModuleList()
for _ in range(encoder_n_layers):
self.layers.append(TransformerBlock(encoder_dim, encoder_n_heads, max_seq_len))
self.norm = AdaLNZero_Out(encoder_dim)
self.out_proj = nn.Linear(encoder_dim, encoder_dim)
# Rope embedding
freqs_cis = precompute_freqs_cis(
encoder_dim // encoder_n_heads, max_seq_len
)
self.register_buffer("freqs_cis", torch.view_as_real(freqs_cis), persistent=False)
def forward(self, x, t, attn_mask, start_pos=0):
freqs_cis = torch.view_as_complex(self.freqs_cis.float())[start_pos: start_pos + x.size(1)]
for i, layer in enumerate(self.layers):
x = layer(x, t, start_pos, freqs_cis, attn_mask)
x = self.norm(x, t)
x = self.out_proj(x)
return x
================================================
FILE: tts/modules/wavvae/decoder/diag_gaussian.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import numpy as np
class DiagonalGaussianDistribution(object):
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
)
def sample(self, generator=None) -> torch.Tensor:
# make sure sample is on the same device as the parameters and has same dtype
sample = torch.randn(
self.mean.shape,
generator=generator,
device=self.parameters.device,
dtype=self.parameters.dtype,
)
x = self.mean + self.std * sample
return x
def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar
else:
return 0.5 * (
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar
)
def nll(self, sample, dims) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self) -> torch.Tensor:
return self.mean
================================================
FILE: tts/modules/wavvae/decoder/hifigan_modules.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.utils.data
from librosa.filters import mel as librosa_mel_fn
from torch.nn.utils import weight_norm, remove_weight_norm
from torch.nn import Conv1d
import numpy as np
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size*dilation - dilation)/2)
class Upsample(nn.Module):
def __init__(self, mult, r):
super(Upsample, self).__init__()
self.r = r
self.upsample = nn.Sequential(nn.Upsample(mode="nearest", scale_factor=r),
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(3),
nn.utils.weight_norm(nn.Conv1d(mult, mult // 2, kernel_size=7, stride=1))
)
r_kernel = r if r >= 5 else 5
self.trans_upsample = nn.Sequential(nn.LeakyReLU(0.2),
nn.utils.weight_norm(nn.ConvTranspose1d(mult, mult // 2,
kernel_size=r_kernel * 2, stride=r,
padding=r_kernel - r // 2,
output_padding=r % 2)
))
def forward(self, x):
x = torch.sin(x) + x
out1 = self.upsample(x)
out2 = self.trans_upsample(x)
return out1 + out2
class Downsample(nn.Module):
def __init__(self, mult, r):
super(Downsample, self).__init__()
self.r = r
r_kernel = r if r >= 5 else 5
self.trans_downsample = nn.Sequential(nn.LeakyReLU(0.2),
nn.utils.weight_norm(nn.Conv1d(mult, mult * 2,
kernel_size=r_kernel * 2, stride=r,
padding=r_kernel - r // 2)
))
def forward(self, x):
out = self.trans_downsample(x)
return out
def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find("BatchNorm2d") != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def weights_zero_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.fill_(0.0)
m.bias.data.fill_(0.0)
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
class Audio2Mel(nn.Module):
def __init__(
self,
hop_length=300,
sampling_rate=24000,
n_mel_channels=80,
mel_fmin=0.,
mel_fmax=None,
frame_size=0.05,
device='cpu'
):
super().__init__()
##############################################
# FFT Parameters #
##############################################
self.n_fft = int(np.power(2., np.ceil(np.log(sampling_rate * frame_size) / np.log(2))))
window = torch.hann_window(int(sampling_rate * frame_size)).float()
mel_basis = librosa_mel_fn(
sampling_rate, self.n_fft, n_mel_channels, mel_fmin, mel_fmax
) # Mel filter (by librosa)
mel_basis = torch.from_numpy(mel_basis).float()
self.register_buffer("mel_basis", mel_basis)
self.register_buffer("window", window)
self.hop_length = hop_length
self.win_length = int(sampling_rate * frame_size)
self.sampling_rate = sampling_rate
self.n_mel_channels = n_mel_channels
def forward(self, audio):
fft = torch.stft(
audio.squeeze(1),
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
center=True,
)
real_part, imag_part = fft.unbind(-1)
magnitude = torch.sqrt(torch.clamp(real_part ** 2 + imag_part ** 2, min=1e-5))
mel_output = torch.matmul(self.mel_basis, magnitude)
log_mel_spec = 20 * torch.log10(torch.clamp(mel_output, min=1e-5)) - 20
norm_mel = (log_mel_spec + 115.) / 115.
mel_comp = torch.clamp(norm_mel * 8. - 4., -4., 4.)
return mel_comp
class ResnetBlock(nn.Module):
def __init__(self, dim, dilation=1, dim_in=None):
super().__init__()
if dim_in is None:
dim_in = dim
self.block = nn.Sequential(
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(dilation),
WNConv1d(dim_in, dim, kernel_size=3, dilation=dilation),
nn.LeakyReLU(0.2),
WNConv1d(dim, dim, kernel_size=1),
)
self.shortcut = WNConv1d(dim_in, dim, kernel_size=1)
def forward(self, x):
return self.shortcut(x) + self.block(x)
'''
参照hifigan(https://arxiv.org/pdf/2010.05646.pdf)v2结构
多尺度主要是kernel_size不同,3组并行卷积模块,每个卷积模块内部采用不同的串行dilation size,且中间交叉正常无dilation卷积层
'''
class ResBlockMRFV2(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super(ResBlockMRFV2, self).__init__()
self.convs1 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2])))
])
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1)))
])
self.convs2.apply(init_weights)
def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, 0.2)
xt = c1(xt)
xt = F.leaky_relu(xt, 0.2)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class ResBlockMRFV2Inter(torch.nn.Module):
def __init__(self, channels, kernel_size=3):
super(ResBlockMRFV2Inter, self).__init__()
self.block1 = ResBlockMRFV2(channels)
self.block2 = ResBlockMRFV2(channels, 7)
self.block3 = ResBlockMRFV2(channels, 11)
def forward(self, x):
xs = self.block1(x)
xs += self.block2(x)
xs += self.block3(x)
x = xs / 3
return x
class Generator(nn.Module):
def __init__(self, input_size_, ngf, n_residual_layers, num_band, args, ratios=[5, 5, 4, 3], onnx_export=False,
device='cpu'):
super().__init__()
self.hop_length = args.frame_shift
self.args = args
self.onnx_export = onnx_export
# ------------- Define upsample layers ----------------
mult = int(2 ** len(ratios))
model_up = []
input_size = input_size_
model_up += [
nn.ReflectionPad1d(3),
WNConv1d(input_size, mult * ngf, kernel_size=7, padding=0),
]
# Upsample to raw audio scale
for i, r in enumerate(ratios):
model_up += [Upsample(mult * ngf, r)]
model_up += [ResBlockMRFV2Inter(mult * ngf // 2)]
mult //= 2
model_up += [
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(3),
WNConv1d(ngf, num_band, kernel_size=7, padding=0),
nn.Tanh(),
]
if not args.use_tanh:
model_up[-1] = nn.Conv1d(num_band, num_band, 1)
model_up[-2].apply(weights_zero_init)
self.model_up = nn.Sequential(*model_up)
self.apply(weights_init)
def forward(self, mel, step=None):
# mel input: (batch_size, seq_num, 80)
if self.onnx_export:
mel = mel.transpose(1, 2)
# on onnx, for engineering, mel input: (batch_size, 80, seq_num)
# Between Down and up
x = mel
# Upsample pipline
cnt_after_upsample = 0
for i, m in enumerate(self.model_up):
x = m(x)
if type(m) == Upsample:
cnt_after_upsample += 1
return x
================================================
FILE: tts/modules/wavvae/decoder/seanet_encoder.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import torch
from torch import nn
from tts.modules.wavvae.encoder.common_modules.seanet import SEANetEncoder
class Encoder(nn.Module):
def __init__(
self,
dowmsamples: List[int] = [6, 5, 5, 4, 2],
):
super().__init__()
# breakpoint()
self.frame_rate = 25 # not use
self.encoder = SEANetEncoder(causal=False, n_residual_layers=1, norm='weight_norm', pad_mode='reflect', lstm=2,
dimension=512, channels=1, n_filters=32, ratios=dowmsamples, activation='ELU',
kernel_size=7, residual_kernel_size=3, last_kernel_size=7, dilation_base=2,
true_skip=False, compress=2)
def forward(self, audio: torch.Tensor):
audio = audio.unsqueeze(1) # audio(16,24000)
emb = self.encoder(audio)
return emb
================================================
FILE: tts/modules/wavvae/decoder/wavvae_v3.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import torch
from torch import nn
import torch.nn.functional as F
from tts.modules.wavvae.decoder.seanet_encoder import Encoder
from tts.modules.wavvae.decoder.diag_gaussian import DiagonalGaussianDistribution
from tts.modules.wavvae.decoder.hifigan_modules import Generator, Upsample
class WavVAE_V3(nn.Module):
def __init__(self, hparams=None):
super().__init__()
self.encoder = Encoder(dowmsamples=[6, 5, 4, 4, 2])
self.proj_to_z = nn.Linear(512, 64)
self.proj_to_decoder = nn.Linear(32, 320)
config_path = hparams['melgan_config']
args = argparse.Namespace()
args.__dict__.update(config_path)
self.latent_upsampler = Upsample(320, 4)
self.decoder = Generator(
input_size_=160, ngf=128, n_residual_layers=4,
num_band=1, args=args, ratios=[5,4,4,3])
''' encode waveform into 25 hz latent representation '''
def encode_latent(self, audio):
posterior = self.encode(audio)
latent = posterior.sample().permute(0, 2, 1) # (b,t,latent_channel)
return latent
def encode(self, audio):
x = self.encoder(audio).permute(0, 2, 1)
x = self.proj_to_z(x).permute(0, 2, 1)
poseterior = DiagonalGaussianDistribution(x)
return poseterior
def decode(self, latent):
latent = self.proj_to_decoder(latent).permute(0, 2, 1)
return self.decoder(self.latent_upsampler(latent))
def forward(self, audio):
posterior = self.encode(audio)
latent = posterior.sample().permute(0, 2, 1) # (b, t, latent_channel)
recon_wav = self.decode(latent)
return recon_wav, posterior
================================================
FILE: tts/modules/wavvae/encoder/common_modules/conv.py
================================================
# MIT License
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# Copyright (c) [2023] [Meta Platforms, Inc. and affiliates.]
# Copyright (c) [2025] [Ziyue Jiang]
# SPDX-License-Identifier: MIT
# This file has been modified by Ziyue Jiang on 2025/03/19
# Original file was released under MIT, with the full license text # available at https://github.com/facebookresearch/encodec/blob/gh-pages/LICENSE.
# This modified file is released under the same license.
"""Convolutional layers wrappers and utilities."""
import math
import typing as tp
import warnings
import einops
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.utils import spectral_norm, weight_norm
CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
'time_layer_norm', 'layer_norm', 'time_group_norm'])
def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
assert norm in CONV_NORMALIZATIONS
if norm == 'weight_norm':
return weight_norm(module)
elif norm == 'spectral_norm':
return spectral_norm(module)
else:
return module
def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
assert norm in CONV_NORMALIZATIONS
if norm == 'layer_norm':
assert isinstance(module, nn.modules.conv._ConvNd)
return ConvLayerNorm(module.out_channels, **norm_kwargs)
elif norm == 'time_group_norm':
if causal:
raise ValueError("GroupNorm doesn't support causal evaluation.")
assert isinstance(module, nn.modules.conv._ConvNd)
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
else:
return nn.Identity()
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
padding_total: int = 0) -> int:
length = x.shape[-1]
n_frames = (length - kernel_size + padding_total) / stride + 1
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
return ideal_length - length
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
length = x.shape[-1]
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
if mode == 'reflect':
max_pad = max(padding_left, padding_right)
extra_pad = 0
if length <= max_pad:
extra_pad = max_pad - length + 1
x = F.pad(x, (0, extra_pad))
padded = F.pad(x, paddings, mode, value)
end = padded.shape[-1] - extra_pad
return padded[..., :end]
else:
return F.pad(x, paddings, mode, value)
class ConvLayerNorm(nn.LayerNorm):
def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
super().__init__(normalized_shape, **kwargs)
def forward(self, x):
x = einops.rearrange(x, 'b ... t -> b t ...')
x = super().forward(x)
x = einops.rearrange(x, 'b t ... -> b ... t')
return
class NormConv1d(nn.Module):
def __init__(self, *args, causal: bool = False, norm: str = 'none',
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
super().__init__()
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
self.norm_type = norm
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
return x
class SConv1d(nn.Module):
def __init__(self, in_channels: int, out_channels: int,
kernel_size: int, stride: int = 1, dilation: int = 1,
groups: int = 1, bias: bool = True, causal: bool = False,
norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
pad_mode: str = 'reflect'):
super().__init__()
# warn user on unusual setup between dilation and stride
if stride > 1 and dilation > 1:
warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1'
f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
dilation=dilation, groups=groups, bias=bias, causal=causal,
norm=norm, norm_kwargs=norm_kwargs)
self.causal = causal
self.pad_mode = pad_mode
def forward(self, x):
B, C, T = x.shape
kernel_size = self.conv.conv.kernel_size[0]
stride = self.conv.conv.stride[0]
dilation = self.conv.conv.dilation[0]
kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
padding_total = kernel_size - stride
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
if self.causal:
# Left padding for causal
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
else:
# Asymmetric padding required for odd strides
padding_right = padding_total // 2
padding_left = padding_total - padding_right
x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
return self.conv(x)
================================================
FILE: tts/modules/wavvae/encoder/common_modules/lstm.py
================================================
# MIT License
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# Copyright (c) [2023] [Meta Platforms, Inc. and affiliates.]
# Copyright (c) [2025] [Ziyue Jiang]
# SPDX-License-Identifier: MIT
# This file has been modified by Ziyue Jiang on 2025/03/19
# Original file was released under MIT, with the full license text # available at https://github.com/facebookresearch/encodec/blob/gh-pages/LICENSE.
# This modified file is released under the same license.
"""LSTM layers module."""
from torch import nn
class SLSTM(nn.Module):
"""
LSTM without worrying about the hidden state, nor the layout of the data.
Expects input as convolutional layout.
"""
def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
super().__init__()
self.skip = skip
self.lstm = nn.LSTM(dimension, dimension, num_layers)
# 修改transpose顺序
def forward(self, x):
x1 = x.permute(2, 0, 1)
y, _ = self.lstm(x1)
y = y.permute(1, 2, 0)
if self.skip:
y = y + x
return y
================================================
FILE: tts/modules/wavvae/encoder/common_modules/seanet.py
================================================
# MIT License
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# Copyright (c) [2023] [Meta Platforms, Inc. and affiliates.]
# Copyright (c) [2025] [Ziyue Jiang]
# SPDX-License-Identifier: MIT
# This file has been modified by Ziyue Jiang on 2025/03/19
# Original file was released under MIT, with the full license text # available at https://github.com/facebookresearch/encodec/blob/gh-pages/LICENSE.
# This modified file is released under the same license.
"""Encodec SEANet-based encoder and decoder implementation."""
import typing as tp
import numpy as np
import torch.nn as nn
from .conv import SConv1d
from .lstm import SLSTM
class SEANetResnetBlock(nn.Module):
def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False,
pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True):
super().__init__()
assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations'
act = getattr(nn, activation)
hidden = dim // compress
block = []
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
in_chs = dim if i == 0 else hidden
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
block += [
act(**activation_params),
SConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
norm=norm, norm_kwargs=norm_params,
causal=causal, pad_mode=pad_mode),
]
self.block = nn.Sequential(*block)
self.shortcut: nn.Module
if true_skip:
self.shortcut = nn.Identity()
else:
self.shortcut = SConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params,
causal=causal, pad_mode=pad_mode)
def forward(self, x):
return self.shortcut(x) + self.block(x)
class SEANetEncoder(nn.Module):
def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 1,
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
pad_mode: str = 'reflect', true_skip: bool = False, compress: int = 2, lstm: int = 2):
super().__init__()
self.channels = channels
self.dimension = dimension
self.n_filters = n_filters
self.ratios = list(reversed(ratios))
del ratios
self.n_residual_layers = n_residual_layers
self.hop_length = np.prod(self.ratios)
act = getattr(nn, activation)
mult = 1
model: tp.List[nn.Module] = [
SConv1d(channels, mult * n_filters, kernel_size, norm=norm, norm_kwargs=norm_params,
causal=causal, pad_mode=pad_mode)
]
# Downsample to raw audio scale
for i, ratio in enumerate(self.ratios):
# Add residual layers
for j in range(n_residual_layers):
model += [
SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
dilations=[dilation_base ** j, 1],
norm=norm, norm_params=norm_params,
activation=activation, activation_params=activation_params,
causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
# Add downsampling layers
model += [
act(**activation_params),
SConv1d(mult * n_filters, mult * n_filters * 2,
kernel_size=ratio * 2, stride=ratio,
norm=norm, norm_kwargs=norm_params,
causal=causal, pad_mode=pad_mode),
]
mult *= 2
if lstm:
model += [SLSTM(mult * n_filters, num_layers=lstm)]
model += [
act(**activation_params),
SConv1d(mult * n_filters, dimension, last_kernel_size, norm=norm, norm_kwargs=norm_params,
causal=causal, pad_mode=pad_mode)
]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
================================================
FILE: tts/utils/audio_utils/align.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
def mel2token_to_dur(mel2token, T_txt=None, max_dur=None):
is_torch = isinstance(mel2token, torch.Tensor)
has_batch_dim = True
if not is_torch:
mel2token = torch.LongTensor(mel2token)
if T_txt is None:
T_txt = mel2token.max()
if len(mel2token.shape) == 1:
mel2token = mel2token[None, ...]
has_batch_dim = False
B, _ = mel2token.shape
dur = mel2token.new_zeros(B, T_txt + 1).scatter_add(1, mel2token, torch.ones_like(mel2token))
dur = dur[:, 1:]
if max_dur is not None:
dur = dur.clamp(max=max_dur)
if not is_torch:
dur = dur.numpy()
if not has_batch_dim:
dur = dur[0]
return dur
================================================
FILE: tts/utils/audio_utils/io.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
import subprocess
import numpy as np
from scipy.io import wavfile
import pyloudnorm as pyln
from pydub import AudioSegment
def to_wav_bytes(wav, sr, norm=False):
wav = wav.astype(float)
if norm:
meter = pyln.Meter(sr) # create BS.1770 meter
loudness = meter.integrated_loudness(wav)
wav = pyln.normalize.loudness(wav, loudness, -18.0)
if np.abs(wav).max() >= 1:
wav = wav / np.abs(wav).max() * 0.95
wav = wav * 32767
bytes_io = io.BytesIO()
wavfile.write(bytes_io, sr, wav.astype(np.int16))
return bytes_io.getvalue()
def save_wav(wav_bytes, path):
with open(path[:-4] + '.wav', 'wb') as file:
file.write(wav_bytes)
if path[-4:] == '.mp3':
to_mp3(path[:-4])
def to_mp3(out_path):
if out_path[-4:] == '.wav':
out_path = out_path[:-4]
subprocess.check_call(
f'ffmpeg -threads 1 -loglevel error -i "{out_path}.wav" -vn -b:a 192k -y -hide_banner -async 1 "{out_path}.mp3"',
shell=True, stdin=subprocess.PIPE)
subprocess.check_call(f'rm -f "{out_path}.wav"', shell=True)
def convert_to_wav(wav_path):
# Check if the file exists
if not os.path.exists(wav_path):
print(f"The file '{wav_path}' does not exist.")
return
# Check if the file already has a .wav extension
if not wav_path.endswith(".wav"):
# Define the output path with a .wav extension
out_path = os.path.splitext(wav_path)[0] + ".wav"
# Load the audio file using pydub and convert it to WAV
audio = AudioSegment.from_file(wav_path)
audio.export(out_path, format="wav")
print(f"Converted '{wav_path}' to '{out_path}'")
def convert_to_wav_bytes(audio_binary):
# Load the audio binary using pydub and convert it to WAV
audio = AudioSegment.from_file(io.BytesIO(audio_binary))
wav_bytes = io.BytesIO()
audio.export(wav_bytes, format="wav")
wav_bytes.seek(0)
return wav_bytes
''' Smoothly combine audio segments using crossfade transitions." '''
def combine_audio_segments(segments, crossfade_duration=0.16, sr=24000):
window_length = int(sr * crossfade_duration)
hanning_window = np.hanning(2 * window_length)
# Combine
for i, segment in enumerate(segments):
if i == 0:
combined_audio = segment
else:
overlap = combined_audio[-window_length:] * hanning_window[window_length:] + segment[:window_length] * hanning_window[:window_length]
combined_audio = np.concatenate(
[combined_audio[:-window_length], overlap, segment[window_length:]]
)
return combined_audio
================================================
FILE: tts/utils/audio_utils/plot.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import torch
LINE_COLORS = ['w', 'r', 'orange', 'k', 'cyan', 'm', 'b', 'lime', 'g', 'brown', 'navy']
def spec_to_figure(spec, vmin=None, vmax=None, title='', f0s=None, dur_info=None, figsize=(12, 6)):
if isinstance(spec, torch.Tensor):
spec = spec.cpu().numpy()
H = spec.shape[1] // 2
fig = plt.figure(figsize=figsize)
plt.title(title)
plt.pcolor(spec.T, vmin=vmin, vmax=vmax)
if dur_info is not None:
assert isinstance(dur_info, dict)
txt = dur_info['txt']
dur_gt = dur_info['dur_gt']
if isinstance(dur_gt, torch.Tensor):
dur_gt = dur_gt.cpu().numpy()
dur_gt = np.cumsum(dur_gt).astype(int)
for i in range(len(dur_gt)):
shift = (i % 8) + 1
plt.text(dur_gt[i], shift * 4, txt[i])
plt.vlines(dur_gt[i], 0, H // 2, colors='b') # blue is gt
plt.xlim(0, dur_gt[-1])
if 'dur_pred' in dur_info:
dur_pred = dur_info['dur_pred']
if isinstance(dur_pred, torch.Tensor):
dur_pred = dur_pred.cpu().numpy()
dur_pred = np.cumsum(dur_pred).astype(int)
for i in range(len(dur_pred)):
shift = (i % 8) + 1
plt.text(dur_pred[i], H + shift * 4, txt[i])
plt.vlines(dur_pred[i], H, H * 1.5, colors='r') # red is pred
plt.xlim(0, max(dur_gt[-1], dur_pred[-1]))
if f0s is not None:
ax = plt.gca()
ax2 = ax.twinx()
# ax.set_xticks()
if not isinstance(f0s, dict):
f0s = {'f0': f0s}
for i, (k, f0) in enumerate(f0s.items()):
if f0 is not None:
if isinstance(f0, torch.Tensor):
f0 = f0.cpu().numpy()
ax2.plot(
np.arange(len(f0)) + 0.5, f0, label=k, c=LINE_COLORS[i], linewidth=1, alpha=0.5)
ax2.set_ylim(0, 1000)
ax2.legend()
return fig
def align_to_figure(align, dur_info):
if isinstance(align, torch.Tensor):
align = align.cpu().numpy()
H = align.shape[1]
fig = plt.figure(figsize=(12, 6))
plt.pcolor(align.T, vmin=0, vmax=1)
if dur_info is not None:
assert isinstance(dur_info, dict)
txt = dur_info['txt']
dur_gt = dur_info['dur_gt']
if isinstance(dur_gt, torch.Tensor):
dur_gt = dur_gt.cpu().numpy()
dur_gt = np.cumsum(dur_gt).astype(int) // 2
for i in range(len(dur_gt)):
plt.text(dur_gt[i], i, txt[i], color='red')
plt.vlines(dur_gt[i], 0, H, colors='b') # blue is gt
# plt.xlim(0, dur_gt[-1])
return fig
================================================
FILE: tts/utils/commons/ckpt_utils.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import glob
import os
import re
import subprocess
import traceback
import torch
from torch.nn.parallel import DistributedDataParallel
import torch.distributed as dist
@contextlib.contextmanager
def dist_load(path):
if not dist.is_initialized() or dist.get_world_size() == 1 or os.path.realpath(path).startswith('/dev/shm'):
yield path
else:
from tts.utils.commons.hparams import hparams
from tts.utils.commons.trainer import LOCAL_RANK
tmpdir = '/dev/shm'
assert len(os.path.basename(path)) > 0
shm_ckpt_path = f'{tmpdir}/{hparams["exp_name"]}/{os.path.basename(path)}'
if LOCAL_RANK == 0:
subprocess.check_call(
f'mkdir -p {os.path.dirname(shm_ckpt_path)}; '
f'cp -Lr {path} {shm_ckpt_path}', shell=True)
dist.barrier()
yield shm_ckpt_path
dist.barrier()
if LOCAL_RANK == 0:
subprocess.check_call(f'rm -rf {shm_ckpt_path}', shell=True)
def torch_load_dist(path, map_location='cpu'):
with dist_load(path) as tmp_path:
checkpoint = torch.load(tmp_path, map_location=map_location)
return checkpoint
def get_last_checkpoint(work_dir, steps=None):
checkpoint = None
last_ckpt_path = None
ckpt_paths = get_all_ckpts(work_dir, steps)
if len(ckpt_paths) > 0:
last_ckpt_path = ckpt_paths[0]
checkpoint = torch_load_dist(last_ckpt_path, map_location='cpu')
return checkpoint, last_ckpt_path
def get_all_ckpts(work_dir, steps=None):
if steps is None or steps == 0:
ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_*.ckpt'
else:
ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_{steps}.ckpt'
return sorted(glob.glob(ckpt_path_pattern),
key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0]))
def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True,
silent=False, load_opt=False, opts=None, steps=None, checkpoint=None, ckpt_path='', delete_unmatch=True):
if checkpoint is None:
if os.path.isfile(ckpt_base_dir):
base_dir = os.path.dirname(ckpt_base_dir)
ckpt_path = ckpt_base_dir
checkpoint = torch_load_dist(ckpt_base_dir, map_location='cpu')
else:
base_dir = ckpt_base_dir
if load_opt:
checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir, steps)
else:
ckpt_path = f'{ckpt_base_dir}/model_only_last.ckpt'
if os.path.exists(ckpt_path):
checkpoint = torch_load_dist(ckpt_path, map_location='cpu')
else:
checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir, steps)
if checkpoint is not None:
state_dict_all = {
k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in checkpoint["state_dict"].items()}
if not isinstance(cur_model, list):
cur_models = [cur_model]
model_names = [model_name]
else:
cur_models = cur_model
model_names = model_name
for model_name, cur_model in zip(model_names, cur_models):
if isinstance(cur_model, DistributedDataParallel):
cur_model = cur_model.module
device = next(cur_model.parameters()).device
if '.' not in model_name:
state_dict = state_dict_all[model_name]
else:
base_model_name = model_name.split('.')[0]
rest_model_name = model_name[len(base_model_name) + 1:]
state_dict = {
k[len(rest_model_name) + 1:]: v for k, v in state_dict_all[base_model_name].items()
if k.startswith(f'{rest_model_name}.')}
state_dict = {k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in state_dict.items()}
if not strict and delete_unmatch:
try:
cur_model.load_state_dict(state_dict, strict=True)
if not silent:
print(f"| loaded '{model_name}' from '{ckpt_path}' with strict=True.")
except:
cur_model_state_dict = cur_model.state_dict()
cur_model_state_dict = {k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in
cur_model_state_dict.items()}
unmatched_keys = []
for key, param in state_dict.items():
if key in cur_model_state_dict:
new_param = cur_model_state_dict[key]
if new_param.shape != param.shape:
unmatched_keys.append(key)
print("| Unmatched keys: ", key, "cur model: ", new_param.shape,
"ckpt model: ", param.shape)
for key in unmatched_keys:
del state_dict[key]
load_results = cur_model.load_state_dict(state_dict, strict=strict)
cur_model.to(device)
if not silent:
print(f"| loaded '{model_name}' from '{ckpt_path}'.")
missing_keys, unexpected_keys = load_results.missing_keys, load_results.unexpected_keys
print(f"| Missing keys: {len(missing_keys)}, Unexpected keys: {len(unexpected_keys)}")
if load_opt:
optimizer_states = checkpoint['optimizer_states']
assert len(opts) == len(optimizer_states)
for optimizer, opt_state in zip(opts, optimizer_states):
opt_state = {k.replace('_orig_mod.', ''): v for k, v in opt_state.items()}
if optimizer is None:
return
try:
optimizer.load_state_dict(opt_state)
for i, state in enumerate(optimizer.state.values()):
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(device)
except ValueError:
print(f"| WARMING: optimizer {optimizer} parameters not match !!!")
return checkpoint.get('global_step', 0)
else:
e_msg = f"| ckpt not found in {base_dir}."
if force:
assert False, e_msg
else:
print(e_msg)
def load_with_size_mismatch(model, state_dict, prefix=""):
current_model_dict = model.state_dict()
cm_keys = current_model_dict.keys()
mismatch_keys = {k.replace(prefix, "") for k, v in state_dict.items() if k.replace(prefix, "") in cm_keys and v.size() != current_model_dict[k.replace(prefix, "")].size()}
new_state_dict = {k.replace(prefix, ""): v for k, v in state_dict.items() if k.replace(prefix, "") in cm_keys and v.size() == current_model_dict[k.replace(prefix, "")].size()}
missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
print(f"| mismatch keys: ", mismatch_keys)
if len(missing_keys) > 0:
print(f"| missing_keys in dit: {missing_keys}")
if len(unexpected_keys) > 0:
print(f"| unexpected_keys in dit: {unexpected_keys}")
================================================
FILE: tts/utils/commons/hparams.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import re
import ast
import yaml
global_print_hparams = True
hparams = {}
class Args:
def __init__(self, **kwargs):
for k, v in kwargs.items():
self.__setattr__(k, v)
def override_config(old_config: dict, new_config: dict):
if new_config.get('__replace', False):
old_config.clear()
for k, v in new_config.items():
if isinstance(v, dict) and k in old_config:
override_config(old_config[k], new_config[k])
else:
old_config[k] = v
def traverse_dict(d, func, ctx):
for k in list(d.keys()):
v = d[k]
if isinstance(v, dict):
traverse_dict(v, func, ctx)
else:
d[k] = func(v, ctx)
def parse_config(v, context=None):
if context is None:
context = {}
if isinstance(v, str):
if v.startswith('^'):
return load_config(v[1:], [], set())
match = re.match(r"\${(.*)}", v)
if match:
expression = match.group(1)
return ast.literal_eval(expression, {}, context)
return v
def remove_meta_key(d):
for k in list(d.keys()):
v = d[k]
if isinstance(v, dict):
remove_meta_key(v)
else:
if k[:2] == '__':
del d[k]
def load_config(config_fn, config_chains, loaded_configs):
# deep first inheritance and avoid the second visit of one node
if not os.path.exists(config_fn):
print(f"| WARN: {config_fn} not exist.", )
return {}
with open(config_fn) as f:
hparams_ = yaml.safe_load(f)
loaded_configs.add(config_fn)
if 'base_config' in hparams_:
ret_hparams = {}
if not isinstance(hparams_['base_config'], list):
hparams_['base_config'] = [hparams_['base_config']]
for c in hparams_['base_config']:
if c.startswith('.'):
c = f'{os.path.dirname(config_fn)}/{c}'
c = os.path.normpath(c)
if c not in loaded_configs:
override_config(ret_hparams, load_config(c, config_chains, loaded_configs))
override_config(ret_hparams, hparams_)
else:
ret_hparams = hparams_
config_chains.append(config_fn)
return ret_hparams
def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True):
if config == '' and exp_name == '':
parser = argparse.ArgumentParser(description='')
parser.add_argument('--config', type=str, default='',
help='location of the data corpus')
parser.add_argument('--exp_name', type=str, default='', help='exp_name')
parser.add_argument('-hp', '--hparams', type=str, default='',
help='location of the data corpus')
parser.add_argument('--infer', action='store_true', help='infer')
parser.add_argument('--validate', action='store_true', help='validate')
parser.add_argument('--reset', action='store_true', help='reset hparams')
parser.add_argument('--remove', action='store_true', help='remove old ckpt')
parser.add_argument('--debug', action='store_true', help='debug')
parser.add_argument('--start_rank', type=int, default=-1,
help='the start rank id for DDP, keep 0 when single-machine multi-GPU')
parser.add_argument('--world_size', type=int, default=-1,
help='the total number of GPU used across all machines, keep -1 for single-machine multi-GPU')
parser.add_argument('--init_method', type=str, default='tcp', help='method to init ddp, use tcp or file')
parser.add_argument('--master_addr', type=str, default='', help='')
parser.add_argument('--ddp_dir', type=str, default='', help='')
args, unknown = parser.parse_known_args()
if print_hparams:
print("| set_hparams Unknow hparams: ", unknown)
else:
args = Args(config=config, exp_name=exp_name, hparams=hparams_str,
infer=False, validate=False, reset=False, debug=False, remove=False,
start_rank=-1, world_size=-1, init_method='tcp', ddp_dir='', master_addr='')
global hparams
assert args.config != '' or args.exp_name != ''
if args.config != '':
assert os.path.exists(args.config), f"{args.config} not exists"
saved_hparams = {}
args_work_dir = ''
if args.exp_name != '':
args_work_dir = f'{args.exp_name}'
ckpt_config_path = f'{args_work_dir}/config.yaml'
if os.path.exists(ckpt_config_path):
with open(ckpt_config_path) as f:
saved_hparams_ = yaml.safe_load(f)
if saved_hparams_ is not None:
saved_hparams.update(saved_hparams_)
hparams_ = {}
config_chains = []
if args.config != '':
hparams_.update(load_config(args.config, config_chains, set()))
if len(config_chains) > 1 and print_hparams:
print('| Hparams chains: ', config_chains)
if not args.reset:
hparams_.update(saved_hparams)
traverse_dict(hparams_, parse_config, hparams_)
hparams_['work_dir'] = args_work_dir
# Support config overriding in command line. Support list type config overriding.
# Examples: --hparams="a=1,b.c=2,d=[1 1 1]"
if args.hparams != "":
for new_hparam in args.hparams.split(","):
k, v = new_hparam.split("=")
v = v.strip("\'\" ")
config_node = hparams_
for k_ in k.split(".")[:-1]:
config_node = config_node[k_]
k = k.split(".")[-1]
if k in config_node:
if v in ['True', 'False'] or type(config_node[k]) in [bool, list, dict]:
if type(config_node[k]) == list:
v = v.replace(" ", ",").replace('^', "\"")
if '|' in v:
tp = type(config_node[k][0]) if len(config_node[k]) else str
config_node[k] = [tp(x) for x in v.split("|") if x != '']
continue
config_node[k] = ast.literal_eval(v)
else:
config_node[k] = type(config_node[k])(v)
else:
config_node[k] = v
try:
config_node[k] = float(v)
except:
pass
try:
config_node[k] = int(v)
except:
pass
if v.lower() in ['false', 'true']:
config_node[k] = v.lower() == 'true'
if args_work_dir != '' and not args.infer:
os.makedirs(hparams_['work_dir'], exist_ok=True)
hparams_['infer'] = args.infer
hparams_['debug'] = args.debug
hparams_['validate'] = args.validate
hparams_['exp_name'] = args.exp_name
hparams_['start_rank'] = args.start_rank # useful for multi-machine training
hparams_['world_size'] = args.world_size
hparams_['init_method'] = args.init_method
hparams_['ddp_dir'] = args.ddp_dir
hparams_['master_addr'] = args.master_addr
remove_meta_key(hparams_)
global global_print_hparams
if global_hparams:
hparams.clear()
hparams.update(hparams_)
if print_hparams and global_print_hparams and global_hparams:
print('| Hparams: ', json.dumps(hparams_, indent=2, sort_keys=True))
# for i, (k, v) in enumerate(sorted(hparams_.items())):
# print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "")
global_print_hparams = False
return hparams_
================================================
FILE: tts/utils/text_utils/dict.json
================================================
{"phone": ["C0a", "C0ai", "C0air", "C0an", "C0ang", "C0angr", "C0anr", "C0ao", "C0aor", "C0ar", "C0b", "C0c", "C0ch", "C0d", "C0e", "C0ei", "C0eir", "C0en", "C0eng", "C0engr", "C0enr", "C0er", "C0f", "C0g", "C0h", "C0i", "C0ia", "C0ian", "C0iang", "C0iangr", "C0ianr", "C0iao", "C0iaor", "C0iar", "C0ie", "C0ier", "C0ii", "C0iii", "C0iiir", "C0iir", "C0in", "C0ing", "C0ingr", "C0inr", "C0io", "C0iong", "C0iongr", "C0iou", "C0iour", "C0ir", "C0j", "C0k", "C0l", "C0m", "C0n", "C0ng", "C0o", "C0ong", "C0ongr", "C0or", "C0ou", "C0our", "C0p", "C0q", "C0r", "C0s", "C0sh", "C0t", "C0u", "C0ua", "C0uai", "C0uair", "C0uan", "C0uang", "C0uangr", "C0uanr", "C0uar", "C0uei", "C0ueir", "C0uen", "C0ueng", "C0uengr", "C0uenr", "C0uo", "C0uor", "C0ur", "C0v", "C0van", "C0vanr", "C0ve", "C0ver", "C0vn", "C0vnr", "C0vr", "C0x", "C0z", "C0zh", "C0_", "E0aa", "E0ae", "E0ah", "E0ao", "E0aw", "E0ax", "E0ay", "E0b", "E0ch", "E0d", "E0dh", "E0eh", "E0ehr", "E0er", "E0ey", "E0f", "E0g", "E0hh", "E0ih", "E0iy", "E0iyr", "E0jh", "E0k", "E0l", "E0m", "E0n", "E0ng", "E0oh", "E0ow", "E0oy", "E0p", "E0r", "E0s", "E0sh", "E0t", "E0th", "E0uh", "E0uw", "E0uwr", "E0v", "E0w", "E0y", "E0z", "E0zh", "sil", "…", "、", "。", "《", "》", "【", "】", "!", """, "#", "$", "%", "'", "''", "(", ")", "*", ",", ":", ";", "?", "\", "^", "_", "`", "{", "}", "~"], "tone": ["0", "1", "10", "11", "12", "13", "15", "17", "2", "3", "4", "5", "6", "7", "8", "9"], "wordCategory": ["0", "B", "E", "M", "S"], "prosody": ["0", "1", "2", "3", "4"], "focus": ["0", "1"], "intonation": ["0", "1", "2"], "phraseAccent": ["0", "H-", "L-"], "boundaryTone": ["0", "H%", "L%"], "accentType": ["!H*", "0", "H*", "L*", "L*+H", "L+H*"]}
================================================
FILE: tts/utils/text_utils/ph_tone_convert.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
def map_phone_to_tokendict(item, pad_bos_eos=True):
# Merge Chinese phone and tone (Original dict ends at 173, i.e., ph_dict_size=173). 146~173 is punctuations.
phone = item['txt_token'].clone()
merged_phone = item['txt_token'].clone()
tone_tmp = item['tone'].clone()
# In tone_dict, tone_1 is 4, tone_2 is 11, tone_3 is 12, tone_4 is 13, tone_5 is 14, tone_6 is 15
tone_tmp[tone_tmp==4] = 1
tone_tmp[tone_tmp==11] = 2
tone_tmp[tone_tmp==12] = 3
tone_tmp[tone_tmp==13] = 4
tone_tmp[tone_tmp==14] = 5
tone_tmp[tone_tmp==15] = 6
# Chinese phones lie in 3~100 in the phone_dict, we map them to 200~788
ch_phone_idx = (phone >= 3) & (phone <= 100)
merged_phone[ch_phone_idx] = (merged_phone[ch_phone_idx] - 3) * 6 + 200 + tone_tmp[ch_phone_idx]
if pad_bos_eos:
merged_phone = F.pad(merged_phone, (1, 0), mode='constant', value=798)
merged_phone = F.pad(merged_phone, (0, 1), mode='constant', value=799)
return merged_phone
def split_ph_timestamp(ph_timestamp):
''' Input: ph_timestamp, shape [T] '''
# Map the timestamp of each phone back to its original frame-level lengths
ph_timestamp[ph_timestamp >= 800] -= 800
ph_list = []
tone_list = []
dur_list = []
cur_timestamp = 0
for idx, item in enumerate(ph_timestamp):
if idx % 2 == 0:
# Map Chinese phones back to its original phone_dict
if (200 <= item <= 788):
ph = (item - 200 - 1) // 6 + 3
tone = (item - 200 - 1) % 6 + 1
if tone == 1:
tone = 4
else:
tone = tone + 9
# Set English tone to '3'
else:
ph = item
tone = 3
ph_list.append(ph)
tone_list.append(tone)
else:
dur_list.append((item - cur_timestamp))
cur_timestamp = item
assert len(ph_list) == len(dur_list), f"{len(ph_list)}, {len(dur_list)}"
ph_seq, tone_seq, dur_seq = torch.LongTensor(ph_list), torch.LongTensor(tone_list), torch.LongTensor(dur_list)
return ph_seq, tone_seq, dur_seq, ph_timestamp[-1]
def split_ph(ph_seq):
''' Input: ph_timestamp, shape [T] '''
ph_list = []
tone_list = []
for idx, item in enumerate(ph_seq):
# Map Chinese phones back to its original phone_dict
if (200 <= item <= 788):
ph = (item - 200 - 1) // 6 + 3
tone = (item - 200 - 1) % 6 + 1
if tone == 1:
tone = 4
else:
tone = tone + 9
# Set English tone to '3'
else:
ph = item
tone = 3
ph_list.append(ph)
tone_list.append(tone)
assert len(ph_list) == len(tone_list)
ph_seq, tone_seq = torch.LongTensor(ph_list), torch.LongTensor(tone_list)
return ph_seq, tone_seq
================================================
FILE: tts/utils/text_utils/split_text.py
================================================
# -*- coding: utf-8 -*-
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
def chunk_text_chinese(text, limit=60):
# 中文字符匹配
chinese_pattern = re.compile(r'[\u4e00-\u9fff]')
# 标点符号匹配
punctuation = r",。!?;:,.!?;:"
result = [] # 存储断句结果
current_chunk = [] # 当前片段
chinese_count = 0 # 中文字符计数
i = 0
while i < len(text):
char = text[i]
current_chunk.append(char)
if chinese_pattern.match(char):
chinese_count += 1
if chinese_count >= limit: # 达到限制字符数
# 从当前位置往前找最近的标点符号
for j in range(len(current_chunk) - 1, -1, -1):
if current_chunk[j] in punctuation:
result.append(''.join(current_chunk[:j + 1]))
current_chunk = current_chunk[j + 1:]
chinese_count = sum(1 for c in current_chunk if chinese_pattern.match(c))
break
else:
# 如果前面没有标点符号,则继续找后面的标点符号
for k in range(i + 1, len(text)):
if text[k] in punctuation:
result.append(''.join(current_chunk)+text[i+1:k+1])
current_chunk = []
chinese_count = 0
i = k
break
i+=1
# 添加最后剩余的部分
if current_chunk:
result.append(''.join(current_chunk))
return result
def chunk_text_english(text, max_chars=130):
"""
Splits the input text into chunks, each with a maximum number of characters.
Args:
text (str): The text to be split.
max_chars (int): The maximum number of characters per chunk.
Returns:
List[str]: A list of text chunks.
"""
chunks = []
current_chunk = ""
# Split the text into sentences based on punctuation followed by whitespace
sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
for sentence in sentences:
if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
def chunk_text_chinesev2(text, limit=60, look_ahead_limit=30):
"""
将中文文本分成多个块,优先确保每个块以句号、感叹号或问号结尾,
其次考虑逗号等其他标点符号,避免在无标点处断句
参数:
text: 要分块的文本
limit: 每个块的中文字符数限制
look_ahead_limit: 向后查找的最大字符数限制
返回:
分块后的文本列表
"""
# 中文字符匹配
chinese_pattern = re.compile(r'[\u4e00-\u9fff]')
# 分级定义标点符号(优先级从高到低)
primary_end_marks = "。.!!??" # 首选:句号、感叹号、问号
secondary_end_marks = ",,;;:" # 次选:逗号、分号、冒号
tertiary_end_marks = "、…—-~~" # 再次:顿号、省略号、破折号等
result = [] # 存储断句结果
current_chunk = [] # 当前片段
chinese_count = 0 # 中文字符计数
i = 0
while i < len(text):
char = text[i]
current_chunk.append(char)
if chinese_pattern.match(char):
chinese_count += 1
if chinese_count >= limit: # 达到字符数限制,需要寻找断句点
found_end = False
# 依次尝试不同优先级的断句策略
# 1. 向后查找首选标点
for k in range(1, min(look_ahead_limit, len(text) - i)):
next_char = text[i + k]
if next_char in primary_end_marks:
result.append(''.join(current_chunk) + text[i+1:i+k+1])
current_chunk = []
chinese_count = 0
i = i + k
found_end = True
break
if not found_end:
# 2. 向前查找首选标点
for j in range(len(current_chunk) - 1, -1, -1):
if current_chunk[j] in primary_end_marks:
result.append(''.join(current_chunk[:j + 1]))
current_chunk = current_chunk[j + 1:]
chinese_count = sum(1 for c in current_chunk if chinese_pattern.match(c))
found_end = True
break
if not found_end:
# 3. 向后查找次选标点
for k in range(1, min(look_ahead_limit, len(text) - i)):
next_char = text[i + k]
if next_char in secondary_end_marks:
result.append(''.join(current_chunk) + text[i+1:i+k+1])
current_chunk = []
chinese_count = 0
i = i + k
found_end = True
break
if not found_end:
# 4. 向前查找次选标点
for j in range(len(current_chunk) - 1, -1, -1):
if current_chunk[j] in secondary_end_marks:
result.append(''.join(current_chunk[:j + 1]))
current_chunk = current_chunk[j + 1:]
chinese_count = sum(1 for c in current_chunk if chinese_pattern.match(c))
found_end = True
break
if not found_end:
# 5. 向后查找三级标点
for k in range(1, min(look_ahead_limit, len(text) - i)):
next_char = text[i + k]
if next_char in tertiary_end_marks:
result.append(''.join(current_chunk) + text[i+1:i+k+1])
current_chunk = []
chinese_count = 0
i = i + k
found_end = True
break
if not found_end:
# 6. 向前查找三级标点
for j in range(len(current_chunk) - 1, -1, -1):
if current_chunk[j] in tertiary_end_marks:
result.append(''.join(current_chunk[:j + 1]))
current_chunk = current_chunk[j + 1:]
chinese_count = sum(1 for c in current_chunk if chinese_pattern.match(c))
found_end = True
break
if not found_end:
# 万不得已,在此处断句(这种情况很少见,因为汉语文本中通常会有标点)
result.append(''.join(current_chunk))
current_chunk = []
chinese_count = 0
i += 1
# 添加最后剩余的部分
if current_chunk:
result.append(''.join(current_chunk))
# 英文标点替换为中文标点
punctuation_map = {
'.': '。',
',': ',',
'!': '!',
'?': '?',
';': ';',
':': ':'
}
for i in range(len(result)):
for eng_punc, cn_punc in punctuation_map.items():
result[i] = result[i].replace(eng_punc, cn_punc)
return result
if __name__ == '__main__':
print(chunk_text_chinese("哇塞!家人们,你们太好运了。我居然发现了一个宝藏零食大礼包,简直适合所有人的口味!有香辣的,让你舌尖跳舞;有盐焗的,咸香可口;还有五香的,香气四溢。就连怀孕的姐妹都吃得津津有味!整整三十包啊!什么手撕蟹柳、辣子鸡、嫩豆干、手撕素肉、鹌鹑蛋、小肉枣肠、猪肉腐、魔芋、魔芋丝等等,应有尽有。香辣土豆爽辣过瘾,各种素肉嚼劲十足,鹌鹑蛋营养美味,真的太多太多啦,...家人们,现在价格太划算了,赶紧下单。"))
print(chunk_text_english("Washington CNN When President Donald Trump declared in the House Chamber this week that executives at the nation’s top automakers were “so excited” about their prospects amid his new tariff regime, it did not entirely reflect the conversation he’d held with them earlier that day."))
text = "欢迎收听《TED Talks Daily》,在这里,我们每天为您带来新思想,激发您的好奇心。我是您的主持人,Elise Hugh。当我们去看医生时,医生会评估我们的身体健康状况,检查我们的生命体征,可能还会关注我们的胆固醇水平,确保我们整体处于健康状态。医生可能还会通过一系列问题来检查我们的心理健康。然而,人际交往专家Casley Killam指出,我们在理解健康时忽略了一个关键指标,那就是我们的社会健康。在2024年的演讲中,她解释了为什么人际关系如此重要,以及忽视它可能带来的代价。几年前,我认识的一位女士,我们暂且称她为Maya,在短时间内经历了许多重大变化。她结婚了,和丈夫因工作搬到了一个陌生的城市,在那里她谁也不认识。她开始了一份在家办公的新工作,同时还要应对父亲新确诊的痴呆症。为了应对这些变化带来的压力,Maya加倍关注自己的身心健康。她几乎每天都锻炼,吃健康的食物,每周去看一次心理医生。这些措施确实有帮助,她的身体变得更加强壮,心理也更具韧性,但效果有限。她仍然感到困扰,经常在半夜失眠,白天感到注意力不集中,缺乏动力。Maya做了医生通常建议我们做的所有事情来保持身心健康,但似乎还缺少些什么。如果我告诉你,Maya所缺少的东西,也是全球数十亿人所缺少的,甚至可能也是你所缺少的呢?如果我告诉你,缺乏它会削弱我们为保持健康所做的其他努力,甚至可能缩短你的寿命呢?我研究这个问题已经超过十年,我发现,我们传统上对健康的理解是不完整的。通过将健康主要视为身体和心理的健康,我们忽略了我认为是我们这个时代最大的挑战和机遇——社会健康。身体健康关乎我们的身体,心理健康关乎我们的思想,而社会健康则关乎我们的人际关系。如果你以前没有听说过这个词,那是因为它还没有进入主流词汇,但它同样重要。Maya在她的新家还没有归属感。她不再亲自见到她的家人、朋友或同事,她经常一连几周只和丈夫共度时光。她的故事告诉我们,如果我们只照顾身体和心理,而不关注人际关系,我们就无法完全健康,无法真正茁壮成长。与Maya类似,全球有数亿人连续几周不与任何朋友或家人交谈。全球范围内,有四分之一的人感到孤独。20%的成年人觉得他们没有任何人可以求助。想想看,你遇到的每五个人中,可能有一个人觉得自己孤立无援。这不仅令人心碎,也是一场公共卫生危机。"
for res in chunk_text_chinesev2(text):
print(res)
================================================
FILE: tts/utils/text_utils/text_encoder.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
import six
from six.moves import range # pylint: disable=redefined-builtin
PAD = ""
EOS = ""
UNK = ""
SEG = "|"
PUNCS = '!,.?;:'
RESERVED_TOKENS = [PAD, EOS, UNK]
NUM_RESERVED_TOKENS = len(RESERVED_TOKENS)
PAD_ID = RESERVED_TOKENS.index(PAD) # Normally 0
EOS_ID = RESERVED_TOKENS.index(EOS) # Normally 1
UNK_ID = RESERVED_TOKENS.index(UNK) # Normally 2
if six.PY2:
RESERVED_TOKENS_BYTES = RESERVED_TOKENS
else:
RESERVED_TOKENS_BYTES = [bytes(PAD, "ascii"), bytes(EOS, "ascii")]
# Regular expression for unescaping token strings.
# '\u' is converted to '_'
# '\\' is converted to '\'
# '\213;' is converted to unichr(213)
_UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);")
_ESCAPE_CHARS = set(u"\\_u;0123456789")
def strip_ids(ids, ids_to_strip):
"""Strip ids_to_strip from the end ids."""
ids = list(ids)
while ids and ids[-1] in ids_to_strip:
ids.pop()
return ids
class TextEncoder(object):
"""Base class for converting from ints to/from human readable strings."""
def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS):
self._num_reserved_ids = num_reserved_ids
@property
def num_reserved_ids(self):
return self._num_reserved_ids
def encode(self, s):
"""Transform a human-readable string into a sequence of int ids.
The ids should be in the range [num_reserved_ids, vocab_size). Ids [0,
num_reserved_ids) are reserved.
EOS is not appended.
Args:
s: human-readable string to be converted.
Returns:
ids: list of integers
"""
return [int(w) + self._num_reserved_ids for w in s.split()]
def decode(self, ids, strip_extraneous=False):
"""Transform a sequence of int ids into a human-readable string.
EOS is not expected in ids.
Args:
ids: list of integers to be converted.
strip_extraneous: bool, whether to strip off extraneous tokens
(EOS and PAD).
Returns:
s: human-readable string.
"""
if strip_extraneous:
ids = strip_ids(ids, list(range(self._num_reserved_ids or 0)))
return " ".join(self.decode_list(ids))
def decode_list(self, ids):
"""Transform a sequence of int ids into a their string versions.
This method supports transforming individual input/output ids to their
string versions so that sequence to/from text conversions can be visualized
in a human readable format.
Args:
ids: list of integers to be converted.
Returns:
strs: list of human-readable string.
"""
decoded_ids = []
for id_ in ids:
if 0 <= id_ < self._num_reserved_ids:
decoded_ids.append(RESERVED_TOKENS[int(id_)])
else:
decoded_ids.append(id_ - self._num_reserved_ids)
return [str(d) for d in decoded_ids]
@property
def vocab_size(self):
raise NotImplementedError()
class TokenTextEncoder(TextEncoder):
"""Encoder based on a user-supplied vocabulary (file or list)."""
def __init__(self,
vocab_filename,
reverse=False,
vocab_list=None,
replace_oov=None,
num_reserved_ids=NUM_RESERVED_TOKENS):
"""Initialize from a file or list, one token per line.
Handling of reserved tokens works as follows:
- When initializing from a list, we add reserved tokens to the vocab.
- When initializing from a file, we do not add reserved tokens to the vocab.
- When saving vocab files, we save reserved tokens to the file.
Args:
vocab_filename: If not None, the full filename to read vocab from. If this
is not None, then vocab_list should be None.
reverse: Boolean indicating if tokens should be reversed during encoding
and decoding.
vocab_list: If not None, a list of elements of the vocabulary. If this is
not None, then vocab_filename should be None.
replace_oov: If not None, every out-of-vocabulary token seen when
encoding will be replaced by this string (which must be in vocab).
num_reserved_ids: Number of IDs to save for reserved tokens like .
"""
super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids)
self._reverse = reverse
self._replace_oov = replace_oov
if vocab_filename:
self._init_vocab_from_file(vocab_filename)
else:
assert vocab_list is not None
self._init_vocab_from_list(vocab_list)
self.pad_index = self.token_to_id[PAD]
self.eos_index = self.token_to_id[EOS]
self.unk_index = self.token_to_id[UNK]
self.seg_index = self.token_to_id[SEG] if SEG in self.token_to_id else self.eos_index
def encode(self, s):
"""Converts a space-separated string of tokens to a list of ids."""
if isinstance(s, str):
sentence = s
tokens = sentence.strip().split()
else:
tokens = s
if self._replace_oov is not None:
tokens = [t if t in self.token_to_id else self._replace_oov
for t in tokens]
ret = [self.token_to_id[tok] for tok in tokens]
return ret[::-1] if self._reverse else ret
def decode(self, ids, strip_eos=False, strip_padding=False):
if strip_padding and self.pad() in list(ids):
pad_pos = list(ids).index(self.pad())
ids = ids[:pad_pos]
if strip_eos and self.eos() in list(ids):
eos_pos = list(ids).index(self.eos())
ids = ids[:eos_pos]
return " ".join(self.decode_list(ids))
def decode_list(self, ids):
seq = reversed(ids) if self._reverse else ids
return [self._safe_id_to_token(i) for i in seq]
@property
def vocab_size(self):
return len(self.id_to_token)
def __len__(self):
return self.vocab_size
def _safe_id_to_token(self, idx):
return self.id_to_token.get(idx, "ID_%d" % idx)
def _init_vocab_from_file(self, filename):
"""Load vocab from a file.
Args:
filename: The file to load vocabulary from.
"""
with open(filename) as f:
tokens = [token.strip() for token in f.readlines()]
def token_gen():
for token in tokens:
yield token
self._init_vocab(token_gen(), add_reserved_tokens=False)
def _init_vocab_from_list(self, vocab_list):
"""Initialize tokens from a list of tokens.
It is ok if reserved tokens appear in the vocab list. They will be
removed. The set of tokens in vocab_list should be unique.
Args:
vocab_list: A list of tokens.
"""
def token_gen():
for token in vocab_list:
if token not in RESERVED_TOKENS:
yield token
self._init_vocab(token_gen())
def _init_vocab(self, token_generator, add_reserved_tokens=True):
"""Initialize vocabulary with tokens from token_generator."""
self.id_to_token = {}
non_reserved_start_index = 0
if add_reserved_tokens:
self.id_to_token.update(enumerate(RESERVED_TOKENS))
non_reserved_start_index = len(RESERVED_TOKENS)
self.id_to_token.update(
enumerate(token_generator, start=non_reserved_start_index))
# _token_to_id is the reverse of _id_to_token
self.token_to_id = dict((v, k) for k, v in six.iteritems(self.id_to_token))
def pad(self):
return self.pad_index
def eos(self):
return self.eos_index
def unk(self):
return self.unk_index
def seg(self):
return self.seg_index
def store_to_file(self, filename):
"""Write vocab file to disk.
Vocab files have one token per line. The file ends in a newline. Reserved
tokens are written to the vocab file as well.
Args:
filename: Full path of the file to store the vocab to.
"""
with open(filename, "w") as f:
for i in range(len(self.id_to_token)):
f.write(self.id_to_token[i] + "\n")
def sil_phonemes(self):
return [p for p in self.id_to_token.values() if is_sil_phoneme(p)]
def build_token_encoder(token_list_file):
token_list = json.load(open(token_list_file))
return TokenTextEncoder(None, vocab_list=token_list, replace_oov='')
def is_sil_phoneme(p):
return p == '' or not p[0].isalpha() or p == 'sil' or p == 'sp' or p == 'XX'
================================================
FILE: workflow-examples/单人语音.json
================================================
{"id":"f4285961-b2b3-477d-8027-364fea281241","revision":0,"last_node_id":27,"last_link_id":74,"nodes":[{"id":15,"type":"PreviewAudio","pos":[1494.0950927734375,-45.68428039550781],"size":[270,88],"flags":{},"order":3,"mode":0,"inputs":[{"localized_name":"audio","name":"audio","type":"AUDIO","link":73},{"localized_name":"audioUI","name":"audioUI","type":"AUDIO_UI","widget":{"name":"audioUI"},"link":null}],"outputs":[],"properties":{"Node name for S&R":"PreviewAudio"},"widgets_values":[]},{"id":11,"type":"MegaTTS3SpeakersPreview","pos":[814.6781616210938,-40.99158477783203],"size":[315,78],"flags":{},"order":0,"mode":0,"inputs":[{"localized_name":"speaker","name":"speaker","type":"COMBO","widget":{"name":"speaker"},"link":null}],"outputs":[{"localized_name":"audio","name":"audio","type":"AUDIO","links":[71]},{"localized_name":"npy_file","name":"npy_file","type":"STRING","links":[]}],"properties":{"Node name for S&R":"MegaTTS3SpeakersPreview"},"widgets_values":["中文女-已知风格\\御姐配音.wav"]},{"id":13,"type":"MultiLinePromptMG","pos":[822.5565795898438,91.6104965209961],"size":[292.10406494140625,190.9091796875],"flags":{},"order":1,"mode":0,"inputs":[{"localized_name":"multi_line_prompt","name":"multi_line_prompt","type":"STRING","widget":{"name":"multi_line_prompt"},"link":null}],"outputs":[{"localized_name":"text","name":"text","type":"STRING","links":[72]}],"properties":{"Node name for S&R":"MultiLinePromptMG"},"widgets_values":["MegaTTS 真开源版本来了,效果666, 我爱你!I love you!“我爱你”的英语是“I love you”\n\n2.5平方电线,共465篇,约315万字, 2002年的第一场雪,下在了2003年\n\nHigh-quality voice cloning, supports Chinese and English, and can perform cross-lingual cloning\n\n亲爱的,你最好了,不要生气了嘛,人家真的不是故意的啦,原谅我吧,好不好嘛"]},{"id":19,"type":"SaveAudioMP3","pos":[1500.0428466796875,108.09651184082031],"size":[270,136],"flags":{},"order":4,"mode":0,"inputs":[{"localized_name":"audio","name":"audio","type":"AUDIO","link":74},{"localized_name":"filename_prefix","name":"filename_prefix","type":"STRING","widget":{"name":"filename_prefix"},"link":null},{"localized_name":"quality","name":"quality","type":"COMBO","widget":{"name":"quality"},"link":null},{"localized_name":"audioUI","name":"audioUI","type":"AUDIO_UI","widget":{"name":"audioUI"},"link":null}],"outputs":[],"properties":{"Node name for S&R":"SaveAudioMP3"},"widgets_values":["单人","V0"]},{"id":27,"type":"MegaTTS3Run","pos":[1150.0458984375,-12.683919906616211],"size":[315,210],"flags":{},"order":2,"mode":0,"inputs":[{"localized_name":"audio","name":"audio","type":"AUDIO","link":71},{"localized_name":"text","name":"text","type":"STRING","link":72},{"localized_name":"dialogue_audio_s2","name":"dialogue_audio_s2","shape":7,"type":"AUDIO","link":null},{"localized_name":"audio_npy_file","name":"audio_npy_file","shape":7,"type":"STRING","link":null},{"localized_name":"audio_s2_npy_file","name":"audio_s2_npy_file","shape":7,"type":"STRING","link":null},{"localized_name":"time_step","name":"time_step","type":"INT","widget":{"name":"time_step"},"link":null},{"localized_name":"p_w","name":"p_w","type":"FLOAT","widget":{"name":"p_w"},"link":null},{"localized_name":"t_w","name":"t_w","type":"FLOAT","widget":{"name":"t_w"},"link":null},{"localized_name":"unload_model","name":"unload_model","type":"BOOLEAN","widget":{"name":"unload_model"},"link":null}],"outputs":[{"localized_name":"audio","name":"audio","type":"AUDIO","links":[73,74]}],"title":"Mega TTS3 Run","properties":{"Node name for S&R":"MegaTTS3Run"},"widgets_values":[32,1.6,2.5,false]}],"links":[[71,11,0,27,0,"AUDIO"],[72,13,0,27,1,"STRING"],[73,27,0,15,0,"AUDIO"],[74,27,0,19,0,"AUDIO"]],"groups":[],"config":{},"extra":{"ds":{"scale":1.1000000000000003,"offset":[-508.3255535295901,162.01651221912473]},"frontendVersion":"1.16.9","VHS_latentpreview":false,"VHS_latentpreviewrate":0,"VHS_MetadataImage":true,"VHS_KeepIntermediate":true},"version":0.4}
================================================
FILE: workflow-examples/双人会话.json
================================================
{"id":"eb4d7439-617f-4e18-beba-af841ebc9d1a","revision":0,"last_node_id":22,"last_link_id":54,"nodes":[{"id":11,"type":"MegaTTS3SpeakersPreview","pos":[812.1293334960938,-43.54039764404297],"size":[315,78],"flags":{},"order":0,"mode":0,"inputs":[{"localized_name":"speaker","name":"speaker","type":"COMBO","widget":{"name":"speaker"},"link":null}],"outputs":[{"localized_name":"audio","name":"audio","type":"AUDIO","links":[48]},{"localized_name":"npy_file","name":"npy_file","type":"STRING","links":[50]}],"properties":{"Node name for S&R":"MegaTTS3SpeakersPreview"},"widgets_values":["中文未分男女\\磁性-中音-中-00001.wav"]},{"id":17,"type":"MegaTTS3SpeakersPreview","pos":[810.3323974609375,284.3142395019531],"size":[315,78],"flags":{},"order":1,"mode":0,"inputs":[{"localized_name":"speaker","name":"speaker","type":"COMBO","widget":{"name":"speaker"},"link":null}],"outputs":[{"localized_name":"audio","name":"audio","type":"AUDIO","links":[53]},{"localized_name":"npy_file","name":"npy_file","type":"STRING","links":[54]}],"properties":{"Node name for S&R":"MegaTTS3SpeakersPreview"},"widgets_values":["中文女-已知风格\\御姐配音.wav"]},{"id":13,"type":"MultiLinePromptMG","pos":[821.70703125,91.6104965209961],"size":[287.00628662109375,133.1352081298828],"flags":{},"order":2,"mode":0,"inputs":[{"localized_name":"multi_line_prompt","name":"multi_line_prompt","type":"STRING","widget":{"name":"multi_line_prompt"},"link":null}],"outputs":[{"localized_name":"text","name":"text","type":"STRING","links":[49]}],"properties":{"Node name for S&R":"MultiLinePromptMG"},"widgets_values":["[S1] MegaTTS 真开源版本来了,效果666\n[S2] 晕 xuan4 是一种 gan3 觉\n[S1] 我爱你!I love you!“我爱你”的英语是“I love you”\n[S2] 2.5平方电线,共465篇,约315万字\n[S1] 2002年的第一场雪,下在了2003年"]},{"id":15,"type":"PreviewAudio","pos":[1496.6439208984375,-0.654519259929657],"size":[270,88],"flags":{},"order":4,"mode":0,"inputs":[{"localized_name":"audio","name":"audio","type":"AUDIO","link":51},{"localized_name":"audioUI","name":"audioUI","type":"AUDIO_UI","widget":{"name":"audioUI"},"link":null}],"outputs":[],"properties":{"Node name for S&R":"PreviewAudio"},"widgets_values":[]},{"id":19,"type":"SaveAudioMP3","pos":[1507.689453125,157.37432861328125],"size":[270,136],"flags":{},"order":5,"mode":0,"inputs":[{"localized_name":"audio","name":"audio","type":"AUDIO","link":52},{"localized_name":"filename_prefix","name":"filename_prefix","type":"STRING","widget":{"name":"filename_prefix"},"link":null},{"localized_name":"quality","name":"quality","type":"COMBO","widget":{"name":"quality"},"link":null},{"localized_name":"audioUI","name":"audioUI","type":"AUDIO_UI","widget":{"name":"audioUI"},"link":null}],"outputs":[],"properties":{"Node name for S&R":"SaveAudioMP3"},"widgets_values":["双人对话","V0"]},{"id":22,"type":"MegaTTS3Run","pos":[1157.6923828125,-2.4885244369506836],"size":[315,210],"flags":{},"order":3,"mode":0,"inputs":[{"localized_name":"audio","name":"audio","type":"AUDIO","link":48},{"localized_name":"text","name":"text","type":"STRING","link":49},{"localized_name":"dialogue_audio_s2","name":"dialogue_audio_s2","shape":7,"type":"AUDIO","link":53},{"localized_name":"audio_npy_file","name":"audio_npy_file","shape":7,"type":"STRING","link":50},{"localized_name":"audio_s2_npy_file","name":"audio_s2_npy_file","shape":7,"type":"STRING","link":54},{"localized_name":"time_step","name":"time_step","type":"INT","widget":{"name":"time_step"},"link":null},{"localized_name":"p_w","name":"p_w","type":"FLOAT","widget":{"name":"p_w"},"link":null},{"localized_name":"t_w","name":"t_w","type":"FLOAT","widget":{"name":"t_w"},"link":null},{"localized_name":"unload_model","name":"unload_model","type":"BOOLEAN","widget":{"name":"unload_model"},"link":null}],"outputs":[{"localized_name":"audio","name":"audio","type":"AUDIO","links":[51,52]}],"title":"Mega TTS3 Run","properties":{"Node name for S&R":"MegaTTS3Run"},"widgets_values":[32,1.6,2.5,false]}],"links":[[48,11,0,22,0,"AUDIO"],[49,13,0,22,1,"STRING"],[50,11,1,22,3,"STRING"],[51,22,0,15,0,"AUDIO"],[52,22,0,19,0,"AUDIO"],[53,17,0,22,2,"AUDIO"],[54,17,1,22,4,"STRING"]],"groups":[],"config":{},"extra":{"ds":{"scale":1.1000000000000003,"offset":[-476.88972922206165,158.6180469758719]},"frontendVersion":"1.16.9","VHS_latentpreview":false,"VHS_latentpreviewrate":0,"VHS_MetadataImage":true,"VHS_KeepIntermediate":true},"version":0.4}