Full Code of billwuhao/ComfyUI_MegaTTS3 for AI

main bad7ccb58afc cached
40 files
284.9 KB
76.0k tokens
373 symbols
1 requests
Download .txt
Showing preview only (304K chars total). Download the full file or copy to clipboard to get everything.
Repository: billwuhao/ComfyUI_MegaTTS3
Branch: main
Commit: bad7ccb58afc
Files: 40
Total size: 284.9 KB

Directory structure:
gitextract_pb8s4_gh/

├── .github/
│   └── workflows/
│       └── publish_action.yml
├── .gitignore
├── LICENSE
├── README-CN.md
├── README.md
├── __init__.py
├── megatts3node.py
├── pyproject.toml
├── requirements.txt
├── tts/
│   ├── frontend_function.py
│   ├── modules/
│   │   ├── aligner/
│   │   │   └── whisper_small.py
│   │   ├── ar_dur/
│   │   │   ├── ar_dur_predictor.py
│   │   │   └── commons/
│   │   │       ├── layers.py
│   │   │       ├── nar_tts_modules.py
│   │   │       ├── rel_transformer.py
│   │   │       ├── rot_transformer.py
│   │   │       ├── seq_utils.py
│   │   │       └── transformer.py
│   │   ├── llm_dit/
│   │   │   ├── cfm.py
│   │   │   ├── dit.py
│   │   │   ├── time_embedding.py
│   │   │   └── transformer.py
│   │   └── wavvae/
│   │       ├── decoder/
│   │       │   ├── diag_gaussian.py
│   │       │   ├── hifigan_modules.py
│   │       │   ├── seanet_encoder.py
│   │       │   └── wavvae_v3.py
│   │       └── encoder/
│   │           └── common_modules/
│   │               ├── conv.py
│   │               ├── lstm.py
│   │               └── seanet.py
│   └── utils/
│       ├── audio_utils/
│       │   ├── align.py
│       │   ├── io.py
│       │   └── plot.py
│       ├── commons/
│       │   ├── ckpt_utils.py
│       │   └── hparams.py
│       └── text_utils/
│           ├── dict.json
│           ├── ph_tone_convert.py
│           ├── split_text.py
│           └── text_encoder.py
└── workflow-examples/
    ├── 单人语音.json
    └── 双人会话.json

================================================
FILE CONTENTS
================================================

================================================
FILE: .github/workflows/publish_action.yml
================================================
name: Publish to Comfy registry
on:
  workflow_dispatch:
  push:
    branches:
      - master
      - main
    paths:
      - "pyproject.toml"

jobs:
  publish-node:
    name: Publish Custom Node to registry
    runs-on: ubuntu-latest
    steps:
      - name: Check out code
        uses: actions/checkout@v4
      - name: Publish Custom Node
        uses: Comfy-Org/publish-node-action@main
        with:
          ## Add your own personal access token to your Github Repository secrets and reference it here.
          personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}


================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# UV
#   Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#uv.lock

# poetry
#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
#   in version control.
#   https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
#  and can be added to the global gitignore or merged into this file.  For a more nuclear
#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# Ruff stuff:
.ruff_cache/

# PyPI configuration file
.pypirc


================================================
FILE: LICENSE
================================================

                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [2025] ByteDance

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.

================================================
FILE: README-CN.md
================================================
[中文](README-CN.md) | [English](README.md) 

# ComfyUI 的 MegaTTS3 声音克隆节点

声音克隆质量非常高, 支持中英文, 并可跨语言克隆. **支持自定义音色!!! 超长文本!!! 双人对话!!! Windows 正常安装 pynini, 不再是阉割版 TTS!!!**.

## 📣 更新

[2025-06-07]⚒️: v2.0.0. **支持自定义音色, 支持超长文本, 支持双人对话, Windows 正常安装 pynini, 不再是阉割版 TTS!**.

```
[S1] MegaTTS 真开源版本来了,效果666
[S2] 晕 xuan4 是一种 gan3 觉
[S1] 我爱你!I love you!“我爱你”的英语是“I love you”
[S2] 2.5平方电线,共465篇,约315万字
[S1] 2002年的第一场雪,下在了2003年
```

https://github.com/user-attachments/assets/b734e6bd-9303-4311-b3a4-618241ca6535

[2025-04-28]⚒️: 新增预览音色节点, 先预览音色, 满意再进行克隆. 感谢 @chenpipi0807 的 idea😍. 可在 `speakers` 文件夹下分门别类建更多文件夹.

[2025-04-06]⚒️: 发布 v1.0.0.

## 使用

- 单人克隆(超长文本用空行隔开):

![image](https://github.com/billwuhao/ComfyUI_MegaTTS3/blob/main/images/2025-04-06_13-52-57.png)

- 双人对话:

![image](https://github.com/billwuhao/ComfyUI_MegaTTS3/blob/main/images/2025-04-06_14-49-12.png)

## 安装

- **Windows 先安装以下依赖**:

[pynini-windows-wheels](https://github.com/billwuhao/pynini-windows-wheels/releases/tag/v2.1.6.post1) 下载相应 python 版本的 pynini 轮子.

示例:
```
D:\AIGC\python\py310\python.exe -m pip install pynini-2.1.6.post1-cp3xx-cp3xx-win_amd64.whl
D:\AIGC\python\py310\python.exe -m pip install importlib_resources
D:\AIGC\python\py310\python.exe -m pip install WeTextProcessing>=1.0.4 --no-deps
```

- **然后正常进行下列安装**:
```
cd ComfyUI/custom_nodes
git clone https://github.com/billwuhao/ComfyUI_MegaTTS3.git
cd ComfyUI_MegaTTS3
pip install -r requirements.txt

# python_embeded
./python_embeded/python.exe -m pip install -r requirements.txt
```

## 模型下载

- 模型和音色需要手动下载放到 `ComfyUI\models\TTS` 路径下:

[MegaTTS3](https://huggingface.co/ByteDance/MegaTTS3/tree/main)  整个文件夹全部下载放到 `TTS` 文件夹下.

- **VAE 编码模型, 加微信公众号获取, 放到 `TTS\MegaTTS3\wavvae` 文件夹下, 即可自定义音色而无需 `.npy` 文件**:

<img src="https://github.com/billwuhao/ComfyUI_MegaTTS3/blob/main/images/gzh.webp" alt="" width="200" height="200">

  - [Google 云盘](https://drive.google.com/drive/folders/1p9GNdNJqeK_94lIJW8lew_G3EazU-9Wx?usp=sharing)

- 请将音频放到 `TTS\speakers` 目录下. 我将会把所有 TTS 节点的说话者音频全部统一放到 `ComfyUI\models\TTS\speakers` 路径下, 这些节点包括 `IndexTTS, CSM, Dia, KokoroTTS, MegaTTS, QuteTTS, SparkTTS, StepAudioTTS` 等.

结构如下:

```
.
│  .gitattributes
│  config.json
│  README.md
│
├─aligner_lm
│      config.yaml
│      model_only_last.ckpt
│
├─diffusion_transformer
│      config.yaml
│      model_only_last.ckpt
│
├─duration_lm
│      config.yaml
│      model_only_last.ckpt
│
├─g2p
│      added_tokens.json
│      config.json
│      generation_config.json
│      latest
│      merges.txt
│      model.safetensors
│      special_tokens_map.json
│      tokenizer.json
│      tokenizer_config.json
│      trainer_state.json
│      vocab.json
│
└─wavvae
        config.yaml
        decoder.ckpt
        model_only_last.ckpt
```

## 鸣谢

- [MegaTTS3](https://github.com/bytedance/MegaTTS3)

## 打赏

您的赞赏是我最大的动力! 感谢您支持我一杯咖啡!

<img src="https://github.com/billwuhao/ComfyUI_MegaTTS3/blob/main/images/20250607012102.jpg" alt="" width="200" height="200">


================================================
FILE: README.md
================================================
[中文](README-CN.md) | [English](README.md)

# MegaTTS3 Voice Cloning Nodes for ComfyUI

High-quality voice cloning, supporting both Chinese and English, with cross-lingual cloning capabilities. **Supports custom voice cloning!!! Extra-long text!!! Two-person dialogue!!! Full pynini installation on Windows, no more stripped-down TTS!!!**.

## 📣 Updates

[2025-06-07]⚒️: v2.0.0. **Supports custom voice cloning, extra-long text, two-person dialogue, and full pynini installation on Windows, no more stripped-down TTS!**.

```
[S1] MegaTTS 真开源版本来了,效果666
[S2] 晕 xuan4 是一种 gan3 觉
[S1] 我爱你!I love you!“我爱你”的英语是“I love you”
[S2] 2.5平方电线,共465篇,约315万字
[S1] 2002年的第一场雪,下在了2003年
```

https://github.com/user-attachments/assets/b734e6bd-9303-4311-b3a4-618241ca6535

[2025-04-28]⚒️: Added a voice preview node. Preview the voice first, then clone if you're satisfied. Thanks to @chenpipi0807 for the idea😍. You can create categorized subfolders within the `speakers` folder.

[2025-04-06]⚒️: Released v1.0.0.

## Usage

- Single-person cloning (separate long text with blank lines):

![image](https://github.com/billwuhao/ComfyUI_MegaTTS3/blob/main/images/2025-04-06_13-52-57.png)

- Two-person dialogue:

![image](https://github.com/billwuhao/ComfyUI_MegaTTS3/blob/main/images/2025-04-06_14-49-12.png)

## Installation

- **For Windows, install the following dependencies first**:

[pynini-windows-wheels](https://github.com/billwuhao/pynini-windows-wheels/releases/tag/v2.1.6.post1) Download the pynini wheel file corresponding to your Python version.

Example:
```
D:\AIGC\python\py310\python.exe -m pip install pynini-2.1.6.post1-cp3xx-cp3xx-win_amd64.whl
D:\AIGC\python\py310\python.exe -m pip install importlib_resources
D:\AIGC\python\py310\python.exe -m pip install WeTextProcessing>=1.0.4 --no-deps
```

- **Then, proceed with the normal installation**:
```
cd ComfyUI/custom_nodes
git clone https://github.com/billwuhao/ComfyUI_MegaTTS3.git
cd ComfyUI_MegaTTS3
pip install -r requirements.txt

# For python_embeded
./python_embeded/python.exe -m pip install -r requirements.txt
```

## Model Download

- Models and voices need to be downloaded manually and placed in the `ComfyUI\models\TTS` directory:

[MegaTTS3](https://huggingface.co/ByteDance/MegaTTS3/tree/main) Download the entire folder and place it in the `TTS` directory.

- **For the VAE encoder model, which enables custom voice cloning without `.npy` files, please follow our WeChat Official Account to obtain it. Place it in the `TTS\MegaTTS3\wavvae` folder**:

<img src="https://github.com/billwuhao/ComfyUI_MegaTTS3/blob/main/images/gzh.webp" alt="" width="200" height="200">

  - [Google Cloud Drive](https://drive.google.com/drive/folders/1p9GNdNJqeK_94lIJW8lew_G3EazU-9Wx?usp=sharing)

- Please place the audio in the `TTS\speakers` directory. I will unify all speaker audios for TTS nodes into the `ComfyUI\models\TTS\speakers` path. These nodes include `IndexTTS, CSM, Dia, KokoroTTS, MegaTTS, QuteTTS, SparkTTS, StepAudioTTS`, etc.

The structure is as follows:

```
.
│  .gitattributes
│  config.json
│  README.md
│
├─aligner_lm
│      config.yaml
│      model_only_last.ckpt
│
├─diffusion_transformer
│      config.yaml
│      model_only_last.ckpt
│
├─duration_lm
│      config.yaml
│      model_only_last.ckpt
│
├─g2p
│      added_tokens.json
│      config.json
│      generation_config.json
│      latest
│      merges.txt
│      model.safetensors
│      special_tokens_map.json
│      tokenizer.json
│      tokenizer_config.json
│      trainer_state.json
│      vocab.json
│
└─wavvae
        config.yaml
        decoder.ckpt
        model_only_last.ckpt
```


## Credits

- [MegaTTS3](https://github.com/bytedance/MegaTTS3)

## Donation

Your appreciation is my greatest motivation! Thank you for supporting me with a cup of coffee!

<img src="https://github.com/billwuhao/ComfyUI_MegaTTS3/blob/main/images/20250607012102.jpg" alt="" width="200" height="200">


================================================
FILE: __init__.py
================================================
from .megatts3node import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS

__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']

================================================
FILE: megatts3node.py
================================================
import json
import os
import librosa
import numpy as np
import torch
import torchaudio
from typing import List, Union, Optional
from tn.chinese.normalizer import Normalizer as ZhNormalizer
from tn.english.normalizer import Normalizer as EnNormalizer
from langdetect import detect as classify_language
import pyloudnorm as pyln
import folder_paths
import gc
import re
import sys
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
    sys.path.append(current_dir)

from tts.modules.ar_dur.commons.nar_tts_modules import LengthRegulator
from tts.frontend_function import g2p, align, make_dur_prompt, dur_pred, prepare_inputs_for_dit
from tts.utils.audio_utils.io import convert_to_wav_bytes, combine_audio_segments
from tts.utils.commons.ckpt_utils import load_ckpt
from tts.utils.commons.hparams import set_hparams, hparams
from tts.utils.text_utils.text_encoder import TokenTextEncoder
from tts.utils.text_utils.split_text import chunk_text_chinese, chunk_text_english, chunk_text_chinesev2
from tts.utils.commons.hparams import hparams, set_hparams


models_dir = folder_paths.models_dir
model_path = os.path.join(models_dir, "TTS")
speakers_dir = os.path.join(model_path, "speakers")
cache_dir = folder_paths.get_temp_directory()

def get_all_files(
    root_dir: str,
    return_type: str = "list",
    extensions: Optional[List[str]] = None,
    exclude_dirs: Optional[List[str]] = None,
    relative_path: bool = False
) -> Union[List[str], dict]:
    """
    递归获取目录下所有文件路径
    
    :param root_dir: 要遍历的根目录
    :param return_type: 返回类型 - "list"(列表) 或 "dict"(按目录分组)
    :param extensions: 可选的文件扩展名过滤列表 (如 ['.py', '.txt'])
    :param exclude_dirs: 要排除的目录名列表 (如 ['__pycache__', '.git'])
    :param relative_path: 是否返回相对路径 (相对于root_dir)
    :return: 文件路径列表或字典
    """
    file_paths = []
    file_dict = {}
    
    # 规范化目录路径
    root_dir = os.path.normpath(root_dir)
    
    for dirpath, dirnames, filenames in os.walk(root_dir):
        # 处理排除目录
        if exclude_dirs:
            dirnames[:] = [d for d in dirnames if d not in exclude_dirs]
        
        current_files = []
        for filename in filenames:
            # 扩展名过滤
            if extensions:
                if not any(filename.lower().endswith(ext.lower()) for ext in extensions):
                    continue
            
            # 构建完整路径
            full_path = os.path.join(dirpath, filename)
            
            # 处理相对路径
            if relative_path:
                full_path = os.path.relpath(full_path, root_dir)
            
            current_files.append(full_path)
        
        if return_type == "dict":
            # 使用相对路径或绝对路径作为键
            dict_key = os.path.relpath(dirpath, root_dir) if relative_path else dirpath
            if current_files:
                file_dict[dict_key] = current_files
        else:
            file_paths.extend(current_files)
    
    return file_dict if return_type == "dict" else file_paths

def get_speakers():
    if not os.path.exists(speakers_dir):
        os.makedirs(speakers_dir, exist_ok=True)
        return []
    speakers = get_all_files(speakers_dir, extensions=[".wav", ".mp3", ".flac", ".mp4", ".WAV", ".MP3", ".FLAC", ".MP4"], relative_path=True)
    return speakers


class MegaTTS3DiTInfer():
    def __init__(
            self,
            device=None,
            ckpt_root=os.path.join(model_path, "MegaTTS3"),
            dit_exp_name='diffusion_transformer',
            frontend_exp_name='aligner_lm',
            wavvae_exp_name='wavvae',
            dur_ckpt_path='duration_lm',
            g2p_exp_name='g2p',
            precision=torch.float16,
            **kwargs
        ):
        self.sr = 24000
        self.fm = 8
        if device is None:
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.device = device
        self.precision = precision

        # build models
        self.dit_exp_name = os.path.join(ckpt_root, dit_exp_name)
        self.frontend_exp_name = os.path.join(ckpt_root, frontend_exp_name)
        self.wavvae_exp_name = os.path.join(ckpt_root, wavvae_exp_name)
        self.dur_exp_name = os.path.join(ckpt_root, dur_ckpt_path)
        self.g2p_exp_name = os.path.join(ckpt_root, g2p_exp_name)
        self.build_model(self.device)

        # init text normalizer
        self.zh_normalizer = ZhNormalizer(overwrite_cache=False, remove_erhua=False, remove_interjections=False)
        self.en_normalizer = EnNormalizer(overwrite_cache=False)

        # loudness meter
        self.loudness_meter = pyln.Meter(self.sr)
        
        self.ph_ref = None
        self.tone_ref = None
        self.mel2ph_ref = None
        self.vae_latent = None
        self.ctx_dur_tokens = None
        self.incremental_state_dur_prompt = None

        self.audio_bytes = None
        
    def clean(self):
        import gc
        self.dur_model = None
        self.dit= None
        self.g2p_model = None
        self.wavvae_en = None
        self.wavvae_de = None
        self.aligner_lm = None

        self.audio_bytes = None
        self.ph_ref = None
        self.tone_ref = None
        self.mel2ph_ref = None
        self.vae_latent = None
        self.ctx_dur_tokens = None
        self.incremental_state_dur_prompt = None

        gc.collect()
        torch.cuda.empty_cache()

    def build_model(self, device):
        set_hparams(exp_name=self.dit_exp_name, print_hparams=False)

        ''' Load Dict '''
        current_dir = os.path.dirname(os.path.abspath(__file__))
        ling_dict = json.load(open(f"{current_dir}/tts/utils/text_utils/dict.json", encoding='utf-8-sig'))
        self.ling_dict = {k: TokenTextEncoder(None, vocab_list=ling_dict[k], replace_oov='<UNK>') for k in ['phone', 'tone']}
        self.token_encoder = token_encoder = self.ling_dict['phone']
        ph_dict_size = len(token_encoder)

        ''' Load Duration LM '''
        from tts.modules.ar_dur.ar_dur_predictor import ARDurPredictor
        hp_dur_model = self.hp_dur_model = set_hparams(f'{self.dur_exp_name}/config.yaml', global_hparams=False)
        hp_dur_model['frames_multiple'] = hparams['frames_multiple']
        self.dur_model = ARDurPredictor(
            hp_dur_model, hp_dur_model['dur_txt_hs'], hp_dur_model['dur_model_hidden_size'],
            hp_dur_model['dur_model_layers'], ph_dict_size,
            hp_dur_model['dur_code_size'],
            use_rot_embed=hp_dur_model.get('use_rot_embed', False))
        self.length_regulator = LengthRegulator()
        load_ckpt(self.dur_model, f'{self.dur_exp_name}', 'dur_model')
        self.dur_model.eval()
        self.dur_model.to(device)

        ''' Load Diffusion Transformer '''
        from tts.modules.llm_dit.dit import Diffusion
        self.dit = Diffusion()
        load_ckpt(self.dit, f'{self.dit_exp_name}', 'dit', strict=False)
        self.dit.eval()
        self.dit.to(device)
        self.cfg_mask_token_phone = 302 - 1
        self.cfg_mask_token_tone = 32 - 1

        ''' Load Frontend LM '''
        from tts.modules.aligner.whisper_small import Whisper
        self.aligner_lm = Whisper()
        load_ckpt(self.aligner_lm, f'{self.frontend_exp_name}', 'model')
        self.aligner_lm.eval()
        self.aligner_lm.to(device)
        self.kv_cache = None
        self.hooks = None

        ''' Load G2P LM'''
        from transformers import AutoTokenizer, AutoModelForCausalLM
        g2p_tokenizer = AutoTokenizer.from_pretrained(self.g2p_exp_name, padding_side="right")
        g2p_tokenizer.padding_side = "right"
        self.g2p_model = AutoModelForCausalLM.from_pretrained(self.g2p_exp_name).eval().to(device)
        self.g2p_tokenizer = g2p_tokenizer
        self.speech_start_idx = g2p_tokenizer.encode('<Reserved_TTS_0>')[0]

        ''' Wav VAE '''
        self.hp_wavvae = hp_wavvae = set_hparams(f'{self.wavvae_exp_name}/config.yaml', global_hparams=False)
        from tts.modules.wavvae.decoder.wavvae_v3 import WavVAE_V3

        self.wavvae_en = WavVAE_V3(hparams=hp_wavvae)
        self.wavvae_de = WavVAE_V3(hparams=hp_wavvae)

        if os.path.exists(f'{self.wavvae_exp_name}/model_only_last.ckpt'):
            load_ckpt(self.wavvae_en, f'{self.wavvae_exp_name}/model_only_last.ckpt', 'model_gen', strict=True)
            self.has_vae_encoder = True
            self.wavvae_en.eval()
            self.wavvae_en.to(device)
        else:
            load_ckpt(self.wavvae_de, f'{self.wavvae_exp_name}/decoder.ckpt', 'model_gen', strict=False)
            self.has_vae_encoder = False
            self.wavvae_de.eval()
            self.wavvae_de.to(device)

        self.vae_stride = hp_wavvae.get('vae_stride', 4)
        self.hop_size = hp_wavvae.get('hop_size', 4)
    
    def preprocess(self, audio_bytes, latent_file=None, topk_dur=1, **kwargs):
        if self.audio_bytes != audio_bytes:
            self.audio_bytes = audio_bytes
            wav_bytes = convert_to_wav_bytes(audio_bytes)

            ''' Load wav '''
            wav, _ = librosa.core.load(wav_bytes, sr=self.sr)
            # Pad wav if necessary
            ws = hparams['win_size']
            if len(wav) % ws < ws - 1:
                wav = np.pad(wav, (0, ws - 1 - (len(wav) % ws)), mode='constant', constant_values=0.0).astype(np.float32)
            wav = np.pad(wav, (0, 12000), mode='constant', constant_values=0.0).astype(np.float32)
            self.loudness_prompt = self.loudness_meter.integrated_loudness(wav.astype(float))

            ''' obtain alignments with aligner_lm '''
            ph_ref, tone_ref, mel2ph_ref = align(self, wav)

            self.kv_cache = None
            self.hooks = None

            with torch.inference_mode():
                ''' Forward WaveVAE to obtain: prompt latent '''
                if self.has_vae_encoder:
                    if latent_file is None:
                        wav = torch.FloatTensor(wav)[None].to(self.device)
                        vae_latent = self.wavvae_en.encode_latent(wav)
                    else:
                        vae_latent = torch.from_numpy(np.load(latent_file)).to(self.device)
                    vae_latent = vae_latent[:, :mel2ph_ref.size(1)//4]
                else:
                    assert latent_file is not None, "WaveVAE encode model does not exist, an npy file must be provided!!!"
                    vae_latent = torch.from_numpy(np.load(latent_file)).to(self.device)
                    vae_latent = vae_latent[:, :mel2ph_ref.size(1)//4]
            
                ''' Duration Prompting '''
                self.dur_model.hparams["infer_top_k"] = topk_dur if topk_dur > 1 else None
                incremental_state_dur_prompt, ctx_dur_tokens = make_dur_prompt(self, mel2ph_ref, ph_ref, tone_ref)

                self.ph_ref = ph_ref.to(self.device)
                self.tone_ref = tone_ref.to(self.device)
                self.mel2ph_ref = mel2ph_ref.to(self.device)
                self.vae_latent = vae_latent.to(self.device)
                self.ctx_dur_tokens = ctx_dur_tokens.to(self.device)
                self.incremental_state_dur_prompt = incremental_state_dur_prompt

    def forward(self, texts, time_step, p_w, t_w, dur_disturb=0.1, dur_alpha=1.0, **kwargs):

        with torch.inference_mode():
            ''' Generating '''
            waveforms = []
            for input_text in texts:
                wav_pred_ = []
                language_type = classify_language(input_text)
                if language_type == 'en':
                    input_text = self.en_normalizer.normalize(input_text)
                    text_segs = chunk_text_english(input_text, max_chars=130)
                else:
                    input_text = self.zh_normalizer.normalize(input_text)
                    text_segs = chunk_text_chinesev2(input_text, limit=60)

                for seg_i, text in enumerate(text_segs):
                    ''' G2P '''
                    ph_pred, tone_pred = g2p(self, text)

                    ''' Duration Prediction '''
                    mel2ph_pred = dur_pred(self, self.ctx_dur_tokens, self.incremental_state_dur_prompt, ph_pred, tone_pred, seg_i, dur_disturb, dur_alpha, is_first=seg_i==0, is_final=seg_i==len(text_segs)-1)
                    
                    inputs = prepare_inputs_for_dit(self, self.mel2ph_ref, mel2ph_pred, self.ph_ref, self.tone_ref, ph_pred, tone_pred, self.vae_latent)
                    # Speech dit inference
                    with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
                        x = self.dit.inference(inputs, timesteps=time_step, seq_cfg_w=[p_w, t_w]).float()
                    
                    # WavVAE decode
                    x[:, :self.vae_latent.size(1)] = self.vae_latent
                    if self.has_vae_encoder:
                        wav_pred = self.wavvae_en.decode(x)[0,0].to(torch.float32)
                    else:
                        wav_pred = self.wavvae_de.decode(x)[0,0].to(torch.float32)
                    
                    ''' Post-processing '''
                    # Trim prompt wav
                    wav_pred = wav_pred[self.vae_latent.size(1)*self.vae_stride*self.hop_size:].cpu().numpy()
                    # Norm generated wav to prompt wav's level
                    meter = pyln.Meter(self.sr)  # create BS.1770 meter
                    loudness_pred = self.loudness_meter.integrated_loudness(wav_pred.astype(float))
                    wav_pred = pyln.normalize.loudness(wav_pred, loudness_pred, self.loudness_prompt)
                    if np.abs(wav_pred).max() >= 1:
                        wav_pred = wav_pred / np.abs(wav_pred).max() * 0.95

                    # Apply hamming window
                    wav_pred_.append(wav_pred)

                    gc.collect()
                    torch.cuda.empty_cache()

                wav_pred = combine_audio_segments(wav_pred_, sr=self.sr).astype(np.float32)
                waveform = torch.tensor(wav_pred)
                waveforms.append(waveform.cpu())

            return torch.cat(waveforms, dim=0), self.sr


class MegaTTS3SpeakersPreview:
    @classmethod
    def INPUT_TYPES(s):
        speakers = get_speakers()
        return {
            "required": {"speaker":(speakers,),},}

    RETURN_TYPES = ("AUDIO", "STRING", )
    RETURN_NAMES = ("audio", "npy_file", )
    FUNCTION = "preview"
    CATEGORY = "🎤MW/MW-MegaTTS3"

    def preview(self, speaker):
        wav_path = os.path.join(speakers_dir, speaker)
        latent_file = wav_path.rsplit('.', 1)[0] + '.npy'
        if not os.path.exists(latent_file):
            latent_file = ""

        waveform, sample_rate = torchaudio.load(wav_path)
        waveform = waveform.unsqueeze(0)
        output_audio = {
            "waveform": waveform,
            "sample_rate": sample_rate
        }
        return (output_audio, latent_file)


def cache_audio_tensor(
    cache_dir,
    audio_tensor: torch.Tensor,
    sample_rate: int,
    filename_prefix: str = "cached_audio_",
    audio_format: Optional[str] = ".wav"
) -> str:
    import tempfile
    try:
        with tempfile.NamedTemporaryFile(
            prefix=filename_prefix,
            suffix=audio_format,
            dir=cache_dir,
            delete=False 
        ) as tmp_file:
            temp_filepath = tmp_file.name
        
        torchaudio.save(temp_filepath, audio_tensor, sample_rate)

        return temp_filepath
    except Exception as e:
        raise Exception(f"Error caching audio tensor: {e}")

def statistical_compare(tensor1, tensor2):
    """通过统计特征快速比较"""
    stats1 = {
        'mean': tensor1.mean(),
        'std': tensor1.std(),
        'max': tensor1.max(),
        'min': tensor1.min()
    }
    stats2 = {
        'mean': tensor2.mean(),
        'std': tensor2.std(),
        'max': tensor2.max(),
        'min': tensor2.min()
    }
    return all(torch.allclose(stats1[k], stats2[k], rtol=1e-3) for k in stats1)


INFER_INS_CACHE = None
class MegaTTS3Run:
    def __init__(self):
        self.resource_context = None
        self.audio_tensor = None
        self.audio_prompt = None

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "audio": ("AUDIO",),
                "text": ("STRING", {"forceInput": True}),
                "time_step": ("INT", {"default": 32, "min": 1,}),
                "p_w": ("FLOAT", {"default":1.6, "min": 0.1,}),
                "t_w": ("FLOAT", {"default": 2.5, "min": 0.1,}),
                "unload_model": ("BOOLEAN", {"default": True}),
            },
            "optional": {
                "dialogue_audio_s2":("AUDIO",),
                "audio_npy_file": ("STRING",  {"forceInput": True, "tooltip": "No `npy_file` will use VAE to encode audio. 不提供 .npy 文件, 将使用 WaveVAE 编码音频"}),
                "audio_s2_npy_file": ("STRING",  {"forceInput": True, "tooltip": "No `npy_file` will use VAE to encode audio. 不提供 .npy 文件, 将使用 WaveVAE 编码音频"}),
            }
        }

    RETURN_TYPES = ("AUDIO",)
    RETURN_NAMES = ("audio",)
    FUNCTION = "clone"
    CATEGORY = "🎤MW/MW-MegaTTS3"

    def clone(self, audio, text, time_step, p_w, t_w, unload_model, audio_npy_file=None, dialogue_audio_s2=None, audio_s2_npy_file=None):
        if not os.path.exists(os.path.join(model_path, "MegaTTS3", 'wavvae', 'model_only_last.ckpt')):
            print("WaveVAE encode model does not exist, an npy file must be provided!!!")
        waveform = audio["waveform"].squeeze(0)

        global INFER_INS_CACHE
        if INFER_INS_CACHE is None:
            INFER_INS_CACHE = MegaTTS3DiTInfer()
            
        latent_file = audio_npy_file if audio_npy_file else None
        try:
            import gc
            if dialogue_audio_s2 is None:
                # 只有音频改变时, 才重新预处理
                if self.audio_tensor is None or self.audio_prompt is None or statistical_compare(self.audio_tensor, waveform) == False:
                    self.audio_tensor = waveform
                    self.audio_prompt = cache_audio_tensor(cache_dir, waveform, audio["sample_rate"])

                texts = [i.strip() for i in re.split(r'\n\s*\n', text.strip()) if i.strip()]
                with open(self.audio_prompt, 'rb') as file:
                    file_content = file.read()
                INFER_INS_CACHE.preprocess(file_content, latent_file=latent_file)

                del file_content
                gc.collect()
                torch.cuda.empty_cache()

                waveform, sr = INFER_INS_CACHE.forward(texts=texts, time_step=time_step, p_w=p_w, t_w=t_w)
                gc.collect()
                torch.cuda.empty_cache()
            else:
                latent_file_2 = audio_s2_npy_file if audio_s2_npy_file else None
                audio_1 = cache_audio_tensor(cache_dir, waveform, audio["sample_rate"])
                audio_2 = cache_audio_tensor(cache_dir, dialogue_audio_s2["waveform"].squeeze(0), dialogue_audio_s2["sample_rate"])
                with open(audio_1, 'rb') as file:
                    file_content_1 = file.read()
                with open(audio_2, 'rb') as file:
                    file_content_2 = file.read()

                gc.collect()
                torch.cuda.empty_cache()
        
                ress = []
                for t, a, n in self.get_speaker_text_audio(text, audio_1, audio_2):
                    texts = [i.strip() for i in re.split(r'\n\s*\n', t.strip()) if i.strip()]
                    if a == audio_1:
                        INFER_INS_CACHE.preprocess(file_content_1, latent_file=latent_file)
                        res_sub, sr = INFER_INS_CACHE.forward(texts=texts, time_step=time_step, p_w=p_w, t_w=t_w)
                        ress.append([res_sub, n])
                    else:
                        INFER_INS_CACHE.preprocess(file_content_2, latent_file=latent_file_2)
                        res_sub, sr = INFER_INS_CACHE.forward(texts=texts, time_step=time_step, p_w=p_w, t_w=t_w)
                        ress.append([res_sub, n])

                del file_content_1
                del file_content_2
                gc.collect()
                torch.cuda.empty_cache()
                waveform = torch.cat(list(zip(*sorted(ress, key=lambda x: x[1])))[0], dim=0)

        except Exception as e:
            if unload_model:
                import gc
                INFER_INS_CACHE.clean()
                INFER_INS_CACHE = None
                self.resource_context = None
                gc.collect()
                torch.cuda.empty_cache()
            raise e

        if unload_model:
            import gc
            INFER_INS_CACHE.clean()
            INFER_INS_CACHE = None
            self.resource_context = None
            gc.collect()
            torch.cuda.empty_cache()

        return ({"waveform": waveform.unsqueeze(0).unsqueeze(0), "sample_rate": sr},)

    def get_speaker_text_audio(self, text, audio_1, audio_2):
        pattern = r'(\[s?S?1\]|\[s?S?2\])\s*([\s\S]*?)(?=\[s?S?[12]\]|$)'
        matches = re.findall(pattern, text)
        if len(matches) == 0:
            raise ValueError("No speaker tags found in the text: [S1]... [S2]...")
        labels = []
        contents = []
        audios = []

        for label, content in matches:
            labels.append(label)
            contents.append(content)
    
        audios = [
            audio_1 if i.lower() == '[s1]' else audio_2 for i in labels
        ]

        return sorted(zip(contents, audios, range(len(contents))), key=lambda x: x[1])
    

class MultiLinePromptMG:
    @classmethod
    def INPUT_TYPES(cls):
               
        return {
            "required": {
                "multi_line_prompt": ("STRING", {
                    "multiline": True, 
                    "default": ""}),
                },
        }

    CATEGORY = "🎤MW/MW-MegaTTS3"
    RETURN_TYPES = ("STRING",)
    RETURN_NAMES = ("text",)
    FUNCTION = "promptgen"
    
    def promptgen(self, multi_line_prompt: str):
        return (multi_line_prompt.strip(),)


NODE_CLASS_MAPPINGS = {
    "MegaTTS3SpeakersPreview": MegaTTS3SpeakersPreview,
    "MegaTTS3Run": MegaTTS3Run,
    "MultiLinePromptMG": MultiLinePromptMG,
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "MegaTTS3SpeakersPreview": "MegaTTS3 Speakers Preview",
    "MegaTTS3Run": "MegaTTS3 Run",
    "MultiLinePromptMG": "Multi Line Text",
}

================================================
FILE: pyproject.toml
================================================
[project]
name = "megatts3-mw"
description = "Lightweight and Efficient, 🎧Ultra High-Quality Voice Cloning, Chinese and English."
version = "2.0.0"
license = {file = "LICENSE"}
dependencies = ["setproctitle", "attrdict", "librosa", "pydub", "pyloudnorm", "x-transformers", "torchdiffeq", "openai-whisper>=20240930"]

[project.urls]
Repository = "https://github.com/billwuhao/ComfyUI_MegaTTS3"
#  Used by Comfy Registry https://comfyregistry.org

[tool.comfy]
PublisherId = "mw"
DisplayName = "MW-ComfyUI_MegaTTS3"
Icon = ""


================================================
FILE: requirements.txt
================================================
setproctitle
attrdict
librosa
pyloudnorm
x-transformers
torchdiffeq
openai-whisper>=20240930
langdetect
pynini==2.1.6; platform_system!="Windows"
WeTextProcessing>=1.0.3; platform_system!="Windows"


================================================
FILE: tts/frontend_function.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn.functional as F
import whisper
import librosa
from copy import deepcopy
from tts.utils.text_utils.ph_tone_convert import split_ph_timestamp, split_ph
from tts.utils.audio_utils.align import mel2token_to_dur

''' Graphme to phoneme function '''
def g2p(self, text_inp):
    # prepare inputs
    txt_token = self.g2p_tokenizer('<BOT>' + text_inp + '<BOS>')['input_ids']
    input_ids = torch.LongTensor([txt_token+[145+self.speech_start_idx]]).to(self.device)

    # model forward
    with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
        outputs = self.g2p_model.generate(input_ids, max_new_tokens=256, do_sample=True, top_k=1, eos_token_id=800+1+self.speech_start_idx)
    
    # process outputs
    ph_tokens = outputs[:, len(txt_token):-1]-self.speech_start_idx
    ph_pred, tone_pred = split_ph(ph_tokens[0])
    ph_pred, tone_pred = ph_pred[None, :].to(self.device), tone_pred[None, :].to(self.device)
    return ph_pred, tone_pred

''' Get phoneme2mel align of prompt speech '''
def align(self, wav):
    with torch.inference_mode():
        whisper_wav = librosa.resample(wav, orig_sr=self.sr, target_sr=16000)
        mel = torch.FloatTensor(whisper.log_mel_spectrogram(whisper_wav).T).to(self.device)[None].transpose(1,2)
        prompt_max_frame = mel.size(2) // self.fm * self.fm
        mel = mel[:, :, :prompt_max_frame]
        token = torch.LongTensor([[798]]).to(self.device)
        audio_features = self.aligner_lm.embed_audio(mel)
        for i in range(768):
            with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
                logits = self.aligner_lm.logits(token, audio_features, None)
                token_pred = torch.argmax(F.softmax(logits[:, -1], dim=-1), 1)[None]
                token = torch.cat([token, token_pred], dim=1)
                if token_pred[0] == 799:
                    break
        alignment_tokens = token
    
    ph_ref, tone_ref, dur_ref, _ = split_ph_timestamp(deepcopy(alignment_tokens)[0, 1:-1])
    ph_ref = torch.Tensor(ph_ref)[None].to(self.device)
    tone_ref = torch.Tensor(tone_ref)[None].to(self.device)
    if dur_ref.sum() < prompt_max_frame:
        dur_ref[-1] += prompt_max_frame - dur_ref.sum()
    elif dur_ref.sum() > prompt_max_frame:
        len_diff = dur_ref.sum() - prompt_max_frame
        while True:
            for i in range(len(dur_ref)):
                dur_ref[i] -= 1
                len_diff -= 1
                if len_diff == 0:
                    break
            if len_diff == 0:
                break
    mel2ph_ref = self.length_regulator(dur_ref[None]).to(self.device)
    mel2ph_ref = mel2ph_ref[:, :mel2ph_ref.size(1)//self.fm*self.fm]
    return ph_ref, tone_ref, mel2ph_ref

''' Duration Prompting '''
def make_dur_prompt(self, mel2ph_ref, ph_ref, tone_ref):
    dur_tokens_2d_ = mel2token_to_dur(mel2ph_ref, ph_ref.shape[1]).clamp(
                    max=self.hp_dur_model['dur_code_size'] - 1) + 1

    ctx_dur_tokens = dur_tokens_2d_.clone().flatten(0, 1).to(self.device)
    txt_tokens_flat_ = ph_ref.flatten(0, 1)
    ctx_dur_tokens = ctx_dur_tokens[txt_tokens_flat_ > 0][None]

    last_dur_pos_prompt = ctx_dur_tokens.shape[1]
    dur_spk_pos_ids_flat = range(0, last_dur_pos_prompt)
    dur_spk_pos_ids_flat = torch.LongTensor([dur_spk_pos_ids_flat]).to(self.device)
    with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
        _, incremental_state_dur_prompt = self.dur_model.infer(
            ph_ref, {'tone': tone_ref}, None, None, None,
            ctx_vqcodes=ctx_dur_tokens, spk_pos_ids_flat=dur_spk_pos_ids_flat, return_state=True)
    return incremental_state_dur_prompt, ctx_dur_tokens

''' Duration Prediction '''
def dur_pred(self, ctx_dur_tokens, incremental_state_dur_prompt, ph_pred, tone_pred, seg_i, dur_disturb, dur_alpha, is_first, is_final):
    last_dur_token = ctx_dur_tokens[:, -1:]
    last_dur_pos_prompt = ctx_dur_tokens.shape[1]
    incremental_state_dur = deepcopy(incremental_state_dur_prompt)
    txt_len = ph_pred.shape[1]
    dur_spk_pos_ids_flat = range(last_dur_pos_prompt, last_dur_pos_prompt + txt_len)
    dur_spk_pos_ids_flat = torch.LongTensor([dur_spk_pos_ids_flat]).to(self.device)
    last_dur_pos_prompt = last_dur_pos_prompt + txt_len

    with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
        dur_pred = self.dur_model.infer(
            ph_pred, {'tone': tone_pred}, None, None, None,
            incremental_state=incremental_state_dur,
            first_decoder_inp=last_dur_token,
            spk_pos_ids_flat=dur_spk_pos_ids_flat,
        )

    dur_pred = dur_pred - 1
    dur_pred = dur_pred.clamp(0, self.hp_dur_model['dur_code_size'] - 1)
    # if is_final:
    #     dur_pred[:, -1] = dur_pred[:, -1].clamp(64, 128)
    # else:
    #     dur_pred[:, -1] = dur_pred[:, -1].clamp(48, 128)
    # if seg_i > 0:
        # dur_pred[:, 0] = 0
    # ['。', '!', '?', 'sil']
    # for sil_token in [148, 153, 166, 145]:
    #     dur_pred[ph_pred==sil_token].clamp_min(32)
    # # [',', ';'] 
    # for sil_token in [163, 165]:
    #     dur_pred[ph_pred==sil_token].clamp_min(16)
    if not is_final:
        # add 0.32ms for crossfade
        dur_pred[:, -1] =  dur_pred[:, -1] + 32
    else:
        dur_pred[:, -1] = dur_pred[:, -1].clamp(64, 128)

    ''' DiT target speech generation '''
    dur_disturb_choice = (torch.rand_like(dur_pred.float()) > 0.5).float()
    dur_disturb_r = 1 + torch.rand_like(dur_pred.float()) * dur_disturb
    dur_pred = dur_pred * dur_disturb_r * dur_disturb_choice + \
            dur_pred / dur_disturb_r * (1 - dur_disturb_choice)
    dur_pred = torch.round(dur_pred * dur_alpha).clamp(0, 127)
    # ['。', '!', '?', 'sil']
    for sil_token in [148, 153, 166, 145]:
        dur_pred[ph_pred==sil_token] = dur_pred[ph_pred==sil_token].clamp_min(64)
    # [',', ';'] 
    for sil_token in [163, 165]:
        dur_pred[ph_pred==sil_token] = dur_pred[ph_pred==sil_token].clamp_min(32)
    if is_first:
        dur_pred[:, 0] = 8
    
    dur_sum = dur_pred.sum()
    npad = self.fm - dur_sum % self.fm
    if npad < self.fm:
        dur_pred[:, -1] += npad
    mel2ph_pred = self.length_regulator(dur_pred).to(self.device)
    return mel2ph_pred

def prepare_inputs_for_dit(self, mel2ph_ref, mel2ph_pred, ph_ref, tone_ref, ph_pred, tone_pred, vae_latent):
    # Prepare duration token 
    mel2ph_pred = torch.cat((mel2ph_ref, mel2ph_pred+ph_ref.size(1)), dim=1)
    mel2ph_pred = mel2ph_pred[:, :mel2ph_pred.size(1)//self.fm*self.fm].repeat(3, 1)
    # Prepare phone and tone token
    ph_pred = torch.cat((ph_ref, ph_pred), dim=1)
    tone_pred = torch.cat((tone_ref, tone_pred), dim=1)
    # Disable the English tone (set them to 3)"""
    en_tone_idx = ~((tone_pred == 4) | ( (11 <= tone_pred) & (tone_pred <= 15)) | (tone_pred == 0))
    tone_pred[en_tone_idx] = 3
    
    # Prepare cfg inputs
    ph_seq = torch.cat([ph_pred, ph_pred, torch.full(ph_pred.size(), self.cfg_mask_token_phone, device=self.device)], 0)
    tone_seq = torch.cat([tone_pred, tone_pred, torch.full(tone_pred.size(), self.cfg_mask_token_tone, device=self.device)], 0)
    target_size = mel2ph_pred.size(1)//self.vae_stride
    vae_latent_ = vae_latent.repeat(3, 1, 1)
    ctx_mask = torch.ones_like(vae_latent_[:, :, 0:1])
    vae_latent_ = F.pad(vae_latent_, (0, 0, 0, target_size - vae_latent.size(1)), mode='constant', value=0)
    vae_latent_[1:] = 0.0
    ctx_mask = F.pad(ctx_mask, (0, 0, 0, target_size - vae_latent.size(1)), mode='constant', value=0)

    return {
        'phone': ph_seq,
        'tone': tone_seq,
        "lat_ctx": vae_latent_ * ctx_mask,
        "ctx_mask": ctx_mask,
        "dur": mel2ph_pred,
    }

================================================
FILE: tts/modules/aligner/whisper_small.py
================================================
# MIT License

# Copyright (c) 2022 OpenAI

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

# Copyright (c) [2022] [OpenAI] 
# Copyright (c) [2025] [Ziyue Jiang] 
# SPDX-License-Identifier: MIT
# This file has been modified by Ziyue Jiang on 2025/03/19
# Original file was released under MIT, with the full license text # available at https://github.com/openai/whisper/blob/v20240930/LICENSE.
# This modified file is released under the same license.

from contextlib import contextmanager
from typing import Dict, Iterable, Optional, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn

from torch.nn.functional import scaled_dot_product_attention
SDPA_AVAILABLE = True


class LayerNorm(nn.LayerNorm):
    def forward(self, x: Tensor) -> Tensor:
        return super().forward(x.float()).type(x.dtype)


class Linear(nn.Linear):
    def forward(self, x: Tensor) -> Tensor:
        return F.linear(
            x,
            self.weight.to(x.dtype),
            None if self.bias is None else self.bias.to(x.dtype),
        )


class Conv1d(nn.Conv1d):
    def _conv_forward(
        self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
    ) -> Tensor:
        return super()._conv_forward(
            x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
        )


def sinusoids(length, channels, max_timescale=10000):
    """Returns sinusoids for positional embedding"""
    assert channels % 2 == 0
    log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
    scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
    return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)


@contextmanager
def disable_sdpa():
    prev_state = MultiHeadAttention.use_sdpa
    try:
        MultiHeadAttention.use_sdpa = False
        yield
    finally:
        MultiHeadAttention.use_sdpa = prev_state


class MultiHeadAttention(nn.Module):
    use_sdpa = True

    def __init__(self, n_state: int, n_head: int):
        super().__init__()
        self.n_head = n_head
        self.query = Linear(n_state, n_state)
        self.key = Linear(n_state, n_state, bias=False)
        self.value = Linear(n_state, n_state)
        self.out = Linear(n_state, n_state)

    def forward(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        kv_cache: Optional[dict] = None,
        casual: Optional[bool] = None
    ):
        q = self.query(x)

        if kv_cache is None or xa is None or self.key not in kv_cache:
            # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
            # otherwise, perform key/value projections for self- or cross-attention as usual.
            k = self.key(x if xa is None else xa)
            v = self.value(x if xa is None else xa)
        else:
            # for cross-attention, calculate keys and values once and reuse in subsequent calls.
            k = kv_cache[self.key]
            v = kv_cache[self.value]

        wv = self.qkv_attention(q, k, v, mask, casual)
        return self.out(wv)

    def qkv_attention(
        self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, casual: Optional[bool] = None
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        n_batch, n_ctx, n_state = q.shape
        scale = (n_state // self.n_head) ** -0.25
        q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
        k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
        v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)

        a = scaled_dot_product_attention(
            q, k, v, is_causal=casual and n_ctx > 1, attn_mask=mask[:, None, None, :] if mask is not None else None
        )
        out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
        return out


class ResidualAttentionBlock(nn.Module):
    def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
        super().__init__()

        self.attn = MultiHeadAttention(n_state, n_head)
        self.attn_ln = LayerNorm(n_state)

        self.cross_attn = (
            MultiHeadAttention(n_state, n_head) if cross_attention else None
        )
        self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None

        n_mlp = n_state * 4
        self.mlp = nn.Sequential(
            Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
        )
        self.mlp_ln = LayerNorm(n_state)

    def forward(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        kv_cache: Optional[dict] = None,
        casual: Optional[bool] = None,
    ):
        x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache, casual=casual)
        if self.cross_attn:
            # TODO: Cross attention mask
            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache, casual=False)
        x = x + self.mlp(self.mlp_ln(x))
        return x


class AudioEncoder(nn.Module):
    def __init__(
        self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
    ):
        super().__init__()
        self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
        self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
        self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))

        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
            [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
        )
        self.ln_post = LayerNorm(n_state)

    def forward(self, x: Tensor, attn_mask: Tensor):
        """
        x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
            the mel spectrogram of the audio
        """
        x = F.gelu(self.conv1(x))
        x = F.gelu(self.conv2(x))
        x = x.permute(0, 2, 1)

        # assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
        x = (x + self.positional_embedding[:x.size(1)]).to(x.dtype)

        for block in self.blocks:
            x = block(x, mask=attn_mask, casual=False)

        x = self.ln_post(x)
        return x


class TextDecoder(nn.Module):
    def __init__(
        self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
    ):
        super().__init__()

        self.token_embedding = nn.Embedding(n_vocab, n_state)
        self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))

        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
            [
                ResidualAttentionBlock(n_state, n_head, cross_attention=True)
                for _ in range(n_layer)
            ]
        )
        self.ln = LayerNorm(n_state)

        self.out_proj = nn.Linear(n_state, n_vocab)

    def forward(self, x: Tensor, attn_mask: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
        """
        x : torch.LongTensor, shape = (batch_size, <= n_ctx)
            the text tokens
        xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
            the encoded audio features to be attended on
        """
        offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
        x = (
            self.token_embedding(x)
            + self.positional_embedding[offset : offset + x.shape[-1]]
        )
        x = x.to(xa.dtype)

        for block in self.blocks:
            x = block(x, xa, mask=attn_mask, kv_cache=kv_cache, casual=True)

        x = self.ln(x)
        # logits = (
        #     x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
        # ).float()
        logits = self.out_proj(x)

        return logits


class Whisper(nn.Module):
    def __init__(self):
        super().__init__()
        self.n_vocab = 6800
        self.n_text_layer = 6
        self.n_text_head = 8
        self.n_text_ctx = 2048

        self.encoder = AudioEncoder(
            n_mels=80, n_ctx=3000, n_state=512, n_head=8, n_layer=6,
        )
        self.decoder = TextDecoder(
            n_vocab=6800, n_ctx=2048, n_state=512, n_head=8, n_layer=6,
        )

    def embed_audio(self, mel: torch.Tensor):
        return self.encoder(mel, None)

    def logits(self, tokens, audio_features, kv_cache=None):
        return self.decoder(tokens, None, audio_features, kv_cache=kv_cache)

    def forward(
        self, mel, mel_len, token, token_len
    ) -> Dict[str, torch.Tensor]:
        attn_mask_enc = self.sequence_mask(mel_len//2, device=mel.device) > 0
        attn_mask_dec = self.sequence_mask(token_len, device=mel.device) > 0
        return self.decoder(token, attn_mask_dec, self.encoder(mel, attn_mask_enc))

    @property
    def device(self):
        return next(self.parameters()).device

    def install_kv_cache_hooks(self, cache: Optional[dict] = None):
        """
        The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
        tensors calculated for the previous positions. This method returns a dictionary that stores
        all caches, and the necessary hooks for the key and value projection modules that save the
        intermediate tensors to be reused during later calculations.

        Returns
        -------
        cache : Dict[nn.Module, torch.Tensor]
            A dictionary object mapping the key/value projection modules to its cache
        hooks : List[RemovableHandle]
            List of PyTorch RemovableHandle objects to stop the hooks to be called
        """
        cache = {**cache} if cache is not None else {}
        hooks = []

        def save_to_cache(module, _, output):
            if module not in cache or output.shape[1] > self.n_text_ctx:
                # save as-is, for the first token or cross attention
                cache[module] = output
            else:
                cache[module] = torch.cat([cache[module], output], dim=1).detach()
            return cache[module]

        def install_hooks(layer: nn.Module):
            if isinstance(layer, MultiHeadAttention):
                hooks.append(layer.key.register_forward_hook(save_to_cache))
                hooks.append(layer.value.register_forward_hook(save_to_cache))

        self.decoder.apply(install_hooks)
        return cache, hooks
    
    def sequence_mask(self, seq_lens, max_len=None, device='cpu'):
        b = seq_lens.shape[0]
        if max_len is None:
            max_len = seq_lens.max()
        mask = torch.arange(max_len).unsqueeze(0).to(device)  # [1, t]
        mask = mask < (seq_lens.unsqueeze(1))  # [1, t] + [b, 1] = [b, t]
        mask = mask.float()
        return mask


================================================
FILE: tts/modules/ar_dur/ar_dur_predictor.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
from copy import deepcopy

import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import Linear
from tqdm import tqdm

from tts.modules.ar_dur.commons.layers import Embedding, LayerNorm
from tts.modules.ar_dur.commons.nar_tts_modules import PosEmb
from tts.modules.ar_dur.commons.rot_transformer import RotTransformerDecoderLayer
from tts.modules.ar_dur.commons.transformer import SinusoidalPositionalEmbedding
from tts.modules.ar_dur.commons.rel_transformer import RelTransformerEncoder

FS_ENCODERS = {
    'rel_fft': lambda hp, dict_size: RelTransformerEncoder(
        dict_size, hp['hidden_size'], hp['hidden_size'],
        hp['ffn_hidden_size'], hp['num_heads'], hp['enc_layers'],
        hp['enc_ffn_kernel_size'], hp['dropout'], prenet=hp['enc_prenet'], pre_ln=hp['enc_pre_ln']),
}

def fill_with_neg_inf2(t):
    """FP16-compatible function that fills a tensor with -inf."""
    return t.float().fill_(-1e8).type_as(t)

def expand_states(h, mel2token):
    h = F.pad(h, [0, 0, 1, 0])
    mel2token_ = mel2token[..., None].repeat([1, 1, h.shape[-1]])
    h = torch.gather(h, 1, mel2token_)  # [B, T, H]
    return h


class CodePredictor(nn.Module):
    def __init__(self, hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size):
        super().__init__()
        self.hparams = deepcopy(hparams)
        self.hparams['hidden_size'] = hidden_size
        self.hidden_size = hidden_size
        char_dict_size = hparams.get('char_dict_size', 4000)
        if not hparams.get('lm_use_enc'):
            self.encoder = nn.Embedding(dict_size, self.hidden_size, padding_idx=0)
            if hparams.get('mega_use_char', True):
                self.char_encoder = nn.Embedding(char_dict_size,
                                                 self.hidden_size, padding_idx=0)
        else:
            self.encoder = FS_ENCODERS[self.hparams['encoder_type']](self.hparams, dict_size)
            if hparams.get('mega_use_char', True):
                self.char_encoder = FS_ENCODERS[self.hparams['encoder_type']](self.hparams, char_dict_size)
            if hparams['use_ph_pos_embed']:
                self.ph_pos_embed = PosEmb(self.hidden_size)

        self.char_empty_embed = nn.Embedding(1, self.hidden_size)
        if hparams.get('use_bert_input'):
            self.bert_input_proj = nn.Linear(768, self.hidden_size)
        self.ling_label_embed_layers = nn.ModuleDict()
        for k, s in zip(hparams['ling_labels'], hparams['ling_label_dict_size']):
            self.ling_label_embed_layers[k] = Embedding(s + 3, self.hidden_size, padding_idx=0)

        self.dec_hidden_size = dec_hidden_size
        self.enc_proj = nn.Linear(self.hidden_size, dec_hidden_size)
        self.code_emb = Embedding(code_size + 2, dec_hidden_size, 0)
        self.use_pos_embed = hparams.get('use_pos_embed', False)
        if self.use_pos_embed:
            self.embed_positions = SinusoidalPositionalEmbedding(dec_hidden_size, 0, init_size=1024)
        self.use_post_ln = hparams.get('use_post_ln', False)
        self.layers = None
        if not self.use_post_ln:
            self.layer_norm = LayerNorm(dec_hidden_size)
        self.code_size = code_size
        self.project_out_dim = Linear(dec_hidden_size, code_size + 1, bias=True)

    def forward_ling_encoder(
            self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, spk_id, spk_embed, mels_timbre):
        ph_tokens = txt_tokens
        hparams = self.hparams
        ph_nonpadding = (ph_tokens > 0).float()[:, :, None]  # [B, T_phone, 1]
        x_spk = self.forward_style_embed(spk_embed, spk_id, mels_timbre)

        # enc_ph
        if not hparams.get('lm_use_enc'):
            x_ph = self.encoder(ph_tokens)
            x_ph = x_ph + sum(
                [self.ling_label_embed_layers[k](ling_feas[k]) for k in hparams['ling_labels']]) \
                if len(hparams['ling_labels']) > 0 else 0
            x_ph = x_ph + x_spk
        else:
            # enc_ph
            ph_enc_oembed = sum(
                [self.ling_label_embed_layers[k](ling_feas[k]) for k in hparams['ling_labels']]) \
                if len(hparams['ling_labels']) > 0 else 0
            ph_enc_oembed = ph_enc_oembed + self.ph_pos_embed(
                torch.arange(0, ph_tokens.shape[1])[None,].to(ph_tokens.device))
            ph_enc_oembed = ph_enc_oembed + x_spk
            ph_enc_oembed = ph_enc_oembed * ph_nonpadding
            x_ph = self.encoder(ph_tokens, other_embeds=ph_enc_oembed)

        # enc_char
        if char_tokens is not None and ph2char is not None:
            char_nonpadding = (char_tokens > 0).float()[:, :, None]
            x_char = self.char_encoder(char_tokens)
            empty_char = (ph2char > 100000).long()
            ph2char = ph2char * (1 - empty_char)
            x_char_phlevel = \
                expand_states(x_char * char_nonpadding, ph2char) \
                * (1 - empty_char)[..., None] + \
                self.char_empty_embed(torch.zeros_like(ph_tokens)) * empty_char[..., None]
        else:
            x_char_phlevel = 0
        # x_ling
        x_ling = x_ph + x_char_phlevel
        x_ling = x_ling * ph_nonpadding
        x_ling = self.enc_proj(x_ling)
        return x_ling

    def sample_one_step(self, vq_pred):
        hparams = self.hparams
        if hparams.get('infer_top_k'):
            top_k = hparams.get('infer_top_k')
            temperature = hparams.get('infer_temperature', 1)
            vq_pred = vq_pred[:, -1] / temperature
            # optionally crop the logits to only the top k options
            if top_k is not None:
                v, _ = torch.topk(vq_pred, min(top_k, vq_pred.size(-1)))
                vq_pred[vq_pred < v[:, [-1]]] = -float('Inf')
            # apply softmax to convert logits to (normalized) probabilities
            probs = F.softmax(vq_pred, dim=-1)
            # sample from the distribution
            vq_pred = torch.multinomial(probs, num_samples=1)
        else:
            vq_pred = torch.argmax(F.softmax(vq_pred[:, -1], dim=-1), 1)
        return vq_pred

    def forward_style_embed(self, spk_embed=None, spk_id=None, mel_ref=None):
        # add spk embed
        style_embed = 0
        if self.hparams['use_spk_embed']:
            style_embed = style_embed + self.spk_embed_proj(spk_embed)[:, None, :]
        if self.hparams['use_spk_id']:
            style_embed = style_embed + self.spk_id_proj(spk_id)[:, None, :]
        if self.hparams['use_spk_enc']:
            style_embed = style_embed + self.spk_enc(mel_ref)[:, None, :]
        return style_embed

    def buffered_future_mask(self, tensor):
        dim = tensor.size(0)
        if (
                not hasattr(self, '_future_mask')
                or self._future_mask is None
                or self._future_mask.device != tensor.device
                or self._future_mask.size(0) < dim
        ):
            self._future_mask = torch.triu(fill_with_neg_inf2(tensor.new(dim, dim)), 1)
        return self._future_mask[:dim, :dim]


class ARDurPredictor(CodePredictor):
    def __init__(self, hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size, use_rot_embed=True,
                 op_version=1):
        super().__init__(hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size)
        self.use_rot_embed = use_rot_embed
        bias = hparams.get('lm_bias', True)
        if self.use_rot_embed:
            self.layers = nn.ModuleList([])
            self.layers.extend([
                RotTransformerDecoderLayer(
                    dec_hidden_size, 0.0, kernel_size=1, ffn_hidden_size=dec_hidden_size * 4,
                    post_ln=self.use_post_ln, op_version=op_version, bias=bias)
                for _ in range(lm_num_layers)
            ])
        if hparams['dur_model_type'] == 'ar_mse':
            self.project_out_dim = nn.Sequential(torch.nn.Linear(dec_hidden_size, 1), nn.Softplus())
        else:
            self.project_out_dim = torch.nn.Linear(dec_hidden_size, code_size + 1)

    def forward(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
                prev_code, spk_id=None, spk_embed=None, mels_timbre=None, mel2ph=None,
                incremental_state=None, x_ling=None, attn_mask=None, spk_pos_ids_flat=None,
                prompt_length=None, cache_size=20, streaming=False):
        x = self.code_emb(prev_code)
        if x_ling is None:
            x_ling = self.forward_ling_encoder(
                txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, spk_id, spk_embed, mels_timbre)
            x_ling = x_ling.flatten(0, 1)
            txt_tokens = txt_tokens.flatten(0, 1)
            x_ling = x_ling[txt_tokens > 0][None]

        # run decoder
        self_attn_padding_mask = None
        if self.use_pos_embed:
            positions = self.embed_positions(
                prev_code,
                incremental_state=incremental_state
            )
        if incremental_state is not None:
            x_ling = x_ling[:, x.shape[1] - 1:x.shape[1]]
            if spk_pos_ids_flat is not None:
                spk_pos_ids_flat = spk_pos_ids_flat[:, x.shape[1] - 1:x.shape[1]]
            x = x[:, -1:]
            if self.use_pos_embed:
                positions = positions[:, -1:]
            if streaming:
                # Shift Pos: query pos is min(cache_size, idx)
                spk_pos_ids_flat = torch.min(torch.LongTensor([prompt_length + cache_size]).to(x.device),
                                             spk_pos_ids_flat)

        # # B x T x C -> T x B x C
        if self.use_pos_embed:
            x = x + positions
        x_ling = x_ling[:, :self.hparams['max_tokens']].contiguous()
        T = min(self.hparams.get('max_tokens_per_item', 1e9), x_ling.shape[1])
        x_ling = x_ling.reshape(-1, T, x_ling.shape[-1])
        x = x + x_ling
        x = x.transpose(0, 1)

        for idx, layer in enumerate(self.layers):
            if incremental_state is None:
                self_attn_mask = self.buffered_future_mask(x)
                if attn_mask is not None:
                    self_attn_mask = self_attn_mask + (1 - attn_mask.float()) * -1e8
                self_attn_mask = self_attn_mask.clamp_min(-1e8)
            else:
                self_attn_mask = None

            x, attn_weights = layer(
                x,
                incremental_state=incremental_state,
                self_attn_mask=self_attn_mask,
                self_attn_padding_mask=self_attn_padding_mask,
                spk_pos_ids_flat=spk_pos_ids_flat
            )

        if streaming and incremental_state != {}:
            for k, v in incremental_state.items():
                if 'attn_state' in k:
                    prev_key, prev_value = incremental_state[k]['prev_key'], incremental_state[k]['prev_value']
                    cur_length = prev_key.shape[2]
                    if cur_length - prompt_length > cache_size:
                        prev_key = torch.cat((prev_key[:, :, :prompt_length], prev_key[:, :, -cache_size:]), dim=2)
                        prev_value = torch.cat((prev_value[:, :, :prompt_length], prev_value[:, :, -cache_size:]),
                                               dim=2)
                    incremental_state[k]['prev_key'], incremental_state[k]['prev_value'] = prev_key, prev_value

        if not self.use_post_ln:
            x = self.layer_norm(x)
        # T x B x C -> B x T x C
        x = x.transpose(0, 1)
        x = self.project_out_dim(x)
        return x

    def infer(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
              spk_id=None, spk_embed=None, mels_timbre=None,
              incremental_state=None, ctx_vqcodes=None, spk_pos_ids_flat=None, return_state=False,
              first_step_min=0, return_probs=False, first_decoder_inp=None, dur_disturb=0.0, **kwargs):
        if incremental_state is None:
            incremental_state = {}
        x_ling = self.forward_ling_encoder(
            txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
            spk_id, spk_embed, mels_timbre)
        x_ling = x_ling.flatten(0, 1)
        txt_tokens_ori = txt_tokens
        txt_tokens_withpad = txt_tokens = txt_tokens.flatten(0, 1)
        x_ling = x_ling[txt_tokens > 0][None]
        txt_tokens = txt_tokens[txt_tokens > 0][None]

        decoded = torch.zeros_like(txt_tokens)
        decoded = F.pad(decoded, [1, 0], value=self.code_size + 1)
        if incremental_state != {}:
            if first_decoder_inp is None:
                assert ctx_vqcodes is not None
                decoded[:, :ctx_vqcodes.shape[1]] = ctx_vqcodes
                ctx_vqcodes = None
            else:
                decoded[:, :1] = first_decoder_inp
        probs = []
        for step in range(decoded.shape[1] - 1):
            vq_pred = self(txt_tokens, None, None, None, None,
                           decoded[:, :step + 1], None, None, None,
                           incremental_state=incremental_state, x_ling=x_ling,
                           spk_pos_ids_flat=spk_pos_ids_flat, **kwargs)
            probs.append(vq_pred.cpu())
            if ctx_vqcodes is None or step >= ctx_vqcodes.shape[1]:
                if self.hparams['dur_model_type'] == 'ar_mse':
                    d = vq_pred[:, -1, 0]
                    if dur_disturb > 0 and step >= 1:
                        if random.random() > 0.5:
                            d = d * (1 + random.random() * dur_disturb)
                        else:
                            d = d / (1 + random.random() * dur_disturb)
                        d = torch.clamp_max(d, self.code_size - 1)
                    vq_pred = torch.round(d).long()
                else:
                    vq_pred = self.sample_one_step(vq_pred)
                decoded[:, step + 1] = torch.clamp_min(vq_pred, 1)
                if step == 0:
                    decoded[:, step + 1] = torch.clamp_min(vq_pred, first_step_min)
            else:
                decoded[:, step + 1] = ctx_vqcodes[:, step]
        decoded = decoded[:, 1:]
        decoded_2d = torch.zeros_like(txt_tokens_ori)
        decoded_2d.flatten(0, 1)[txt_tokens_withpad > 0] = decoded
        if return_state:
            return decoded_2d, incremental_state
        if return_probs:
            return decoded_2d, torch.cat(probs, 1)
        return decoded_2d

    def streaming_infer(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
                        spk_id=None, spk_embed=None, mels_timbre=None,
                        incremental_state=None, ctx_vqcodes=None, spk_pos_ids_flat=None, return_state=False,
                        **kwargs):
        if incremental_state is None:
            incremental_state = {}
        x_ling = self.forward_ling_encoder(
            txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
            spk_id, spk_embed, mels_timbre)
        x_ling = x_ling.flatten(0, 1)
        txt_tokens_ori = txt_tokens
        txt_tokens_withpad = txt_tokens = txt_tokens.flatten(0, 1)
        x_ling = x_ling[txt_tokens > 0][None]
        txt_tokens = txt_tokens[txt_tokens > 0][None]

        vq_decoded = torch.zeros_like(txt_tokens)
        vq_decoded = F.pad(vq_decoded, [1, 0], value=self.code_size + 1)
        if incremental_state != {}:
            assert ctx_vqcodes is not None
            vq_decoded[:, :ctx_vqcodes.shape[1]] = ctx_vqcodes
            ctx_vqcodes = None
        prompt_length = list(incremental_state.items())[0][1]['prev_key'].shape[2]
        for step in tqdm(range(vq_decoded.shape[1] - 1), desc='AR Duration Predictor inference...'):
            vq_pred = self(txt_tokens, None, None, None, None,
                           vq_decoded[:, :step + 1], None, None, None,
                           incremental_state=incremental_state, x_ling=x_ling,
                           spk_pos_ids_flat=spk_pos_ids_flat, prompt_length=prompt_length, streaming=True, **kwargs)
            if ctx_vqcodes is None or step >= ctx_vqcodes.shape[1]:
                if self.hparams['dur_model_type'] == 'ar_mse':
                    vq_pred = torch.round(vq_pred[:, -1, 0]).long()
                else:
                    vq_pred = self.sample_one_step(vq_pred)
                vq_decoded[:, step + 1] = vq_pred
            else:
                vq_decoded[:, step + 1] = ctx_vqcodes[:, step]
        vq_decoded = vq_decoded[:, 1:]
        vq_decoded_2d = torch.zeros_like(txt_tokens_ori)
        vq_decoded_2d.flatten(0, 1)[txt_tokens_withpad > 0] = vq_decoded
        if return_state:
            return vq_decoded_2d, incremental_state
        return vq_decoded_2d

================================================
FILE: tts/modules/ar_dur/commons/layers.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch import nn


class LayerNorm(torch.nn.LayerNorm):
    """Layer normalization module.
    :param int nout: output dim size
    :param int dim: dimension to be normalized
    """

    def __init__(self, nout, dim=-1, eps=1e-5):
        """Construct an LayerNorm object."""
        super(LayerNorm, self).__init__(nout, eps=eps)
        self.dim = dim

    def forward(self, x):
        """Apply layer normalization.
        :param torch.Tensor x: input tensor
        :return: layer normalized tensor
        :rtype torch.Tensor
        """
        if self.dim == -1:
            return super(LayerNorm, self).forward(x)
        return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)


class Reshape(nn.Module):
    def __init__(self, *args):
        super(Reshape, self).__init__()
        self.shape = args

    def forward(self, x):
        return x.view(self.shape)


class Permute(nn.Module):
    def __init__(self, *args):
        super(Permute, self).__init__()
        self.args = args

    def forward(self, x):
        return x.permute(self.args)


def Embedding(num_embeddings, embedding_dim, padding_idx=None):
    m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
    nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
    if padding_idx is not None:
        nn.init.constant_(m.weight[padding_idx], 0)
    return m


================================================
FILE: tts/modules/ar_dur/commons/nar_tts_modules.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import torch
from torch import nn

import torch.nn.functional as F


class LengthRegulator(torch.nn.Module):
    def __init__(self, pad_value=0.0):
        super(LengthRegulator, self).__init__()
        self.pad_value = pad_value

    def forward(self, dur, dur_padding=None, alpha=1.0):
        """
        Example (no batch dim version):
            1. dur = [2,2,3]
            2. token_idx = [[1],[2],[3]], dur_cumsum = [2,4,7], dur_cumsum_prev = [0,2,4]
            3. token_mask = [[1,1,0,0,0,0,0],
                             [0,0,1,1,0,0,0],
                             [0,0,0,0,1,1,1]]
            4. token_idx * token_mask = [[1,1,0,0,0,0,0],
                                         [0,0,2,2,0,0,0],
                                         [0,0,0,0,3,3,3]]
            5. (token_idx * token_mask).sum(0) = [1,1,2,2,3,3,3]

        :param dur: Batch of durations of each frame (B, T_txt)
        :param dur_padding: Batch of padding of each frame (B, T_txt)
        :param alpha: duration rescale coefficient
        :return:
            mel2ph (B, T_speech)
        assert alpha > 0
        """
        dur = torch.round(dur.float() * alpha).long()
        if dur_padding is not None:
            dur = dur * (1 - dur_padding.long())
        token_idx = torch.arange(1, dur.shape[1] + 1)[None, :, None].to(dur.device)
        dur_cumsum = torch.cumsum(dur, 1)
        dur_cumsum_prev = F.pad(dur_cumsum, [1, -1], mode='constant', value=0)

        pos_idx = torch.arange(dur.sum(-1).max())[None, None].to(dur.device)
        token_mask = (pos_idx >= dur_cumsum_prev[:, :, None]) & (pos_idx < dur_cumsum[:, :, None])
        mel2token = (token_idx * token_mask.long()).sum(1)
        return mel2token


class PosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim) * -emb)
        self.emb = emb  # TODO

    def forward(self, x):
        emb = x[:, :, None] * self.emb[None, None, :].to(x.device)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


================================================
FILE: tts/modules/ar_dur/commons/rel_transformer.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import torch
from torch import nn
from torch.nn import functional as F

from tts.modules.ar_dur.commons.layers import Embedding


def convert_pad_shape(pad_shape):
    l = pad_shape[::-1]
    pad_shape = [item for sublist in l for item in sublist]
    return pad_shape


def shift_1d(x):
    x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
    return x


def sequence_mask(length, max_length=None):
    if max_length is None:
        max_length = length.max()
    x = torch.arange(max_length, dtype=length.dtype, device=length.device)
    return x.unsqueeze(0) < length.unsqueeze(1)


class Encoder(nn.Module):
    def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.,
                 window_size=None, block_length=None, pre_ln=False, **kwargs):
        super().__init__()
        self.hidden_channels = hidden_channels
        self.filter_channels = filter_channels
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.kernel_size = kernel_size
        self.p_dropout = p_dropout
        self.window_size = window_size
        self.block_length = block_length
        self.pre_ln = pre_ln

        self.drop = nn.Dropout(p_dropout)
        self.attn_layers = nn.ModuleList()
        self.norm_layers_1 = nn.ModuleList()
        self.ffn_layers = nn.ModuleList()
        self.norm_layers_2 = nn.ModuleList()
        for i in range(self.n_layers):
            self.attn_layers.append(
                MultiHeadAttention(hidden_channels, hidden_channels, n_heads, window_size=window_size,
                                   p_dropout=p_dropout, block_length=block_length))
            self.norm_layers_1.append(LayerNorm(hidden_channels))
            self.ffn_layers.append(
                FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
            self.norm_layers_2.append(LayerNorm(hidden_channels))
        if pre_ln:
            self.last_ln = LayerNorm(hidden_channels)

    def forward(self, x, x_mask, attn_mask=1):
        if isinstance(attn_mask, torch.Tensor):
            attn_mask = attn_mask[:, None]
        attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) * attn_mask
        for i in range(self.n_layers):
            x = x * x_mask
            x_ = x
            if self.pre_ln:
                x = self.norm_layers_1[i](x)
            y = self.attn_layers[i](x, x, attn_mask)
            y = self.drop(y)
            x = x_ + y
            if not self.pre_ln:
                x = self.norm_layers_1[i](x)

            x_ = x
            if self.pre_ln:
                x = self.norm_layers_2[i](x)
            y = self.ffn_layers[i](x, x_mask)
            y = self.drop(y)
            x = x_ + y
            if not self.pre_ln:
                x = self.norm_layers_2[i](x)
        if self.pre_ln:
            x = self.last_ln(x)
        x = x * x_mask
        return x


class MultiHeadAttention(nn.Module):
    def __init__(self, channels, out_channels, n_heads, window_size=None, heads_share=True, p_dropout=0.,
                 block_length=None, proximal_bias=False, proximal_init=False):
        super().__init__()
        assert channels % n_heads == 0

        self.channels = channels
        self.out_channels = out_channels
        self.n_heads = n_heads
        self.window_size = window_size
        self.heads_share = heads_share
        self.block_length = block_length
        self.proximal_bias = proximal_bias
        self.p_dropout = p_dropout
        self.attn = None

        self.k_channels = channels // n_heads
        self.conv_q = nn.Conv1d(channels, channels, 1)
        self.conv_k = nn.Conv1d(channels, channels, 1)
        self.conv_v = nn.Conv1d(channels, channels, 1)
        if window_size is not None:
            n_heads_rel = 1 if heads_share else n_heads
            rel_stddev = self.k_channels ** -0.5
            self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
            self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
        self.conv_o = nn.Conv1d(channels, out_channels, 1)
        self.drop = nn.Dropout(p_dropout)

        nn.init.xavier_uniform_(self.conv_q.weight)
        nn.init.xavier_uniform_(self.conv_k.weight)
        if proximal_init:
            self.conv_k.weight.data.copy_(self.conv_q.weight.data)
            self.conv_k.bias.data.copy_(self.conv_q.bias.data)
        nn.init.xavier_uniform_(self.conv_v.weight)

    def forward(self, x, c, attn_mask=None):
        q = self.conv_q(x)
        k = self.conv_k(c)
        v = self.conv_v(c)

        x, self.attn = self.attention(q, k, v, mask=attn_mask)

        x = self.conv_o(x)
        return x

    def attention(self, query, key, value, mask=None):
        # reshape [b, d, t] -> [b, n_h, t, d_k]
        b, d, t_s, t_t = (*key.size(), query.size(2))
        query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
        key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
        value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)

        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
        if self.window_size is not None:
            assert t_s == t_t, "Relative attention is only available for self-attention."
            key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
            rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
            rel_logits = self._relative_position_to_absolute_position(rel_logits)
            scores_local = rel_logits / math.sqrt(self.k_channels)
            scores = scores + scores_local
        if self.proximal_bias:
            assert t_s == t_t, "Proximal bias is only available for self-attention."
            scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e4)
            if self.block_length is not None:
                block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
                scores = scores * block_mask + -1e4 * (1 - block_mask)
        p_attn = F.softmax(scores, dim=-1)  # [b, n_h, t_t, t_s]
        p_attn = self.drop(p_attn)
        output = torch.matmul(p_attn, value)
        if self.window_size is not None:
            relative_weights = self._absolute_position_to_relative_position(p_attn)
            value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
            output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
        output = output.transpose(2, 3).contiguous().view(b, d, t_t)  # [b, n_h, t_t, d_k] -> [b, d, t_t]
        return output, p_attn

    def _matmul_with_relative_values(self, x, y):
        """
        x: [b, h, l, m]
        y: [h or 1, m, d]
        ret: [b, h, l, d]
        """
        ret = torch.matmul(x, y.unsqueeze(0))
        return ret

    def _matmul_with_relative_keys(self, x, y):
        """
        x: [b, h, l, d]
        y: [h or 1, m, d]
        ret: [b, h, l, m]
        """
        ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
        return ret

    def _get_relative_embeddings(self, relative_embeddings, length):
        max_relative_position = 2 * self.window_size + 1
        # Pad first before slice to avoid using cond ops.
        pad_length = max(length - (self.window_size + 1), 0)
        slice_start_position = max((self.window_size + 1) - length, 0)
        slice_end_position = slice_start_position + 2 * length - 1
        if pad_length > 0:
            padded_relative_embeddings = F.pad(
                relative_embeddings,
                convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
        else:
            padded_relative_embeddings = relative_embeddings
        used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
        return used_relative_embeddings

    def _relative_position_to_absolute_position(self, x):
        """
        x: [b, h, l, 2*l-1]
        ret: [b, h, l, l]
        """
        batch, heads, length, _ = x.size()
        # Concat columns of pad to shift from relative to absolute indexing.
        x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))

        # Concat extra elements so to add up to shape (len+1, 2*len-1).
        x_flat = x.view([batch, heads, length * 2 * length])
        x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))

        # Reshape and slice out the padded elements.
        x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:]
        return x_final

    def _absolute_position_to_relative_position(self, x):
        """
        x: [b, h, l, l]
        ret: [b, h, l, 2*l-1]
        """
        batch, heads, length, _ = x.size()
        # padd along column
        x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
        x_flat = x.view([batch, heads, -1])
        # add 0's in the beginning that will skew the elements after reshape
        x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
        x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
        return x_final

    def _attention_bias_proximal(self, length):
        """Bias for self-attention to encourage attention to close positions.
        Args:
          length: an integer scalar.
        Returns:
          a Tensor with shape [1, 1, length, length]
        """
        r = torch.arange(length, dtype=torch.float32)
        diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
        return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)


class FFN(nn.Module):
    def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.filter_channels = filter_channels
        self.kernel_size = kernel_size
        self.p_dropout = p_dropout
        self.activation = activation

        self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
        self.conv_2 = nn.Conv1d(filter_channels, out_channels, 1)
        self.drop = nn.Dropout(p_dropout)

    def forward(self, x, x_mask):
        x = self.conv_1(x * x_mask)
        if self.activation == "gelu":
            x = x * torch.sigmoid(1.702 * x)
        else:
            x = torch.relu(x)
        x = self.drop(x)
        x = self.conv_2(x * x_mask)
        return x * x_mask


class LayerNorm(nn.Module):
    def __init__(self, channels, eps=1e-4):
        super().__init__()
        self.channels = channels
        self.eps = eps

        self.gamma = nn.Parameter(torch.ones(channels))
        self.beta = nn.Parameter(torch.zeros(channels))

    def forward(self, x):
        n_dims = len(x.shape)
        mean = torch.mean(x, 1, keepdim=True)
        variance = torch.mean((x - mean) ** 2, 1, keepdim=True)

        x = (x - mean) * torch.rsqrt(variance + self.eps)

        shape = [1, -1] + [1] * (n_dims - 2)
        x = x * self.gamma.view(*shape) + self.beta.view(*shape)
        return x


class ConvReluNorm(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
        super().__init__()
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.n_layers = n_layers
        self.p_dropout = p_dropout
        assert n_layers > 1, "Number of layers should be larger than 0."

        self.conv_layers = nn.ModuleList()
        self.norm_layers = nn.ModuleList()
        self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
        self.norm_layers.append(LayerNorm(hidden_channels))
        self.relu_drop = nn.Sequential(
            nn.ReLU(),
            nn.Dropout(p_dropout))
        for _ in range(n_layers - 1):
            self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
            self.norm_layers.append(LayerNorm(hidden_channels))
        self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
        self.proj.weight.data.zero_()
        self.proj.bias.data.zero_()

    def forward(self, x, x_mask):
        x_org = x
        for i in range(self.n_layers):
            x = self.conv_layers[i](x * x_mask)
            x = self.norm_layers[i](x)
            x = self.relu_drop(x)
        x = x_org + self.proj(x)
        return x * x_mask


class RelTransformerEncoder(nn.Module):
    def __init__(self,
                 n_vocab,
                 out_channels,
                 hidden_channels,
                 filter_channels,
                 n_heads,
                 n_layers,
                 kernel_size,
                 p_dropout=0.0,
                 window_size=4,
                 block_length=None,
                 in_channels=None,
                 prenet=True,
                 pre_ln=True,
                 ):

        super().__init__()

        self.n_vocab = n_vocab
        self.out_channels = out_channels
        self.hidden_channels = hidden_channels
        self.filter_channels = filter_channels
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.kernel_size = kernel_size
        self.p_dropout = p_dropout
        self.window_size = window_size
        self.block_length = block_length
        self.prenet = prenet
        if n_vocab > 0:
            self.emb = Embedding(n_vocab, hidden_channels, padding_idx=0)

        if prenet:
            if in_channels is None:
                in_channels = hidden_channels
            self.pre = ConvReluNorm(in_channels, in_channels, in_channels,
                                    kernel_size=5, n_layers=3, p_dropout=0)
        if in_channels is not None and in_channels != hidden_channels:
            self.encoder_inp_proj = nn.Conv1d(in_channels, hidden_channels, 1)
        self.encoder = Encoder(
            hidden_channels,
            filter_channels,
            n_heads,
            n_layers,
            kernel_size,
            p_dropout,
            window_size=window_size,
            block_length=block_length,
            pre_ln=pre_ln,
        )

    def forward(self, x, x_mask=None, other_embeds=0, attn_mask=1):
        if self.n_vocab > 0:
            x_lengths = (x > 0).long().sum(-1)
            x = self.emb(x) * math.sqrt(self.hidden_channels)  # [b, t, h]
        else:
            x_lengths = (x.abs().sum(-1) > 0).long().sum(-1)
        x = x + other_embeds
        x = torch.transpose(x, 1, -1)  # [b, h, t]
        x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)

        if self.prenet:
            x = self.pre(x, x_mask)
            self.prenet_out = x.transpose(1, 2)
        if hasattr(self, 'encoder_inp_proj'):
            x = self.encoder_inp_proj(x) * x_mask
        x = self.encoder(x, x_mask, attn_mask)
        return x.transpose(1, 2)


================================================
FILE: tts/modules/ar_dur/commons/rot_transformer.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import torch
from typing import Optional, Tuple
from torch import nn
from torch.nn import Parameter, Linear
from tts.modules.ar_dur.commons.layers import LayerNorm, Embedding
from tts.modules.ar_dur.commons.transformer import TransformerFFNLayer, MultiheadAttention
from tts.modules.ar_dur.commons.seq_utils import get_incremental_state, set_incremental_state, softmax, make_positions
import torch.nn.functional as F

DEFAULT_MAX_SOURCE_POSITIONS = 3000
DEFAULT_MAX_TARGET_POSITIONS = 3000


class SinusoidalPositionalEmbedding(nn.Module):
    """This module produces sinusoidal positional embeddings of any length.

    Padding symbols are ignored.
    """

    def __init__(self, embedding_dim, padding_idx, init_size=1024):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.padding_idx = padding_idx
        self.weights = SinusoidalPositionalEmbedding.get_embedding(
            init_size,
            embedding_dim,
            padding_idx,
        )
        self.register_buffer('_float_tensor', torch.FloatTensor(1))

    @staticmethod
    def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
        """Build sinusoidal embeddings.

        This matches the implementation in tensor2tensor, but differs slightly
        from the description in Section 3.5 of "Attention Is All You Need".
        """
        half_dim = embedding_dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
        emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
        if embedding_dim % 2 == 1:
            # zero pad
            emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
        if padding_idx is not None:
            emb[padding_idx, :] = 0
        return emb

    def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
        """Input is expected to be of size [bsz x seqlen]."""
        bsz, seq_len = input.shape[:2]
        max_pos = self.padding_idx + 1 + seq_len
        if self.weights is None or max_pos > self.weights.size(0):
            # recompute/expand embeddings if needed
            self.weights = SinusoidalPositionalEmbedding.get_embedding(
                max_pos,
                self.embedding_dim,
                self.padding_idx,
            )
        self.weights = self.weights.to(self._float_tensor)

        if incremental_state is not None:
            # positions is the same for every token when decoding a single step
            pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
            return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)

        positions = make_positions(input, self.padding_idx) if positions is None else positions
        return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()

    def max_positions(self):
        """Maximum number of supported positions."""
        return int(1e5)  # an arbitrary large number


class RotaryEmbeddings(nn.Module):
    cos: torch.Tensor
    sin: torch.Tensor
    theta: torch.Tensor

    def __init__(
            self,
            width: int,
            *,
            seq_len: int = 40000,
            base: int = 10000,
            device: Optional[torch.device] = None,
    ):
        """Rotary embeddings (Su et al., 2021) layer. The rotary embedding
        will be precomputed for up to 'seq _len' positions. The embedding
        will be recomputed when a longer sequence is found in the input.

        :param width:
            Rotary embedding dimensionality, must be even.
        :param seq_len:
            Number of positons to initially precompute.
        :param base:
            The base used for Θ_i, determines the cycle length of the
            embeddings.
        :param device: Device on which the module is to be initialized.
        """
        super().__init__()

        if width % 2:
            raise ValueError(f"Width of rotary embeddings must be even, was: {width}")

        # Ignore allocations on the meta device as we don't persist our buffer,
        # i.e., we don't expect the backing tensor to be replaced with pretrained weights.
        if device is not None and device.type == "meta":
            device = None
        # Θ_i = 10000^(-2(i-1)/d)
        theta = torch.pow(
            base, -torch.arange(0, width, 2, dtype=torch.float, device=device) / width
        )
        self.register_buffer("theta", theta, persistent=False)

        self._create_rotary_embed(width=width, length=seq_len)

    def _create_rotary_embed(self, *, width: int, length: int):
        # mΘ
        position = torch.arange(length, device=self.theta.device).unsqueeze(1)
        m_theta = position * self.theta.unsqueeze(0)

        # We apply both sin and cos twice (see Eq 15, 34), but the ordering
        # is changed for compatibility with most common implementations.
        m_theta = torch.cat([m_theta, m_theta], dim=-1)

        re_cos = m_theta.cos().view([length, width])
        re_sin = m_theta.sin().view([length, width])

        self.register_buffer("cos", re_cos, persistent=False)
        self.register_buffer("sin", re_sin, persistent=False)

    def _rotate(self, input: torch.Tensor):
        """Rotate the input tensor by half of its innermost width.

        input (Tensor): array to rotate.
        RETURNS (Tensor): rotated array.

        Shapes:
            input - (..., width)
            output - (..., width)
        """
        half_idx = input.shape[-1] // 2
        input_1 = -input[..., half_idx:]
        input_2 = input[..., :half_idx]
        return torch.cat([input_1, input_2], dim=-1)

    def forward(self, input: torch.Tensor, *, positions: Optional[torch.Tensor] = None):
        """
        Apply rotary embeddings to an array.

        :param input: Array to apply the rotary embeddings to.
        :param positions: positions of the inputs. If no positions are
            provided, they are assumed to be [0, seq_len).
        :return: Array with the rotary embeddings applied.

        Shapes:
            input - (batch_size, num_heads, seq_len, width_per_head)
            positions - (batch_size, seq_len)
            output - (batch_size, num_heads, seq_len, width_per_head)
        """
        batch_size, _, seq_len, width = input.shape

        if positions is None:
            # Fastpath: positions from [0..seq_len), avoid indexing.
            if self.cos.size(-2) < seq_len:
                self._create_rotary_embed(width=width, length=seq_len)
            rot_cos = self.cos[:seq_len, :].view(1, 1, seq_len, width)
            rot_sin = self.sin[:seq_len, :].view(1, 1, seq_len, width)
        else:
            max_len = int(positions.max()) + 1
            if self.cos.size(-2) < max_len:
                self._create_rotary_embed(width=width, length=max_len)

            # Flatten positions to index cos/sin arrays, then unflatten.
            #
            # Example shapes:
            #
            #   positions_flat - (batch_size * seq_len)
            #   self.cos - (max_len, width)
            #   rot_cos - (batch_size, seq_len, width)
            positions_flat = positions.view(-1)
            rot_cos = self.cos[positions_flat].view(batch_size, 1, seq_len, width)
            rot_sin = self.sin[positions_flat].view(batch_size, 1, seq_len, width)

        # Eq 34 with ordering changed for compatibility.
        return rot_cos * input + rot_sin * self._rotate(input)


class RotMultiheadAttention(MultiheadAttention):
    def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
                 add_bias_kv=False, add_zero_attn=False, self_attention=False,
                 encoder_decoder_attention=False):
        super().__init__(embed_dim, num_heads, kdim=kdim, vdim=vdim, dropout=dropout, bias=bias,
                         add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=self_attention,
                         encoder_decoder_attention=encoder_decoder_attention)
        self.rotary_embeds = RotaryEmbeddings(width=embed_dim // num_heads)

    def forward(
            self,
            query, key, value,
            spk_pos_ids_flat=None,
            key_padding_mask=None,
            incremental_state=None,
            need_weights=True,
            static_kv=False,
            attn_mask=None,
            before_softmax=False,
            need_head_weights=False,
            enc_dec_attn_constraint_mask=None,
            reset_attn_weight=None
    ):
        """Input shape: Time x Batch x Channel

        Args:
            key_padding_mask (ByteTensor, optional): mask to exclude
                keys that are pads, of shape `(batch, src_len)`, where
                padding elements are indicated by 1s.
            need_weights (bool, optional): return the attention weights,
                averaged over heads (default: False).
            attn_mask (ByteTensor, optional): typically used to
                implement causal attention, where the mask prevents the
                attention from looking forward in time (default: None).
            before_softmax (bool, optional): return the raw attention
                weights and values before the attention softmax.
            need_head_weights (bool, optional): return the attention
                weights for each head. Implies *need_weights*. Default:
                return the average attention weights over all heads.
        """
        if need_head_weights:
            need_weights = True

        tgt_len, bsz, embed_dim = query.size()
        assert embed_dim == self.embed_dim
        assert list(query.size()) == [tgt_len, bsz, embed_dim]

        if incremental_state is not None:
            saved_state = self._get_input_buffer(incremental_state)
            if 'prev_key' in saved_state:
                # previous time steps are cached - no need to recompute
                # key and value if they are static
                if static_kv:
                    assert self.encoder_decoder_attention and not self.self_attention
                    key = value = None
        else:
            saved_state = None

        if self.self_attention:
            # self-attention
            q, k, v = self.in_proj_qkv(query)
        elif self.encoder_decoder_attention:
            # encoder-decoder attention
            q = self.in_proj_q(query)
            if key is None:
                assert value is None
                k = v = None
            else:
                k = self.in_proj_k(key)
                v = self.in_proj_v(key)
        else:
            q = self.in_proj_q(query)
            k = self.in_proj_k(key)
            v = self.in_proj_v(value)
        q = q * self.scaling

        if self.bias_k is not None:
            assert self.bias_v is not None
            k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
            v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
            if attn_mask is not None:
                attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat(
                    [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)

        q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
        if k is not None:
            k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
        if v is not None:
            v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)

        # Apply rot embedding and store incremental_state
        q = self.rotary_embeds(q[None, :], positions=spk_pos_ids_flat)[0]
        if saved_state is not None:
            # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
            if 'prev_key' in saved_state:
                prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
                if static_kv:
                    k = prev_key
                else:
                    k = torch.cat((prev_key, k), dim=1)
            if 'prev_value' in saved_state:
                prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
                if static_kv:
                    v = prev_value
                else:
                    v = torch.cat((prev_value, v), dim=1)
            saved_state['prev_key'], saved_state['prev_value'] = k.view(bsz, self.num_heads, -1, self.head_dim), v.view(
                bsz, self.num_heads, -1, self.head_dim)
            self._set_input_buffer(incremental_state, saved_state)
        if incremental_state is not None:
            key_pos = torch.arange(k.shape[-2], device=q.device).unsqueeze(0)
        else:
            key_pos = spk_pos_ids_flat
        k = self.rotary_embeds(k[None, :], positions=key_pos)[0]

        src_len = k.size(1)

        # This is part of a workaround to get around fork/join parallelism
        # not supporting Optional types.
        if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
            key_padding_mask = None

        if key_padding_mask is not None:
            assert key_padding_mask.size(0) == bsz
            assert key_padding_mask.size(1) == src_len

        if self.add_zero_attn:
            src_len += 1
            k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
            v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
            if attn_mask is not None:
                attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat(
                    [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)

        attn_weights = torch.bmm(q, k.transpose(1, 2))
        attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
        assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]

        if attn_mask is not None:
            if len(attn_mask.shape) == 2:
                attn_mask = attn_mask.unsqueeze(0)
            elif len(attn_mask.shape) == 3:
                attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
                    bsz * self.num_heads, tgt_len, src_len)
            attn_weights = attn_weights + attn_mask

        if enc_dec_attn_constraint_mask is not None:  # bs x head x L_kv
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights = attn_weights.masked_fill(
                enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
                -1e8,
            )
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        if key_padding_mask is not None:
            # don't attend to padding symbols
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights = attn_weights.masked_fill(
                key_padding_mask.unsqueeze(1).unsqueeze(2),
                -1e8,
            )
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)

        if before_softmax:
            return attn_weights, v

        attn_weights_float = softmax(attn_weights, dim=-1)
        attn_weights = attn_weights_float.type_as(attn_weights)
        attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)

        if reset_attn_weight is not None:
            if reset_attn_weight:
                self.last_attn_probs = attn_probs.detach()
            else:
                assert self.last_attn_probs is not None
                attn_probs = self.last_attn_probs
        attn = torch.bmm(attn_probs, v)
        assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
        attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
        attn = self.out_proj(attn)

        if need_weights:
            attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
            if not need_head_weights:
                # average attention weights over heads
                attn_weights = attn_weights.mean(dim=0)
        else:
            attn_weights = None

        return attn, (attn_weights, attn_logits)


class RotMultiheadAttention2(MultiheadAttention):
    def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
                 add_bias_kv=False, add_zero_attn=False, self_attention=False,
                 encoder_decoder_attention=False):
        super().__init__(embed_dim, num_heads, kdim=kdim, vdim=vdim, dropout=dropout, bias=bias,
                         add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=self_attention,
                         encoder_decoder_attention=encoder_decoder_attention)
        self.rotary_embeds = RotaryEmbeddings(width=embed_dim // num_heads)

    def forward(
            self,
            query, key, value,
            spk_pos_ids_flat=None,
            key_padding_mask=None,
            incremental_state=None,
            need_weights=True,
            static_kv=False,
            attn_mask=None,
            before_softmax=False,
            need_head_weights=False,
            enc_dec_attn_constraint_mask=None,
            reset_attn_weight=None
    ):
        """Input shape: Time x Batch x Channel

        Args:
            key_padding_mask (ByteTensor, optional): mask to exclude
                keys that are pads, of shape `(batch, src_len)`, where
                padding elements are indicated by 1s.
            need_weights (bool, optional): return the attention weights,
                averaged over heads (default: False).
            attn_mask (ByteTensor, optional): typically used to
                implement causal attention, where the mask prevents the
                attention from looking forward in time (default: None).
            before_softmax (bool, optional): return the raw attention
                weights and values before the attention softmax.
            need_head_weights (bool, optional): return the attention
                weights for each head. Implies *need_weights*. Default:
                return the average attention weights over all heads.
        """
        if need_head_weights:
            need_weights = True

        tgt_len, bsz, embed_dim = query.size()
        assert embed_dim == self.embed_dim
        assert list(query.size()) == [tgt_len, bsz, embed_dim]

        if incremental_state is not None:
            saved_state = self._get_input_buffer(incremental_state)
            if 'prev_key' in saved_state:
                # previous time steps are cached - no need to recompute
                # key and value if they are static
                if static_kv:
                    assert self.encoder_decoder_attention and not self.self_attention
                    key = value = None
        else:
            saved_state = None

        if self.self_attention:
            # self-attention
            q, k, v = self.in_proj_qkv(query)
        elif self.encoder_decoder_attention:
            # encoder-decoder attention
            q = self.in_proj_q(query)
            if key is None:
                assert value is None
                k = v = None
            else:
                k = self.in_proj_k(key)
                v = self.in_proj_v(key)
        else:
            q = self.in_proj_q(query)
            k = self.in_proj_k(key)
            v = self.in_proj_v(value)

        if self.bias_k is not None:
            assert self.bias_v is not None
            k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
            v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
            if attn_mask is not None:
                attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat(
                    [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)

        q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
        if k is not None:
            k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
        if v is not None:
            v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)

        # Apply rot embedding and store incremental_state
        q = self.rotary_embeds(q[None, :], positions=spk_pos_ids_flat)[0]
        if saved_state is not None:
            # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
            if 'prev_key' in saved_state:
                prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
                if static_kv:
                    k = prev_key
                else:
                    k = torch.cat((prev_key, k), dim=1)
            if 'prev_value' in saved_state:
                prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
                if static_kv:
                    v = prev_value
                else:
                    v = torch.cat((prev_value, v), dim=1)
            saved_state['prev_key'], saved_state['prev_value'] = k.view(bsz, self.num_heads, -1, self.head_dim), v.view(
                bsz, self.num_heads, -1, self.head_dim)
            self._set_input_buffer(incremental_state, saved_state)
        key_pos = torch.arange(k.shape[-2], device=q.device).unsqueeze(0)
        k = self.rotary_embeds(k[None, :], positions=key_pos)[0]

        src_len = k.size(1)

        # This is part of a workaround to get around fork/join parallelism
        # not supporting Optional types.
        if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
            key_padding_mask = None

        if key_padding_mask is not None:
            assert key_padding_mask.size(0) == bsz
            assert key_padding_mask.size(1) == src_len

        if attn_mask is not None:
            if len(attn_mask.shape) == 2:
                attn_mask = attn_mask.unsqueeze(0)
            elif len(attn_mask.shape) == 3:
                attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
                    bsz * self.num_heads, tgt_len, src_len)
        attn = torch.nn.functional.scaled_dot_product_attention(
            q, k, v, attn_mask=attn_mask, dropout_p=0, is_causal=False)
        assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
        attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
        attn_logits = None
        attn_weights = None
        return attn, (attn_weights, attn_logits)


class RotDecSALayer(nn.Module):
    def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1,
                 kernel_size=9, ffn_hidden_size=1024, act='gelu', post_ln=False, bias=True):
        super().__init__()
        self.c = c
        self.dropout = dropout
        self.layer_norm1 = LayerNorm(c)
        self.self_attn = RotMultiheadAttention(
            c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
        )
        self.layer_norm2 = LayerNorm(c)
        self.ffn = TransformerFFNLayer(
            c, ffn_hidden_size, padding='LEFT', kernel_size=kernel_size,
            dropout=relu_dropout, act=act, bias=bias)
        self.post_ln = post_ln

    def forward(
            self,
            x,
            encoder_out=None,
            encoder_padding_mask=None,
            incremental_state=None,
            self_attn_mask=None,
            self_attn_padding_mask=None,
            attn_out=None,
            reset_attn_weight=None,
            spk_pos_ids_flat=None,
            **kwargs,
    ):
        layer_norm_training = kwargs.get('layer_norm_training', None)
        if layer_norm_training is not None:
            self.layer_norm1.training = layer_norm_training
            self.layer_norm2.training = layer_norm_training
        residual = x
        if not self.post_ln:
            x = self.layer_norm1(x)

        x, (attn_weights, _) = self.self_attn(
            query=x,
            key=x,
            value=x,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            attn_mask=self_attn_mask,
            spk_pos_ids_flat=spk_pos_ids_flat
        )
        x = F.dropout(x, self.dropout, training=self.training)
        x = residual + x
        if self.post_ln:
            x = self.layer_norm1(x)

        residual = x
        if not self.post_ln:
            x = self.layer_norm2(x)
        x = self.ffn(x, incremental_state=incremental_state)
        x = F.dropout(x, self.dropout, training=self.training)
        x = residual + x
        if self.post_ln:
            x = self.layer_norm2(x)
        return x, attn_weights

    def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
        self.encoder_attn.clear_buffer(incremental_state)
        self.ffn.clear_buffer(incremental_state)

    def set_buffer(self, name, tensor, incremental_state):
        return set_incremental_state(self, incremental_state, name, tensor)


class RotDecSALayer2(RotDecSALayer):
    def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1, kernel_size=9,
                 ffn_hidden_size=1024, act='gelu', post_ln=False):
        super().__init__(c, num_heads, dropout, attention_dropout, relu_dropout, kernel_size, ffn_hidden_size, act,
                         post_ln)
        self.self_attn = RotMultiheadAttention2(
            c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
        )


class RotTransformerDecoderLayer(nn.Module):
    def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=8, ffn_hidden_size=1024, post_ln=False,
                 op_version=1, bias=True):
        super().__init__()
        self.hidden_size = hidden_size
        self.dropout = dropout
        self.num_heads = num_heads
        if op_version == 1:
            self.op = RotDecSALayer(
                hidden_size, num_heads, dropout=dropout,
                attention_dropout=0.0, relu_dropout=dropout,
                kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size,
                post_ln=post_ln, bias=bias)
        else:
            self.op = RotDecSALayer2(
                hidden_size, num_heads, dropout=dropout,
                attention_dropout=0.0, relu_dropout=dropout,
                kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size,
                post_ln=post_ln)

    def forward(self, x, **kwargs):
        return self.op(x, **kwargs)

    def clear_buffer(self, *args):
        return self.op.clear_buffer(*args)

    def set_buffer(self, *args):
        return self.op.set_buffer(*args)


================================================
FILE: tts/modules/ar_dur/commons/seq_utils.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict
import torch
import torch.nn.functional as F


def make_positions(tensor, padding_idx):
    """Replace non-padding symbols with their position numbers.

    Position numbers begin at padding_idx+1. Padding symbols are ignored.
    """
    # The series of casts and type-conversions here are carefully
    # balanced to both work with ONNX export and XLA. In particular XLA
    # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
    # how to handle the dtype kwarg in cumsum.
    mask = tensor.ne(padding_idx).int()
    return (
                   torch.cumsum(mask, dim=1).type_as(mask) * mask
           ).long() + padding_idx


def softmax(x, dim):
    return F.softmax(x, dim=dim, dtype=torch.float32)


def sequence_mask(lengths, maxlen=None, dtype=torch.bool):
    if maxlen is None:
        maxlen = lengths.max()
    mask = ~(torch.ones((len(lengths), maxlen)).to(lengths.device).cumsum(dim=1).t() > lengths).t()
    mask.type(dtype)
    return mask


def weights_nonzero_speech(target):
    # target : B x T x mel
    # Assign weight 1.0 to all labels except for padding (id=0).
    dim = target.size(-1)
    return target.abs().sum(-1, keepdim=True).ne(0).float().repeat(1, 1, dim)


INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0)


def _get_full_incremental_state_key(module_instance, key):
    module_name = module_instance.__class__.__name__

    # assign a unique ID to each module instance, so that incremental state is
    # not shared across module instances
    if not hasattr(module_instance, '_instance_id'):
        INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1
        module_instance._instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name]

    return '{}.{}.{}'.format(module_name, module_instance._instance_id, key)


def get_incremental_state(module, incremental_state, key):
    """Helper for getting incremental state for an nn.Module."""
    full_key = _get_full_incremental_state_key(module, key)
    if incremental_state is None or full_key not in incremental_state:
        return None
    return incremental_state[full_key]


def set_incremental_state(module, incremental_state, key, value):
    """Helper for setting incremental state for an nn.Module."""
    if incremental_state is not None:
        full_key = _get_full_incremental_state_key(module, key)
        incremental_state[full_key] = value


def fill_with_neg_inf(t):
    """FP16-compatible function that fills a tensor with -inf."""
    return t.float().fill_(float('-inf')).type_as(t)


def fill_with_neg_inf2(t):
    """FP16-compatible function that fills a tensor with -inf."""
    return t.float().fill_(-1e8).type_as(t)


def select_attn(attn_logits, type='best'):
    """

    :param attn_logits: [n_layers, B, n_head, T_sp, T_txt]
    :return:
    """
    encdec_attn = torch.stack(attn_logits, 0).transpose(1, 2)
    # [n_layers * n_head, B, T_sp, T_txt]
    encdec_attn = (encdec_attn.reshape([-1, *encdec_attn.shape[2:]])).softmax(-1)
    if type == 'best':
        indices = encdec_attn.max(-1).values.sum(-1).argmax(0)
        encdec_attn = encdec_attn.gather(
            0, indices[None, :, None, None].repeat(1, 1, encdec_attn.size(-2), encdec_attn.size(-1)))[0]
        return encdec_attn
    elif type == 'mean':
        return encdec_attn.mean(0)


def make_pad_mask(lengths, xs=None, length_dim=-1):
    """Make mask tensor containing indices of padded part.
    Args:
        lengths (LongTensor or List): Batch of lengths (B,).
        xs (Tensor, optional): The reference tensor.
            If set, masks will be the same shape as this tensor.
        length_dim (int, optional): Dimension indicator of the above tensor.
            See the example.
    Returns:
        Tensor: Mask tensor containing indices of padded part.
                dtype=torch.uint8 in PyTorch 1.2-
                dtype=torch.bool in PyTorch 1.2+ (including 1.2)
    Examples:
        With only lengths.
        >>> lengths = [5, 3, 2]
        >>> make_non_pad_mask(lengths)
        masks = [[0, 0, 0, 0 ,0],
                 [0, 0, 0, 1, 1],
                 [0, 0, 1, 1, 1]]
        With the reference tensor.
        >>> xs = torch.zeros((3, 2, 4))
        >>> make_pad_mask(lengths, xs)
        tensor([[[0, 0, 0, 0],
                 [0, 0, 0, 0]],
                [[0, 0, 0, 1],
                 [0, 0, 0, 1]],
                [[0, 0, 1, 1],
                 [0, 0, 1, 1]]], dtype=torch.uint8)
        >>> xs = torch.zeros((3, 2, 6))
        >>> make_pad_mask(lengths, xs)
        tensor([[[0, 0, 0, 0, 0, 1],
                 [0, 0, 0, 0, 0, 1]],
                [[0, 0, 0, 1, 1, 1],
                 [0, 0, 0, 1, 1, 1]],
                [[0, 0, 1, 1, 1, 1],
                 [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
        With the reference tensor and dimension indicator.
        >>> xs = torch.zeros((3, 6, 6))
        >>> make_pad_mask(lengths, xs, 1)
        tensor([[[0, 0, 0, 0, 0, 0],
                 [0, 0, 0, 0, 0, 0],
                 [0, 0, 0, 0, 0, 0],
                 [0, 0, 0, 0, 0, 0],
                 [0, 0, 0, 0, 0, 0],
                 [1, 1, 1, 1, 1, 1]],
                [[0, 0, 0, 0, 0, 0],
                 [0, 0, 0, 0, 0, 0],
                 [0, 0, 0, 0, 0, 0],
                 [1, 1, 1, 1, 1, 1],
                 [1, 1, 1, 1, 1, 1],
                 [1, 1, 1, 1, 1, 1]],
                [[0, 0, 0, 0, 0, 0],
                 [0, 0, 0, 0, 0, 0],
                 [1, 1, 1, 1, 1, 1],
                 [1, 1, 1, 1, 1, 1],
                 [1, 1, 1, 1, 1, 1],
                 [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
        >>> make_pad_mask(lengths, xs, 2)
        tensor([[[0, 0, 0, 0, 0, 1],
                 [0, 0, 0, 0, 0, 1],
                 [0, 0, 0, 0, 0, 1],
                 [0, 0, 0, 0, 0, 1],
                 [0, 0, 0, 0, 0, 1],
                 [0, 0, 0, 0, 0, 1]],
                [[0, 0, 0, 1, 1, 1],
                 [0, 0, 0, 1, 1, 1],
                 [0, 0, 0, 1, 1, 1],
                 [0, 0, 0, 1, 1, 1],
                 [0, 0, 0, 1, 1, 1],
                 [0, 0, 0, 1, 1, 1]],
                [[0, 0, 1, 1, 1, 1],
                 [0, 0, 1, 1, 1, 1],
                 [0, 0, 1, 1, 1, 1],
                 [0, 0, 1, 1, 1, 1],
                 [0, 0, 1, 1, 1, 1],
                 [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
    """
    if length_dim == 0:
        raise ValueError("length_dim cannot be 0: {}".format(length_dim))

    if not isinstance(lengths, list):
        lengths = lengths.tolist()
    bs = int(len(lengths))
    if xs is None:
        maxlen = int(max(lengths))
    else:
        maxlen = xs.size(length_dim)

    seq_range = torch.arange(0, maxlen, dtype=torch.int64)
    seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
    seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
    mask = seq_range_expand >= seq_length_expand

    if xs is not None:
        assert xs.size(0) == bs, (xs.size(0), bs)

        if length_dim < 0:
            length_dim = xs.dim() + length_dim
        # ind = (:, None, ..., None, :, , None, ..., None)
        ind = tuple(
            slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
        )
        mask = mask[ind].expand_as(xs).to(xs.device)
    return mask


def make_non_pad_mask(lengths, xs=None, length_dim=-1):
    """Make mask tensor containing indices of non-padded part.
    Args:
        lengths (LongTensor or List): Batch of lengths (B,).
        xs (Tensor, optional): The reference tensor.
            If set, masks will be the same shape as this tensor.
        length_dim (int, optional): Dimension indicator of the above tensor.
            See the example.
    Returns:
        ByteTensor: mask tensor containing indices of padded part.
                    dtype=torch.uint8 in PyTorch 1.2-
                    dtype=torch.bool in PyTorch 1.2+ (including 1.2)
    Examples:
        With only lengths.
        >>> lengths = [5, 3, 2]
        >>> make_non_pad_mask(lengths)
        masks = [[1, 1, 1, 1 ,1],
                 [1, 1, 1, 0, 0],
                 [1, 1, 0, 0, 0]]
        With the reference tensor.
        >>> xs = torch.zeros((3, 2, 4))
        >>> make_non_pad_mask(lengths, xs)
        tensor([[[1, 1, 1, 1],
                 [1, 1, 1, 1]],
                [[1, 1, 1, 0],
                 [1, 1, 1, 0]],
                [[1, 1, 0, 0],
                 [1, 1, 0, 0]]], dtype=torch.uint8)
        >>> xs = torch.zeros((3, 2, 6))
        >>> make_non_pad_mask(lengths, xs)
        tensor([[[1, 1, 1, 1, 1, 0],
                 [1, 1, 1, 1, 1, 0]],
                [[1, 1, 1, 0, 0, 0],
                 [1, 1, 1, 0, 0, 0]],
                [[1, 1, 0, 0, 0, 0],
                 [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
        With the reference tensor and dimension indicator.
        >>> xs = torch.zeros((3, 6, 6))
        >>> make_non_pad_mask(lengths, xs, 1)
        tensor([[[1, 1, 1, 1, 1, 1],
                 [1, 1, 1, 1, 1, 1],
                 [1, 1, 1, 1, 1, 1],
                 [1, 1, 1, 1, 1, 1],
                 [1, 1, 1, 1, 1, 1],
                 [0, 0, 0, 0, 0, 0]],
                [[1, 1, 1, 1, 1, 1],
                 [1, 1, 1, 1, 1, 1],
                 [1, 1, 1, 1, 1, 1],
                 [0, 0, 0, 0, 0, 0],
                 [0, 0, 0, 0, 0, 0],
                 [0, 0, 0, 0, 0, 0]],
                [[1, 1, 1, 1, 1, 1],
                 [1, 1, 1, 1, 1, 1],
                 [0, 0, 0, 0, 0, 0],
                 [0, 0, 0, 0, 0, 0],
                 [0, 0, 0, 0, 0, 0],
                 [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
        >>> make_non_pad_mask(lengths, xs, 2)
        tensor([[[1, 1, 1, 1, 1, 0],
                 [1, 1, 1, 1, 1, 0],
                 [1, 1, 1, 1, 1, 0],
                 [1, 1, 1, 1, 1, 0],
                 [1, 1, 1, 1, 1, 0],
                 [1, 1, 1, 1, 1, 0]],
                [[1, 1, 1, 0, 0, 0],
                 [1, 1, 1, 0, 0, 0],
                 [1, 1, 1, 0, 0, 0],
                 [1, 1, 1, 0, 0, 0],
                 [1, 1, 1, 0, 0, 0],
                 [1, 1, 1, 0, 0, 0]],
                [[1, 1, 0, 0, 0, 0],
                 [1, 1, 0, 0, 0, 0],
                 [1, 1, 0, 0, 0, 0],
                 [1, 1, 0, 0, 0, 0],
                 [1, 1, 0, 0, 0, 0],
                 [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
    """
    return ~make_pad_mask(lengths, xs, length_dim)


def get_mask_from_lengths(lengths):
    max_len = torch.max(lengths).item()
    ids = torch.arange(0, max_len).to(lengths.device)
    mask = (ids < lengths.unsqueeze(1)).bool()
    return mask


def group_hidden_by_segs(h, seg_ids, max_len):
    """

    :param h: [B, T, H]
    :param seg_ids: [B, T]
    :return: h_ph: [B, T_ph, H]
    """
    B, T, H = h.shape
    h_gby_segs = h.new_zeros([B, max_len + 1, H]).scatter_add_(1, seg_ids[:, :, None].repeat([1, 1, H]), h)
    all_ones = h.new_ones(h.shape[:2])
    cnt_gby_segs = h.new_zeros([B, max_len + 1]).scatter_add_(1, seg_ids, all_ones).contiguous()
    h_gby_segs = h_gby_segs[:, 1:]
    cnt_gby_segs = cnt_gby_segs[:, 1:]
    h_gby_segs = h_gby_segs / torch.clamp(cnt_gby_segs[:, :, None], min=1)
    return h_gby_segs, cnt_gby_segs

def expand_by_repeat_times(source_encoding, lengths):
    """
    source_encoding: [T, C]
    lengths, list of int, [T,], how many times each token should repeat
    return:
        expanded_encoding: [T_expand, C]
    """
    hid_dim = source_encoding.shape[1]
    out2source = []
    for i, length in enumerate(lengths):
        out2source += [i for _ in range(length)]
    out2source = torch.LongTensor(out2source).to(source_encoding.device)
    out2source_ = out2source[:, None].repeat([1, hid_dim])
    expanded_encoding = torch.gather(source_encoding, 0, out2source_)  # [B, T, H]
    return expanded_encoding


def expand_word2ph(word_encoding, ph2word):
    word_encoding = F.pad(word_encoding,[0,0,1,0])
    ph2word_ = ph2word[:, :, None].repeat([1, 1, word_encoding.shape[-1]])
    out = torch.gather(word_encoding, 1, ph2word_)  # [B, T, H]
    return out


================================================
FILE: tts/modules/ar_dur/commons/transformer.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import torch
from torch import nn
from torch.nn import Parameter, Linear
from tts.modules.ar_dur.commons.layers import LayerNorm, Embedding
from tts.modules.ar_dur.commons.seq_utils import get_incremental_state, set_incremental_state, softmax, make_positions
import torch.nn.functional as F

DEFAULT_MAX_SOURCE_POSITIONS = 3000
DEFAULT_MAX_TARGET_POSITIONS = 3000


class SinusoidalPositionalEmbedding(nn.Module):
    """This module produces sinusoidal positional embeddings of any length.

    Padding symbols are ignored.
    """

    def __init__(self, embedding_dim, padding_idx, init_size=1024):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.padding_idx = padding_idx
        self.weights = SinusoidalPositionalEmbedding.get_embedding(
            init_size,
            embedding_dim,
            padding_idx,
        )
        self.register_buffer('_float_tensor', torch.FloatTensor(1))

    @staticmethod
    def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
        """Build sinusoidal embeddings.

        This matches the implementation in tensor2tensor, but differs slightly
        from the description in Section 3.5 of "Attention Is All You Need".
        """
        half_dim = embedding_dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
        emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
        if embedding_dim % 2 == 1:
            # zero pad
            emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
        if padding_idx is not None:
            emb[padding_idx, :] = 0
        return emb

    def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
        """Input is expected to be of size [bsz x seqlen]."""
        bsz, seq_len = input.shape[:2]
        max_pos = self.padding_idx + 1 + seq_len
        if self.weights is None or max_pos > self.weights.size(0):
            # recompute/expand embeddings if needed
            self.weights = SinusoidalPositionalEmbedding.get_embedding(
                max_pos,
                self.embedding_dim,
                self.padding_idx,
            )
        self.weights = self.weights.to(self._float_tensor)

        if incremental_state is not None:
            # positions is the same for every token when decoding a single step
            pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
            return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)

        positions = make_positions(input, self.padding_idx) if positions is None else positions
        return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()

    def max_positions(self):
        """Maximum number of supported positions."""
        return int(1e5)  # an arbitrary large number


class TransformerFFNLayer(nn.Module):
    def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu', bias=True):
        super().__init__()
        self.kernel_size = kernel_size
        self.dropout = dropout
        self.act = act
        if padding == 'SAME':
            self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size,
                                   padding=kernel_size // 2, bias=bias)
        elif padding == 'LEFT':
            self.ffn_1 = nn.Sequential(
                nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
                nn.Conv1d(hidden_size, filter_size, kernel_size, bias=bias)
            )
        self.ffn_2 = Linear(filter_size, hidden_size, bias=bias)

    def forward(self, x, incremental_state=None):
        # x: T x B x C
        if incremental_state is not None:
            saved_state = self._get_input_buffer(incremental_state)
            if 'prev_input' in saved_state:
                prev_input = saved_state['prev_input']
                x = torch.cat((prev_input, x), dim=0)
            x = x[-self.kernel_size:]
            saved_state['prev_input'] = x
            self._set_input_buffer(incremental_state, saved_state)

        x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
        x = x * self.kernel_size ** -0.5

        if incremental_state is not None:
            x = x[-1:]
        if self.act == 'gelu':
            x = F.gelu(x)
        if self.act == 'relu':
            x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.ffn_2(x)
        return x

    def _get_input_buffer(self, incremental_state):
        return get_incremental_state(
            self,
            incremental_state,
            'f',
        ) or {}

    def _set_input_buffer(self, incremental_state, buffer):
        set_incremental_state(
            self,
            incremental_state,
            'f',
            buffer,
        )

    def clear_buffer(self, incremental_state):
        if incremental_state is not None:
            saved_state = self._get_input_buffer(incremental_state)
            if 'prev_input' in saved_state:
                del saved_state['prev_input']
            self._set_input_buffer(incremental_state, saved_state)


class MultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
                 add_bias_kv=False, add_zero_attn=False, self_attention=False,
                 encoder_decoder_attention=False):
        super().__init__()
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim

        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
        self.scaling = self.head_dim ** -0.5

        self.self_attention = self_attention
        self.encoder_decoder_attention = encoder_decoder_attention

        assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
                                                             'value to be of the same size'

        if self.qkv_same_dim:
            self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
        else:
            self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
            self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
            self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))

        if bias:
            self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
        else:
            self.register_parameter('in_proj_bias', None)

        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

        if add_bias_kv:
            self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
            self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
        else:
            self.bias_k = self.bias_v = None

        self.add_zero_attn = add_zero_attn

        self.reset_parameters()

        self.enable_torch_version = False
        self.last_attn_probs = None

    def reset_parameters(self):
        if self.qkv_same_dim:
            nn.init.xavier_uniform_(self.in_proj_weight)
        else:
            nn.init.xavier_uniform_(self.k_proj_weight)
            nn.init.xavier_uniform_(self.v_proj_weight)
            nn.init.xavier_uniform_(self.q_proj_weight)

        nn.init.xavier_uniform_(self.out_proj.weight)
        if self.in_proj_bias is not None:
            nn.init.constant_(self.in_proj_bias, 0.)
            nn.init.constant_(self.out_proj.bias, 0.)
        if self.bias_k is not None:
            nn.init.xavier_normal_(self.bias_k)
        if self.bias_v is not None:
            nn.init.xavier_normal_(self.bias_v)

    def forward(
            self,
            query, key, value,
            key_padding_mask=None,
            incremental_state=None,
            need_weights=True,
            static_kv=False,
            attn_mask=None,
            before_softmax=False,
            need_head_weights=False,
            enc_dec_attn_constraint_mask=None,
            reset_attn_weight=None
    ):
        """Input shape: Time x Batch x Channel

        Args:
            key_padding_mask (ByteTensor, optional): mask to exclude
                keys that are pads, of shape `(batch, src_len)`, where
                padding elements are indicated by 1s.
            need_weights (bool, optional): return the attention weights,
                averaged over heads (default: False).
            attn_mask (ByteTensor, optional): typically used to
                implement causal attention, where the mask prevents the
                attention from looking forward in time (default: None).
            before_softmax (bool, optional): return the raw attention
                weights and values before the attention softmax.
            need_head_weights (bool, optional): return the attention
                weights for each head. Implies *need_weights*. Default:
                return the average attention weights over all heads.
        """
        if need_head_weights:
            need_weights = True

        tgt_len, bsz, embed_dim = query.size()
        assert embed_dim == self.embed_dim
        assert list(query.size()) == [tgt_len, bsz, embed_dim]

        if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
            if self.qkv_same_dim:
                return F.multi_head_attention_forward(query, key, value,
                                                      self.embed_dim, self.num_heads,
                                                      self.in_proj_weight,
                                                      self.in_proj_bias, self.bias_k, self.bias_v,
                                                      self.add_zero_attn, self.dropout,
                                                      self.out_proj.weight, self.out_proj.bias,
                                                      self.training, key_padding_mask, need_weights,
                                                      attn_mask)
            else:
                return F.multi_head_attention_forward(query, key, value,
                                                      self.embed_dim, self.num_heads,
                                                      torch.empty([0]),
                                                      self.in_proj_bias, self.bias_k, self.bias_v,
                                                      self.add_zero_attn, self.dropout,
                                                      self.out_proj.weight, self.out_proj.bias,
                                                      self.training, key_padding_mask, need_weights,
                                                      attn_mask, use_separate_proj_weight=True,
                                                      q_proj_weight=self.q_proj_weight,
                                                      k_proj_weight=self.k_proj_weight,
                                                      v_proj_weight=self.v_proj_weight)

        if incremental_state is not None:
            saved_state = self._get_input_buffer(incremental_state)
            if 'prev_key' in saved_state:
                # previous time steps are cached - no need to recompute
                # key and value if they are static
                if static_kv:
                    assert self.encoder_decoder_attention and not self.self_attention
                    key = value = None
        else:
            saved_state = None

        if self.self_attention:
            # self-attention
            q, k, v = self.in_proj_qkv(query)
        elif self.encoder_decoder_attention:
            # encoder-decoder attention
            q = self.in_proj_q(query)
            if key is None:
                assert value is None
                k = v = None
            else:
                k = self.in_proj_k(key)
                v = self.in_proj_v(key)

        else:
            q = self.in_proj_q(query)
            k = self.in_proj_k(key)
            v = self.in_proj_v(value)
        q = q * self.scaling

        if self.bias_k is not None:
            assert self.bias_v is not None
            k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
            v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
            if attn_mask is not None:
                attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat(
                    [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)

        q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
        if k is not None:
            k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
        if v is not None:
            v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)

        if saved_state is not None:
            # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
            if 'prev_key' in saved_state:
                prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
                if static_kv:
                    k = prev_key
                else:
                    k = torch.cat((prev_key, k), dim=1)
            if 'prev_value' in saved_state:
                prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
                if static_kv:
                    v = prev_value
                else:
                    v = torch.cat((prev_value, v), dim=1)
            if 'prev_key_padding_mask' in saved_state and saved_state['prev_key_padding_mask'] is not None:
                prev_key_padding_mask = saved_state['prev_key_padding_mask']
                if static_kv:
                    key_padding_mask = prev_key_padding_mask
                else:
                    key_padding_mask = torch.cat((prev_key_padding_mask, key_padding_mask), dim=1)

            saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
            saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
            saved_state['prev_key_padding_mask'] = key_padding_mask

            self._set_input_buffer(incremental_state, saved_state)

        src_len = k.size(1)

        # This is part of a workaround to get around fork/join parallelism
        # not supporting Optional types.
        if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
            key_padding_mask = None

        if key_padding_mask is not None:
            assert key_padding_mask.size(0) == bsz
            assert key_padding_mask.size(1) == src_len

        if self.add_zero_attn:
            src_len += 1
            k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
            v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
            if attn_mask is not None:
                attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat(
                    [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)

        attn_weights = torch.bmm(q, k.transpose(1, 2))
        attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)

        assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]

        if attn_mask is not None:
            if len(attn_mask.shape) == 2:
                attn_mask = attn_mask.unsqueeze(0)
            elif len(attn_mask.shape) == 3:
                attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
                    bsz * self.num_heads, tgt_len, src_len)
            attn_weights = attn_weights + attn_mask

        if enc_dec_attn_constraint_mask is not None:  # bs x head x L_kv
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights = attn_weights.masked_fill(
                enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
                -1e8,
            )
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        if key_padding_mask is not None:
            # don't attend to padding symbols
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights = attn_weights.masked_fill(
                key_padding_mask.unsqueeze(1).unsqueeze(2),
                -1e8,
            )
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)

        if before_softmax:
            return attn_weights, v

        attn_weights_float = softmax(attn_weights, dim=-1)
        attn_weights = attn_weights_float.type_as(attn_weights)
        attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)

        if reset_attn_weight is not None:
            if reset_attn_weight:
                self.last_attn_probs = attn_probs.detach()
            else:
                assert self.last_attn_probs is not None
                attn_probs = self.last_attn_probs
        attn = torch.bmm(attn_probs, v)
        assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
        attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
        attn = self.out_proj(attn)

        if need_weights:
            attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
            if not need_head_weights:
                # average attention weights over heads
                attn_weights = attn_weights.mean(dim=0)
        else:
            attn_weights = None

        return attn, (attn_weights, attn_logits)

    def in_proj_qkv(self, query):
        return self._in_proj(query).chunk(3, dim=-1)

    def in_proj_q(self, query):
        if self.qkv_same_dim:
            return self._in_proj(query, end=self.embed_dim)
        else:
            bias = self.in_proj_bias
            if bias is not None:
                bias = bias[:self.embed_dim]
            return F.linear(query, self.q_proj_weight, bias)

    def in_proj_k(self, key):
        if self.qkv_same_dim:
            return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
        else:
            weight = self.k_proj_weight
            bias = self.in_proj_bias
            if bias is not None:
                bias = bias[self.embed_dim:2 * self.embed_dim]
            return F.linear(key, weight, bias)

    def in_proj_v(self, value):
        if self.qkv_same_dim:
            return self._in_proj(value, start=2 * self.embed_dim)
        else:
            weight = self.v_proj_weight
            bias = self.in_proj_bias
            if bias is not None:
                bias = bias[2 * self.embed_dim:]
            return F.linear(value, weight, bias)

    def _in_proj(self, input, start=0, end=None):
        weight = self.in_proj_weight
        bias = self.in_proj_bias
        weight = weight[start:end, :]
        if bias is not None:
            bias = bias[start:end]
        return F.linear(input, weight, bias)

    def _get_input_buffer(self, incremental_state):
        return get_incremental_state(
            self,
            incremental_state,
            'attn_state',
        ) or {}

    def _set_input_buffer(self, incremental_state, buffer):
        set_incremental_state(
            self,
            incremental_state,
            'attn_state',
            buffer,
        )

    def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
        return attn_weights

    def clear_buffer(self, incremental_state=None):
        if incremental_state is not None:
            saved_state = self._get_input_buffer(incremental_state)
            if 'prev_key' in saved_state:
                del saved_state['prev_key']
            if 'prev_value' in saved_state:
                del saved_state['prev_value']
            self._set_input_buffer(incremental_state, saved_state)


class EncSALayer(nn.Module):
    def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
                 relu_dropout=0.1, kernel_size=9, padding='SAME', act='gelu',
                 ffn_hidden_size=1024):
        super().__init__()
        self.c = c
        self.dropout = dropout
        self.num_heads = num_heads
        if num_heads > 0:
            self.layer_norm1 = LayerNorm(c)
            self.self_attn = MultiheadAttention(
                self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False)
        self.layer_norm2 = LayerNorm(c)
        self.ffn = TransformerFFNLayer(
            c, ffn_hidden_size, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act)

    def forward(self, x, encoder_padding_mask=None, **kwargs):
        layer_norm_training = kwargs.get('layer_norm_training', None)
        if layer_norm_training is not None:
            self.layer_norm1.training = layer_norm_training
            self.layer_norm2.training = layer_norm_training
        if self.num_heads > 0:
            residual = x
            x = self.layer_norm1(x)
            x, _, = self.self_attn(
                query=x,
                key=x,
                value=x,
                key_padding_mask=encoder_padding_mask
            )
            x = F.dropout(x, self.dropout, training=self.training)
            x = residual + x
            x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]

        residual = x
        x = self.layer_norm2(x)
        x = self.ffn(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = residual + x
        x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
        return x


class DecSALayer(nn.Module):
    def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1,
                 kernel_size=9, ffn_hidden_size=1024, act='gelu', post_ln=False):
        super().__init__()
        self.c = c
        self.dropout = dropout
        self.layer_norm1 = LayerNorm(c)
        self.self_attn = MultiheadAttention(
            c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
        )
        self.layer_norm2 = LayerNorm(c)
        self.encoder_attn = MultiheadAttention(
            c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False,
        )
        self.layer_norm3 = LayerNorm(c)
        self.ffn = TransformerFFNLayer(
            c, ffn_hidden_size, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
        self.post_ln = post_ln

    def forward(
            self,
            x,
            encoder_out=None,
            encoder_padding_mask=None,
            incremental_state=None,
            self_attn_mask=None,
            self_attn_padding_mask=None,
            attn_out=None,
            reset_attn_weight=None,
            **kwargs,
    ):
        layer_norm_training = kwargs.get('layer_norm_training', None)
        if layer_norm_training is not None:
            self.layer_norm1.training = layer_norm_training
            self.layer_norm2.training = layer_norm_training
            self.layer_norm3.training = layer_norm_training
        residual = x
        if not self.post_ln:
            x = self.layer_norm1(x)
        x, _ = self.self_attn(
            query=x,
            key=x,
            value=x,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            attn_mask=self_attn_mask
        )
        x = F.dropout(x, self.dropout, training=self.training)
        x = residual + x
        if self.post_ln:
            x = self.layer_norm1(x)

        attn_logits = None
        if encoder_out is not None or attn_out is not None:
            residual = x
            if not self.post_ln:
                x = self.layer_norm2(x)
        if encoder_out is not None:
            x, attn = self.encoder_attn(
                query=x,
                key=encoder_out,
                value=encoder_out,
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                enc_dec_attn_constraint_mask=get_incremental_state(self, incremental_state,
                                                                   'enc_dec_attn_constraint_mask'),
                reset_attn_weight=reset_attn_weight
            )
            attn_logits = attn[1]
        elif attn_out is not None:
            x = self.encoder_attn.in_proj_v(attn_out)
        if encoder_out is not None or attn_out is not None:
            x = F.dropout(x, self.dropout, training=self.training)
            x = residual + x
        if self.post_ln:
            x = self.layer_norm2(x)

        residual = x
        if not self.post_ln:
            x = self.layer_norm3(x)
        x = self.ffn(x, incremental_state=incremental_state)
        x = F.dropout(x, self.dropout, training=self.training)
        x = residual + x
        if self.post_ln:
            x = self.layer_norm3(x)
        return x, attn_logits

    def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
        self.encoder_attn.clear_buffer(incremental_state)
        self.ffn.clear_buffer(incremental_state)

    def set_buffer(self, name, tensor, incremental_state):
        return set_incremental_state(self, incremental_state, name, tensor)


class TransformerEncoderLayer(nn.Module):
    def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2, ffn_hidden_size=1024):
        super().__init__()
        self.hidden_size = hidden_size
        self.dropout = dropout
        self.num_heads = num_heads
        self.op = EncSALayer(
            hidden_size, num_heads, dropout=dropout,
            attention_dropout=0.0, relu_dropout=dropout,
            kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size)

    def forward(self, x, **kwargs):
        return self.op(x, **kwargs)


class TransformerDecoderLayer(nn.Module):
    def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2, ffn_hidden_size=1024, post_ln=False):
        super().__init__()
        self.hidden_size = hidden_size
        self.dropout = dropout
        self.num_heads = num_heads
        self.op = DecSALayer(
            hidden_size, num_heads, dropout=dropout,
            attention_dropout=0.0, relu_dropout=dropout,
            kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size,
            post_ln=post_ln)

    def forward(self, x, **kwargs):
        return self.op(x, **kwargs)

    def clear_buffer(self, *args):
        return self.op.clear_buffer(*args)

    def set_buffer(self, *args):
        return self.op.set_buffer(*args)


class FFTBlocks(nn.Module):
    def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=0.0,
                 num_heads=2, use_pos_embed=True, use_last_norm=True,
                 use_pos_embed_alpha=True, ffn_hidden_size=1024):
        super().__init__()
        self.num_layers = num_layers
        embed_dim = self.hidden_size = hidden_size
        self.dropout = dropout
        self.use_pos_embed = use_pos_embed
        self.use_last_norm = use_last_norm
        if use_pos_embed:
            self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS
            self.padding_idx = 0
            self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1
            self.embed_positions = SinusoidalPositionalEmbedding(
                embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
            )

        self.layers = nn.ModuleList([])
        self.layers.extend([
            TransformerEncoderLayer(self.hidden_size, self.dropout,
                                    kernel_size=ffn_kernel_size, num_heads=num_heads,
                                    ffn_hidden_size=ffn_hidden_size)
            for _ in range(self.num_layers)
        ])
        if self.use_last_norm:
            self.layer_norm = nn.LayerNorm(embed_dim)
        else:
            self.layer_norm = None

    def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False):
        """
        :param x: [B, T, C]
        :param padding_mask: [B, T]
        :return: [B, T, C] or [L, B, T, C]
        """
        padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
        nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None]  # [T, B, 1]
        if self.use_pos_embed:
            positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
            x = x + positions
            x = F.dropout(x, p=self.dropout, training=self.training)
        # B x T x C -> T x B x C
        x = x.transpose(0, 1) * nonpadding_mask_TB
        hiddens = []
        for layer in self.layers:
            x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
            hiddens.append(x)
        if self.use_last_norm:
            x = self.layer_norm(x) * nonpadding_mask_TB
        if return_hiddens:
            x = torch.stack(hiddens, 0)  # [L, T, B, C]
            x = x.transpose(1, 2)  # [L, B, T, C]
        else:
            x = x.transpose(0, 1)  # [B, T, C]
        return x


class FastSpeechEncoder(FFTBlocks):
    def __init__(self, dict_size, hidden_size=256, num_layers=4, kernel_size=9,
                 dropout=0.0, num_heads=2, ffn_hidden_size=1024):
        super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads,
                         use_pos_embed=False, dropout=dropout, ffn_hidden_size=ffn_hidden_size)
        self.embed_tokens = Embedding(dict_size, hidden_size, 0)
        self.embed_scale = math.sqrt(hidden_size)
        self.padding_idx = 0
        self.embed_positions = SinusoidalPositionalEmbedding(
            hidden_size, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
        )

    def forward(self, txt_tokens, attn_mask=None, other_embeds=0):
        """

        :param txt_tokens: [B, T]
        :return: {
            'encoder_out': [B x T x C]
        }
        """
        encoder_padding_mask = txt_tokens.eq(self.padding_idx).data
        x = self.forward_embedding(txt_tokens) + other_embeds  # [B, T, H]
        if self.num_layers > 0:
            x = super(FastSpeechEncoder, self).forward(x, encoder_padding_mask, attn_mask=attn_mask)
        return x

    def forward_embedding(self, txt_tokens):
        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(txt_tokens)
        if self.use_pos_embed:
            positions = self.embed_positions(txt_tokens)
            x = x + positions
        x = F.dropout(x, p=self.dropout, training=self.training)
        return x


================================================
FILE: tts/modules/llm_dit/cfm.py
================================================
# MIT License

# Copyright (c) 2023 Alexander Tong

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

# Copyright (c) [2023] [Alexander Tong] 
# Copyright (c) [2025] [Ziyue Jiang] 
# SPDX-License-Identifier: MIT
# This file has been modified by Ziyue Jiang on 2025/03/19
# Original file was released under MIT, with the full license text # available at https://github.com/atong01/conditional-flow-matching/blob/1.0.7/LICENSE.
# This modified file is released under the same license.

import math
import torch
from typing import Union
from torch.distributions import LogisticNormal


class LogitNormalTrainingTimesteps:
    def __init__(self, T=1000.0, loc=0.0, scale=1.0):
        assert T > 0
        self.T = T
        self.dist = LogisticNormal(loc, scale)

    def sample(self, size, device):
        t = self.dist.sample(size)[..., 0].to(device)
        return t
    

def pad_t_like_x(t, x):
    """Function to reshape the time vector t by the number of dimensions of x.

    Parameters
    ----------
    x : Tensor, shape (bs, *dim)
        represents the source minibatch
    t : FloatTensor, shape (bs)

    Returns
    -------
    t : Tensor, shape (bs, number of x dimensions)

    Example
    -------
    x: Tensor (bs, C, W, H)
    t: Vector (bs)
    pad_t_like_x(t, x): Tensor (bs, 1, 1, 1)
    """
    if isinstance(t, (float, int)):
        return t
    return t.reshape(-1, *([1] * (x.dim() - 1)))


class ConditionalFlowMatcher:
    """Base class for conditional flow matching methods. This class implements the independent
    conditional flow matching methods from [1] and serves as a parent class for all other flow
    matching methods.

    It implements:
    - Drawing data from gaussian probability path N(t * x1 + (1 - t) * x0, sigma) function
    - conditional flow matching ut(x1|x0) = x1 - x0
    - score function $\nabla log p_t(x|x0, x1)$
    """

    def __init__(self, sigma: Union[float, int] = 0.0):
        r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$.

        Parameters
        ----------
        sigma : Union[float, int]
        """
        self.sigma = sigma
        self.time_sampler = LogitNormalTrainingTimesteps()

    def compute_mu_t(self, x0, x1, t):
        """
        Compute the mean of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].

        Parameters
        ----------
        x0 : Tensor, shape (bs, *dim)
            represents the source minibatch
        x1 : Tensor, shape (bs, *dim)
            represents the target minibatch
        t : FloatTensor, shape (bs)

        Returns
        -------
        mean mu_t: t * x1 + (1 - t) * x0

        References
        ----------
        [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
        """
        t = pad_t_like_x(t, x0)
        return t * x1 + (1 - t) * x0

    def compute_sigma_t(self, t):
        """
        Compute the standard deviation of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].

        Parameters
        ----------
        t : FloatTensor, shape (bs)

        Returns
        -------
        standard deviation sigma

        References
        ----------
        [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
        """
        del t
        return self.sigma

    def sample_xt(self, x0, x1, t, epsilon):
        """
        Draw a sample from the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].

        Parameters
        ----------
        x0 : Tensor, shape (bs, *dim)
            represents the source minibatch
        x1 : Tensor, shape (bs, *dim)
            represents the target minibatch
        t : FloatTensor, shape (bs)
        epsilon : Tensor, shape (bs, *dim)
            noise sample from N(0, 1)

        Returns
        -------
        xt : Tensor, shape (bs, *dim)

        References
        ----------
        [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
        """
        mu_t = self.compute_mu_t(x0, x1, t)
        sigma_t = self.compute_sigma_t(t)
        sigma_t = pad_t_like_x(sigma_t, x0)
        return mu_t + sigma_t * epsilon

    def compute_conditional_flow(self, x0, x1, t, xt):
        """
        Compute the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].

        Parameters
        ----------
        x0 : Tensor, shape (bs, *dim)
            represents the source minibatch
        x1 : Tensor, shape (bs, *dim)
            represents the target minibatch
        t : FloatTensor, shape (bs)
        xt : Tensor, shape (bs, *dim)
            represents the samples drawn from probability path pt

        Returns
        -------
        ut : conditional vector field ut(x1|x0) = x1 - x0

        References
        ----------
        [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
        """
        del t, xt
        return x1 - x0

    def sample_noise_like(self, x):
        return torch.randn_like(x)

    def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False):
        """
        Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma))
        and the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].

        Parameters
        ----------
        x0 : Tensor, shape (bs, *dim)
            represents the source minibatch
        x1 : Tensor, shape (bs, *dim)
            represents the target minibatch
        (optionally) t : Tensor, shape (bs)
            represents the time levels
            if None, drawn from uniform [0,1]
        return_noise : bool
            return the noise sample epsilon


        Returns
        -------
        t : FloatTensor, shape (bs)
        xt : Tensor, shape (bs, *dim)
            represents the samples drawn from probability path pt
        ut : conditional vector field ut(x1|x0) = x1 - x0
        (optionally) eps: Tensor, shape (bs, *dim) such that xt = mu_t + sigma_t * epsilon

        References
        ----------
        [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
        """
        if t is None:
            # t = torch.rand(x0.shape[0]).type_as(x0)
            t = self.time_sampler.sample([x0.shape[0]], x0.device).type_as(x0)

        assert len(t) == x0.shape[0], "t has to have batch size dimension"

        eps = self.sample_noise_like(x0)
        xt = self.sample_xt(x0, x1, t, eps)
        ut = self.compute_conditional_flow(x0, x1, t, xt)
        if return_noise:
            return t, xt, ut, eps
        else:
            return t, xt, ut

    def compute_lambda(self, t):
        """Compute the lambda function, see Eq.(23) [3].

        Parameters
        ----------
        t : FloatTensor, shape (bs)

        Returns
        -------
        lambda : score weighting function

        References
        ----------
        [4] Simulation-free Schrodinger bridges via score and flow matching, Preprint, Tong et al.
        """
        sigma_t = self.compute_sigma_t(t)
        return 2 * sigma_t / (self.sigma**2 + 1e-8)


class VariancePreservingConditionalFlowMatcher(ConditionalFlowMatcher):
    """Albergo et al. 2023 trigonometric interpolants class. This class inherits the
    ConditionalFlowMatcher and override the compute_mu_t and compute_conditional_flow functions in
    order to compute [3]'s trigonometric interpolants.

    [3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
    """

    def compute_mu_t(self, x0, x1, t):
        r"""Compute the mean of the probability path (Eq.5) from [3].

        Parameters
        ----------
        x0 : Tensor, shape (bs, *dim)
            represents the source minibatch
        x1 : Tensor, shape (bs, *dim)
            represents the target minibatch
        t : FloatTensor, shape (bs)

        Returns
        -------
        mean mu_t: cos(pi t/2)x0 + sin(pi t/2)x1

        References
        ----------
        [3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
        """
        t = pad_t_like_x(t, x0)
        return torch.cos(math.pi / 2 * t) * x0 + torch.sin(math.pi / 2 * t) * x1

    def compute_conditional_flow(self, x0, x1, t, xt):
        r"""Compute the conditional vector field similar to [3].

        ut(x1|x0) = pi/2 (cos(pi*t/2) x1 - sin(pi*t/2) x0),
        see Eq.(21) [3].

        Parameters
        ----------
        x0 : Tensor, shape (bs, *dim)
            represents the source minibatch
        x1 : Tensor, shape (bs, *dim)
            represents the target minibatch
        t : FloatTensor, shape (bs)
        xt : Tensor, shape (bs, *dim)
            represents the samples drawn from probability path pt

        Returns
        -------
        ut : conditional vector field
        ut(x1|x0) = pi/2 (cos(pi*t/2) x1 - sin(\pi*t/2) x0)

        References
        ----------
        [3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
        """
        del xt
        t = pad_t_like_x(t, x0)
        return math.pi / 2 * (torch.cos(math.pi / 2 * t) * x1 - torch.sin(math.pi / 2 * t) * x0)


================================================
FILE: tts/modules/llm_dit/dit.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch import nn

from tts.modules.llm_dit.cfm import ConditionalFlowMatcher
from tts.modules.ar_dur.commons.layers import Embedding
from tts.modules.ar_dur.commons.nar_tts_modules import PosEmb
from tts.modules.ar_dur.commons.rel_transformer import RelTransformerEncoder
from tts.modules.ar_dur.ar_dur_predictor import expand_states
from tts.modules.llm_dit.transformer import Transformer
from tts.modules.llm_dit.time_embedding import TimestepEmbedding


class Diffusion(nn.Module):
    def __init__(self):
        super().__init__()
        # Hparams
        # cond dim
        self.local_cond_dim = 512
        self.ctx_mask_dim = 16
        self.in_channels = 32
        self.out_channels = 32
        # LLM
        self.encoder_dim = 1024
        self.encoder_n_layers = 24
        self.encoder_n_heads = 16
        self.max_seq_len = 16384
        self.multiple_of = 256

        self.ctx_mask_proj = nn.Linear(1, self.ctx_mask_dim)
        self.local_cond_project = nn.Linear(
            self.out_channels + self.ctx_mask_dim, self.local_cond_dim)

        self.encoder = Transformer(self.encoder_n_layers, self.encoder_dim, self.encoder_n_heads, self.max_seq_len)

        self.x_prenet = nn.Linear(self.in_channels, self.encoder_dim)
        self.prenet = nn.Linear(self.local_cond_dim, self.encoder_dim)
        self.postnet = nn.Linear(self.encoder_dim, self.out_channels)
  
        self.flow_matcher = ConditionalFlowMatcher(sigma=0.0)
        # The implementation of TimestepEmbedding is a modified version from F5-TTS (https://github.com/SWivid/F5-TTS), 
        # which is licensed under the MIT License.
        self.f5_time_embed = TimestepEmbedding(self.encoder_dim)

        # text encoder
        self.ph_encoder = RelTransformerEncoder(
            302, self.encoder_dim, self.encoder_dim,
            self.encoder_dim * 2, 4, 6,
            3, 0.0, prenet=True, pre_ln=True)
        self.tone_embed = Embedding(32, self.encoder_dim, padding_idx=0)
        self.ph_pos_embed = PosEmb(self.encoder_dim)
        self.ling_pre_net = torch.nn.Sequential(*[
            torch.nn.Conv1d(self.encoder_dim, self.encoder_dim, kernel_size=s * 2, stride=s, padding=s // 2)
            for i, s in enumerate([2, 2])
        ])
    
    def forward(self, inputs, sigmas=None, x_noisy=None):
        ctx_mask = inputs['ctx_mask']
        ctx_feature = inputs['lat_ctx'] * ctx_mask

        """ local conditioning (prompt_latent + spk_embed) """
        ctx_mask_emb = self.ctx_mask_proj(ctx_mask)
        # ctx_feature = ctx_feature * (1 - inputs["spk_cfg_mask"][:, :, None])
        local_cond = torch.cat([ctx_feature, ctx_mask_emb], dim=-1)
        local_cond = self.local_cond_project(local_cond)

        """ diffusion target latent """
        x = inputs['lat']
    
        # Here, x is x1 in CFM
        x0 = torch.randn_like(x)
        t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow(x0, x)
        
        # define noisy_input and target
        t = t.bfloat16()
        x_noisy = (xt * (1 - ctx_mask)).bfloat16()
        target = ut

        # concat condition.
        x_ling = self.forward_ling_encoder(inputs["phone"], inputs["tone"])
        x_ling = self.ling_pre_net(expand_states(x_ling, inputs['mel2ph']).transpose(1, 2)).transpose(1, 2)
        x_noisy = self.x_prenet(x_noisy) + self.prenet(local_cond) + x_ling
        encoder_out = self.encoder(x_noisy, self.f5_time_embed(t), attn_mask=inputs["text_mel_mask"], do_checkpoint=False)
        pred = self.postnet(encoder_out)

        return pred, target
    
    def forward_ling_encoder(self, txt_tokens, tone_tokens):
        ph_tokens = txt_tokens
        ph_nonpadding = (ph_tokens > 0).float()[:, :, None]  # [B, T_phone, 1]

        # enc_ph
        ph_enc_oembed = self.tone_embed(tone_tokens)
        ph_enc_oembed = ph_enc_oembed + self.ph_pos_embed(
            torch.arange(0, ph_tokens.shape[1])[None,].to(ph_tokens.device))
        ph_enc_oembed = ph_enc_oembed
        ph_enc_oembed = ph_enc_oembed * ph_nonpadding
        x_ling = self.ph_encoder(ph_tokens, other_embeds=ph_enc_oembed) * ph_nonpadding
        return x_ling

    def _forward(self, x, local_cond, x_ling, timesteps, ctx_mask, dur=None, seq_cfg_w=[1.0,1.0]):
        """ When we use torchdiffeq, we need to include the CFG process inside _forward() """
        x = x * (1 - ctx_mask)
        x = self.x_prenet(x) + self.prenet(local_cond) + x_ling
        pred_v = self.encoder(x, self.f5_time_embed(timesteps), attn_mask=torch.ones((x.size(0), x.size(1)), device=x.device))
        pred = self.postnet(pred_v)

        """ Perform multi-cond CFG """
        cond_spk_txt, cond_txt, uncond = pred.chunk(3)
        pred = uncond + seq_cfg_w[0] * (cond_txt - uncond) + seq_cfg_w[1] * (cond_spk_txt - cond_txt)
        return pred

    @torch.no_grad()
    def inference(self, inputs, timesteps=20, seq_cfg_w=[1.0, 1.0], **kwargs):
        # txt embedding
        x_ling = self.forward_ling_encoder(inputs["phone"], inputs["tone"])
        x_ling = self.ling_pre_net(expand_states(x_ling, inputs['dur']).transpose(1, 2)).transpose(1, 2)

        # speaker embedding
        ctx_feature = inputs['lat_ctx']
        ctx_feature[1:, :, :] = 0 # prefix spk cfg
        ctx_mask_emb = self.ctx_mask_proj(inputs['ctx_mask'])

        # local conditioning.
        local_cond = torch.cat([ctx_feature, ctx_mask_emb], dim=-1)
        local_cond = self.local_cond_project(local_cond)
        
        ''' Euler ODE solver '''
        bsz, device, frm_len = (local_cond.size(0), local_cond.device, local_cond.size(1))
        # Sway sampling from F5-TTS (https://github.com/SWivid/F5-TTS), 
        # which is licensed under the MIT License.
        sway_sampling_coef = -1.0
        t_schedule = torch.linspace(0, 1, timesteps + 1, device=device, dtype=x_ling.dtype)
        if sway_sampling_coef is not None:
            t_schedule = t_schedule + sway_sampling_coef * (torch.cos(torch.pi / 2 * t_schedule) - 1 + t_schedule)
        
        # AMO sampling implementation for "AMO Sampler: Enhancing Text Rendering with Overshooting" (https://arxiv.org/pdf/2411.19415)
        def amo_sampling(z_t, t, t_next, v):
            # Upcast to avoid precision issues when computing prev_sample
            z_t = z_t.to(torch.float32)

            # Constant definition in Algorithm 1
            s = t_next
            c = 3

            # Line 7 in Algorithm 1
            o = min(t_next + c * (t_next - t), 1)
            pred_z_o = z_t + (o - t) * v

            # Line 11 in Algorithm 1
            a = s / o
            b = ((1 - s) ** 2 - (a * (1 - o)) ** 2) ** 0.5
            noise_i = torch.randn(size=z_t.shape, device=z_t.device)
            z_t_next = a * pred_z_o + b * noise_i
            return z_t_next.to(v.dtype)

        x = torch.randn([1, frm_len, self.out_channels], device=device)
        for step_index in range(timesteps):
            x = x.to(torch.float32)
            sigma = t_schedule[step_index].to(x_ling.dtype)
            sigma_next = t_schedule[step_index + 1]
            model_out = self._forward(torch.cat([x] * bsz), local_cond, x_ling, timesteps=sigma.unsqueeze(0), ctx_mask=inputs['ctx_mask'], dur=inputs['dur'], seq_cfg_w=seq_cfg_w)
            x = amo_sampling(x, sigma, sigma_next, model_out)
            # Cast sample back to model compatible dtype
            x = x.to(model_out.dtype)
        
        return x


================================================
FILE: tts/modules/llm_dit/time_embedding.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import torch
from torch import nn


class SinusPositionEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x, scale=1000):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
        emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class TimestepEmbedding(nn.Module):
    def __init__(self, dim, freq_embed_dim=256):
        super().__init__()
        self.time_embed = SinusPositionEmbedding(freq_embed_dim)
        self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))

    def forward(self, timestep):  # noqa: F821
        time_hidden = self.time_embed(timestep)
        time_hidden = time_hidden.to(timestep.dtype)
        time = self.time_mlp(time_hidden)  # b d
        return time

================================================
FILE: tts/modules/llm_dit/transformer.py
================================================
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permiss
Download .txt
gitextract_pb8s4_gh/

├── .github/
│   └── workflows/
│       └── publish_action.yml
├── .gitignore
├── LICENSE
├── README-CN.md
├── README.md
├── __init__.py
├── megatts3node.py
├── pyproject.toml
├── requirements.txt
├── tts/
│   ├── frontend_function.py
│   ├── modules/
│   │   ├── aligner/
│   │   │   └── whisper_small.py
│   │   ├── ar_dur/
│   │   │   ├── ar_dur_predictor.py
│   │   │   └── commons/
│   │   │       ├── layers.py
│   │   │       ├── nar_tts_modules.py
│   │   │       ├── rel_transformer.py
│   │   │       ├── rot_transformer.py
│   │   │       ├── seq_utils.py
│   │   │       └── transformer.py
│   │   ├── llm_dit/
│   │   │   ├── cfm.py
│   │   │   ├── dit.py
│   │   │   ├── time_embedding.py
│   │   │   └── transformer.py
│   │   └── wavvae/
│   │       ├── decoder/
│   │       │   ├── diag_gaussian.py
│   │       │   ├── hifigan_modules.py
│   │       │   ├── seanet_encoder.py
│   │       │   └── wavvae_v3.py
│   │       └── encoder/
│   │           └── common_modules/
│   │               ├── conv.py
│   │               ├── lstm.py
│   │               └── seanet.py
│   └── utils/
│       ├── audio_utils/
│       │   ├── align.py
│       │   ├── io.py
│       │   └── plot.py
│       ├── commons/
│       │   ├── ckpt_utils.py
│       │   └── hparams.py
│       └── text_utils/
│           ├── dict.json
│           ├── ph_tone_convert.py
│           ├── split_text.py
│           └── text_encoder.py
└── workflow-examples/
    ├── 单人语音.json
    └── 双人会话.json
Download .txt
SYMBOL INDEX (373 symbols across 29 files)

FILE: megatts3node.py
  function get_all_files (line 35) | def get_all_files(
  function get_speakers (line 89) | def get_speakers():
  class MegaTTS3DiTInfer (line 97) | class MegaTTS3DiTInfer():
    method __init__ (line 98) | def __init__(
    method clean (line 141) | def clean(self):
    method build_model (line 161) | def build_model(self, device):
    method preprocess (line 232) | def preprocess(self, audio_bytes, latent_file=None, topk_dur=1, **kwar...
    method forward (line 277) | def forward(self, texts, time_step, p_w, t_w, dur_disturb=0.1, dur_alp...
  class MegaTTS3SpeakersPreview (line 334) | class MegaTTS3SpeakersPreview:
    method INPUT_TYPES (line 336) | def INPUT_TYPES(s):
    method preview (line 346) | def preview(self, speaker):
  function cache_audio_tensor (line 361) | def cache_audio_tensor(
  function statistical_compare (line 384) | def statistical_compare(tensor1, tensor2):
  class MegaTTS3Run (line 402) | class MegaTTS3Run:
    method __init__ (line 403) | def __init__(self):
    method INPUT_TYPES (line 409) | def INPUT_TYPES(s):
    method clone (line 431) | def clone(self, audio, text, time_step, p_w, t_w, unload_model, audio_...
    method get_speaker_text_audio (line 511) | def get_speaker_text_audio(self, text, audio_1, audio_2):
  class MultiLinePromptMG (line 531) | class MultiLinePromptMG:
    method INPUT_TYPES (line 533) | def INPUT_TYPES(cls):
    method promptgen (line 548) | def promptgen(self, multi_line_prompt: str):

FILE: tts/frontend_function.py
  function g2p (line 24) | def g2p(self, text_inp):
  function align (line 40) | def align(self, wav):
  function make_dur_prompt (line 77) | def make_dur_prompt(self, mel2ph_ref, ph_ref, tone_ref):
  function dur_pred (line 95) | def dur_pred(self, ctx_dur_tokens, incremental_state_dur_prompt, ph_pred...
  function prepare_inputs_for_dit (line 154) | def prepare_inputs_for_dit(self, mel2ph_ref, mel2ph_pred, ph_ref, tone_r...

FILE: tts/modules/aligner/whisper_small.py
  class LayerNorm (line 42) | class LayerNorm(nn.LayerNorm):
    method forward (line 43) | def forward(self, x: Tensor) -> Tensor:
  class Linear (line 47) | class Linear(nn.Linear):
    method forward (line 48) | def forward(self, x: Tensor) -> Tensor:
  class Conv1d (line 56) | class Conv1d(nn.Conv1d):
    method _conv_forward (line 57) | def _conv_forward(
  function sinusoids (line 65) | def sinusoids(length, channels, max_timescale=10000):
  function disable_sdpa (line 75) | def disable_sdpa():
  class MultiHeadAttention (line 84) | class MultiHeadAttention(nn.Module):
    method __init__ (line 87) | def __init__(self, n_state: int, n_head: int):
    method forward (line 95) | def forward(
    method qkv_attention (line 118) | def qkv_attention(
  class ResidualAttentionBlock (line 134) | class ResidualAttentionBlock(nn.Module):
    method __init__ (line 135) | def __init__(self, n_state: int, n_head: int, cross_attention: bool = ...
    method forward (line 152) | def forward(
  class AudioEncoder (line 168) | class AudioEncoder(nn.Module):
    method __init__ (line 169) | def __init__(
    method forward (line 182) | def forward(self, x: Tensor, attn_mask: Tensor):
  class TextDecoder (line 201) | class TextDecoder(nn.Module):
    method __init__ (line 202) | def __init__(
    method forward (line 220) | def forward(self, x: Tensor, attn_mask: Tensor, xa: Tensor, kv_cache: ...
  class Whisper (line 246) | class Whisper(nn.Module):
    method __init__ (line 247) | def __init__(self):
    method embed_audio (line 261) | def embed_audio(self, mel: torch.Tensor):
    method logits (line 264) | def logits(self, tokens, audio_features, kv_cache=None):
    method forward (line 267) | def forward(
    method device (line 275) | def device(self):
    method install_kv_cache_hooks (line 278) | def install_kv_cache_hooks(self, cache: Optional[dict] = None):
    method sequence_mask (line 311) | def sequence_mask(self, seq_lens, max_len=None, device='cpu'):

FILE: tts/modules/ar_dur/ar_dur_predictor.py
  function fill_with_neg_inf2 (line 37) | def fill_with_neg_inf2(t):
  function expand_states (line 41) | def expand_states(h, mel2token):
  class CodePredictor (line 48) | class CodePredictor(nn.Module):
    method __init__ (line 49) | def __init__(self, hparams, hidden_size, dec_hidden_size, lm_num_layer...
    method forward_ling_encoder (line 87) | def forward_ling_encoder(
    method sample_one_step (line 130) | def sample_one_step(self, vq_pred):
    method forward_style_embed (line 148) | def forward_style_embed(self, spk_embed=None, spk_id=None, mel_ref=None):
    method buffered_future_mask (line 159) | def buffered_future_mask(self, tensor):
  class ARDurPredictor (line 171) | class ARDurPredictor(CodePredictor):
    method __init__ (line 172) | def __init__(self, hparams, hidden_size, dec_hidden_size, lm_num_layer...
    method forward (line 190) | def forward(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_em...
    method infer (line 265) | def infer(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
    method streaming_infer (line 322) | def streaming_infer(self, txt_tokens, ling_feas, char_tokens, ph2char,...

FILE: tts/modules/ar_dur/commons/layers.py
  class LayerNorm (line 19) | class LayerNorm(torch.nn.LayerNorm):
    method __init__ (line 25) | def __init__(self, nout, dim=-1, eps=1e-5):
    method forward (line 30) | def forward(self, x):
  class Reshape (line 41) | class Reshape(nn.Module):
    method __init__ (line 42) | def __init__(self, *args):
    method forward (line 46) | def forward(self, x):
  class Permute (line 50) | class Permute(nn.Module):
    method __init__ (line 51) | def __init__(self, *args):
    method forward (line 55) | def forward(self, x):
  function Embedding (line 59) | def Embedding(num_embeddings, embedding_dim, padding_idx=None):

FILE: tts/modules/ar_dur/commons/nar_tts_modules.py
  class LengthRegulator (line 23) | class LengthRegulator(torch.nn.Module):
    method __init__ (line 24) | def __init__(self, pad_value=0.0):
    method forward (line 28) | def forward(self, dur, dur_padding=None, alpha=1.0):
  class PosEmb (line 61) | class PosEmb(nn.Module):
    method __init__ (line 62) | def __init__(self, dim):
    method forward (line 70) | def forward(self, x):

FILE: tts/modules/ar_dur/commons/rel_transformer.py
  function convert_pad_shape (line 23) | def convert_pad_shape(pad_shape):
  function shift_1d (line 29) | def shift_1d(x):
  function sequence_mask (line 34) | def sequence_mask(length, max_length=None):
  class Encoder (line 41) | class Encoder(nn.Module):
    method __init__ (line 42) | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers...
    method forward (line 71) | def forward(self, x, x_mask, attn_mask=1):
  class MultiHeadAttention (line 100) | class MultiHeadAttention(nn.Module):
    method __init__ (line 101) | def __init__(self, channels, out_channels, n_heads, window_size=None, ...
    method forward (line 135) | def forward(self, x, c, attn_mask=None):
    method attention (line 145) | def attention(self, query, key, value, mask=None):
    method _matmul_with_relative_values (line 178) | def _matmul_with_relative_values(self, x, y):
    method _matmul_with_relative_keys (line 187) | def _matmul_with_relative_keys(self, x, y):
    method _get_relative_embeddings (line 196) | def _get_relative_embeddings(self, relative_embeddings, length):
    method _relative_position_to_absolute_position (line 211) | def _relative_position_to_absolute_position(self, x):
    method _absolute_position_to_relative_position (line 228) | def _absolute_position_to_relative_position(self, x):
    method _attention_bias_proximal (line 242) | def _attention_bias_proximal(self, length):
  class FFN (line 254) | class FFN(nn.Module):
    method __init__ (line 255) | def __init__(self, in_channels, out_channels, filter_channels, kernel_...
    method forward (line 268) | def forward(self, x, x_mask):
  class LayerNorm (line 279) | class LayerNorm(nn.Module):
    method __init__ (line 280) | def __init__(self, channels, eps=1e-4):
    method forward (line 288) | def forward(self, x):
  class ConvReluNorm (line 300) | class ConvReluNorm(nn.Module):
    method __init__ (line 301) | def __init__(self, in_channels, hidden_channels, out_channels, kernel_...
    method forward (line 325) | def forward(self, x, x_mask):
  class RelTransformerEncoder (line 335) | class RelTransformerEncoder(nn.Module):
    method __init__ (line 336) | def __init__(self,
    method forward (line 387) | def forward(self, x, x_mask=None, other_embeds=0, attn_mask=1):

FILE: tts/modules/ar_dur/commons/rot_transformer.py
  class SinusoidalPositionalEmbedding (line 29) | class SinusoidalPositionalEmbedding(nn.Module):
    method __init__ (line 35) | def __init__(self, embedding_dim, padding_idx, init_size=1024):
    method get_embedding (line 47) | def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
    method forward (line 65) | def forward(self, input, incremental_state=None, timestep=None, positi...
    method max_positions (line 86) | def max_positions(self):
  class RotaryEmbeddings (line 91) | class RotaryEmbeddings(nn.Module):
    method __init__ (line 96) | def __init__(
    method _create_rotary_embed (line 134) | def _create_rotary_embed(self, *, width: int, length: int):
    method _rotate (line 149) | def _rotate(self, input: torch.Tensor):
    method forward (line 164) | def forward(self, input: torch.Tensor, *, positions: Optional[torch.Te...
  class RotMultiheadAttention (line 206) | class RotMultiheadAttention(MultiheadAttention):
    method __init__ (line 207) | def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout...
    method forward (line 215) | def forward(
  class RotMultiheadAttention2 (line 404) | class RotMultiheadAttention2(MultiheadAttention):
    method __init__ (line 405) | def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout...
    method forward (line 413) | def forward(
  class RotDecSALayer (line 543) | class RotDecSALayer(nn.Module):
    method __init__ (line 544) | def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_...
    method forward (line 559) | def forward(
    method clear_buffer (line 604) | def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=N...
    method set_buffer (line 608) | def set_buffer(self, name, tensor, incremental_state):
  class RotDecSALayer2 (line 612) | class RotDecSALayer2(RotDecSALayer):
    method __init__ (line 613) | def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_...
  class RotTransformerDecoderLayer (line 622) | class RotTransformerDecoderLayer(nn.Module):
    method __init__ (line 623) | def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=8, f...
    method forward (line 642) | def forward(self, x, **kwargs):
    method clear_buffer (line 645) | def clear_buffer(self, *args):
    method set_buffer (line 648) | def set_buffer(self, *args):

FILE: tts/modules/ar_dur/commons/seq_utils.py
  function make_positions (line 20) | def make_positions(tensor, padding_idx):
  function softmax (line 35) | def softmax(x, dim):
  function sequence_mask (line 39) | def sequence_mask(lengths, maxlen=None, dtype=torch.bool):
  function weights_nonzero_speech (line 47) | def weights_nonzero_speech(target):
  function _get_full_incremental_state_key (line 57) | def _get_full_incremental_state_key(module_instance, key):
  function get_incremental_state (line 69) | def get_incremental_state(module, incremental_state, key):
  function set_incremental_state (line 77) | def set_incremental_state(module, incremental_state, key, value):
  function fill_with_neg_inf (line 84) | def fill_with_neg_inf(t):
  function fill_with_neg_inf2 (line 89) | def fill_with_neg_inf2(t):
  function select_attn (line 94) | def select_attn(attn_logits, type='best'):
  function make_pad_mask (line 112) | def make_pad_mask(lengths, xs=None, length_dim=-1):
  function make_non_pad_mask (line 218) | def make_non_pad_mask(lengths, xs=None, length_dim=-1):
  function get_mask_from_lengths (line 298) | def get_mask_from_lengths(lengths):
  function group_hidden_by_segs (line 305) | def group_hidden_by_segs(h, seg_ids, max_len):
  function expand_by_repeat_times (line 321) | def expand_by_repeat_times(source_encoding, lengths):
  function expand_word2ph (line 338) | def expand_word2ph(word_encoding, ph2word):

FILE: tts/modules/ar_dur/commons/transformer.py
  class SinusoidalPositionalEmbedding (line 27) | class SinusoidalPositionalEmbedding(nn.Module):
    method __init__ (line 33) | def __init__(self, embedding_dim, padding_idx, init_size=1024):
    method get_embedding (line 45) | def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
    method forward (line 63) | def forward(self, input, incremental_state=None, timestep=None, positi...
    method max_positions (line 84) | def max_positions(self):
  class TransformerFFNLayer (line 89) | class TransformerFFNLayer(nn.Module):
    method __init__ (line 90) | def __init__(self, hidden_size, filter_size, padding="SAME", kernel_si...
    method forward (line 105) | def forward(self, x, incremental_state=None):
    method _get_input_buffer (line 129) | def _get_input_buffer(self, incremental_state):
    method _set_input_buffer (line 136) | def _set_input_buffer(self, incremental_state, buffer):
    method clear_buffer (line 144) | def clear_buffer(self, incremental_state):
  class MultiheadAttention (line 152) | class MultiheadAttention(nn.Module):
    method __init__ (line 153) | def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout...
    method reset_parameters (line 201) | def reset_parameters(self):
    method forward (line 218) | def forward(
    method in_proj_qkv (line 432) | def in_proj_qkv(self, query):
    method in_proj_q (line 435) | def in_proj_q(self, query):
    method in_proj_k (line 444) | def in_proj_k(self, key):
    method in_proj_v (line 454) | def in_proj_v(self, value):
    method _in_proj (line 464) | def _in_proj(self, input, start=0, end=None):
    method _get_input_buffer (line 472) | def _get_input_buffer(self, incremental_state):
    method _set_input_buffer (line 479) | def _set_input_buffer(self, incremental_state, buffer):
    method apply_sparse_mask (line 487) | def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
    method clear_buffer (line 490) | def clear_buffer(self, incremental_state=None):
  class EncSALayer (line 500) | class EncSALayer(nn.Module):
    method __init__ (line 501) | def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
    method forward (line 516) | def forward(self, x, encoder_padding_mask=None, **kwargs):
  class DecSALayer (line 543) | class DecSALayer(nn.Module):
    method __init__ (line 544) | def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_...
    method forward (line 562) | def forward(
    method clear_buffer (line 631) | def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=N...
    method set_buffer (line 635) | def set_buffer(self, name, tensor, incremental_state):
  class TransformerEncoderLayer (line 639) | class TransformerEncoderLayer(nn.Module):
    method __init__ (line 640) | def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2, f...
    method forward (line 650) | def forward(self, x, **kwargs):
  class TransformerDecoderLayer (line 654) | class TransformerDecoderLayer(nn.Module):
    method __init__ (line 655) | def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2, f...
    method forward (line 666) | def forward(self, x, **kwargs):
    method clear_buffer (line 669) | def clear_buffer(self, *args):
    method set_buffer (line 672) | def set_buffer(self, *args):
  class FFTBlocks (line 676) | class FFTBlocks(nn.Module):
    method __init__ (line 677) | def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout...
    method forward (line 706) | def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens...
  class FastSpeechEncoder (line 734) | class FastSpeechEncoder(FFTBlocks):
    method __init__ (line 735) | def __init__(self, dict_size, hidden_size=256, num_layers=4, kernel_si...
    method forward (line 746) | def forward(self, txt_tokens, attn_mask=None, other_embeds=0):
    method forward_embedding (line 760) | def forward_embedding(self, txt_tokens):

FILE: tts/modules/llm_dit/cfm.py
  class LogitNormalTrainingTimesteps (line 36) | class LogitNormalTrainingTimesteps:
    method __init__ (line 37) | def __init__(self, T=1000.0, loc=0.0, scale=1.0):
    method sample (line 42) | def sample(self, size, device):
  function pad_t_like_x (line 47) | def pad_t_like_x(t, x):
  class ConditionalFlowMatcher (line 71) | class ConditionalFlowMatcher:
    method __init__ (line 82) | def __init__(self, sigma: Union[float, int] = 0.0):
    method compute_mu_t (line 92) | def compute_mu_t(self, x0, x1, t):
    method compute_sigma_t (line 115) | def compute_sigma_t(self, t):
    method sample_xt (line 134) | def sample_xt(self, x0, x1, t, epsilon):
    method compute_conditional_flow (line 161) | def compute_conditional_flow(self, x0, x1, t, xt):
    method sample_noise_like (line 186) | def sample_noise_like(self, x):
    method sample_location_and_conditional_flow (line 189) | def sample_location_and_conditional_flow(self, x0, x1, t=None, return_...
    method compute_lambda (line 233) | def compute_lambda(self, t):
  class VariancePreservingConditionalFlowMatcher (line 252) | class VariancePreservingConditionalFlowMatcher(ConditionalFlowMatcher):
    method compute_mu_t (line 260) | def compute_mu_t(self, x0, x1, t):
    method compute_conditional_flow (line 282) | def compute_conditional_flow(self, x0, x1, t, xt):

FILE: tts/modules/llm_dit/dit.py
  class Diffusion (line 27) | class Diffusion(nn.Module):
    method __init__ (line 28) | def __init__(self):
    method forward (line 70) | def forward(self, inputs, sigmas=None, x_noisy=None):
    method forward_ling_encoder (line 101) | def forward_ling_encoder(self, txt_tokens, tone_tokens):
    method _forward (line 114) | def _forward(self, x, local_cond, x_ling, timesteps, ctx_mask, dur=Non...
    method inference (line 127) | def inference(self, inputs, timesteps=20, seq_cfg_w=[1.0, 1.0], **kwar...

FILE: tts/modules/llm_dit/time_embedding.py
  class SinusPositionEmbedding (line 20) | class SinusPositionEmbedding(nn.Module):
    method __init__ (line 21) | def __init__(self, dim):
    method forward (line 25) | def forward(self, x, scale=1000):
  class TimestepEmbedding (line 34) | class TimestepEmbedding(nn.Module):
    method __init__ (line 35) | def __init__(self, dim, freq_embed_dim=256):
    method forward (line 40) | def forward(self, timestep):  # noqa: F821

FILE: tts/modules/llm_dit/transformer.py
  function precompute_freqs_cis (line 23) | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
  function reshape_for_broadcast (line 31) | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
  function apply_rotary_emb (line 39) | def apply_rotary_emb(
  class AdaLNZero (line 52) | class AdaLNZero(nn.Module):
    method __init__ (line 53) | def __init__(self, dim):
    method forward (line 59) | def forward(self, x, emb=None):
  class AdaLNZero_Out (line 66) | class AdaLNZero_Out(nn.Module):
    method __init__ (line 67) | def __init__(self, dim):
    method forward (line 73) | def forward(self, x, emb):
  class Attention (line 80) | class Attention(nn.Module):
    method __init__ (line 81) | def __init__(self, encoder_dim, encoder_n_heads, max_seq_len):
    method forward (line 107) | def forward(
  class FeedForward (line 130) | class FeedForward(nn.Module):
    method __init__ (line 131) | def __init__(
    method forward (line 150) | def forward(self, x):
  class TransformerBlock (line 154) | class TransformerBlock(nn.Module):
    method __init__ (line 155) | def __init__(self, encoder_dim, encoder_n_heads, max_seq_len):
    method forward (line 170) | def forward(
  class Transformer (line 207) | class Transformer(nn.Module):
    method __init__ (line 208) | def __init__(self, encoder_n_layers, encoder_dim, encoder_n_heads, max...
    method forward (line 224) | def forward(self, x, t, attn_mask, start_pos=0):

FILE: tts/modules/wavvae/decoder/diag_gaussian.py
  class DiagonalGaussianDistribution (line 18) | class DiagonalGaussianDistribution(object):
    method __init__ (line 19) | def __init__(self, parameters: torch.Tensor, deterministic: bool = Fal...
    method sample (line 31) | def sample(self, generator=None) -> torch.Tensor:
    method kl (line 42) | def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Te...
    method nll (line 57) | def nll(self, sample, dims) -> torch.Tensor:
    method mode (line 66) | def mode(self) -> torch.Tensor:

FILE: tts/modules/wavvae/decoder/hifigan_modules.py
  function init_weights (line 25) | def init_weights(m, mean=0.0, std=0.01):
  function get_padding (line 31) | def get_padding(kernel_size, dilation=1):
  class Upsample (line 35) | class Upsample(nn.Module):
    method __init__ (line 36) | def __init__(self, mult, r):
    method forward (line 52) | def forward(self, x):
  class Downsample (line 59) | class Downsample(nn.Module):
    method __init__ (line 60) | def __init__(self, mult, r):
    method forward (line 70) | def forward(self, x):
  function weights_init (line 75) | def weights_init(m):
  function weights_zero_init (line 84) | def weights_zero_init(m):
  function WNConv1d (line 91) | def WNConv1d(*args, **kwargs):
  function WNConvTranspose1d (line 95) | def WNConvTranspose1d(*args, **kwargs):
  class Audio2Mel (line 99) | class Audio2Mel(nn.Module):
    method __init__ (line 100) | def __init__(
    method forward (line 129) | def forward(self, audio):
  class ResnetBlock (line 149) | class ResnetBlock(nn.Module):
    method __init__ (line 150) | def __init__(self, dim, dilation=1, dim_in=None):
    method forward (line 164) | def forward(self, x):
  class ResBlockMRFV2 (line 174) | class ResBlockMRFV2(torch.nn.Module):
    method __init__ (line 175) | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
    method forward (line 197) | def forward(self, x):
    method remove_weight_norm (line 206) | def remove_weight_norm(self):
  class ResBlockMRFV2Inter (line 213) | class ResBlockMRFV2Inter(torch.nn.Module):
    method __init__ (line 214) | def __init__(self, channels, kernel_size=3):
    method forward (line 220) | def forward(self, x):
  class Generator (line 228) | class Generator(nn.Module):
    method __init__ (line 229) | def __init__(self, input_size_, ngf, n_residual_layers, num_band, args...
    method forward (line 265) | def forward(self, mel, step=None):

FILE: tts/modules/wavvae/decoder/seanet_encoder.py
  class Encoder (line 21) | class Encoder(nn.Module):
    method __init__ (line 22) | def __init__(
    method forward (line 35) | def forward(self, audio: torch.Tensor):

FILE: tts/modules/wavvae/decoder/wavvae_v3.py
  class WavVAE_V3 (line 25) | class WavVAE_V3(nn.Module):
    method __init__ (line 26) | def __init__(self, hparams=None):
    method encode_latent (line 41) | def encode_latent(self, audio):
    method encode (line 46) | def encode(self, audio):
    method decode (line 52) | def decode(self, latent):
    method forward (line 56) | def forward(self, audio):

FILE: tts/modules/wavvae/encoder/common_modules/conv.py
  function apply_parametrization_norm (line 47) | def apply_parametrization_norm(module: nn.Module, norm: str = 'none') ->...
  function get_norm_module (line 57) | def get_norm_module(module: nn.Module, causal: bool = False, norm: str =...
  function get_extra_padding_for_conv1d (line 71) | def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stri...
  function pad1d (line 79) | def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'ze...
  class ConvLayerNorm (line 96) | class ConvLayerNorm(nn.LayerNorm):
    method __init__ (line 97) | def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch...
    method forward (line 100) | def forward(self, x):
  class NormConv1d (line 107) | class NormConv1d(nn.Module):
    method __init__ (line 108) | def __init__(self, *args, causal: bool = False, norm: str = 'none',
    method forward (line 115) | def forward(self, x):
  class SConv1d (line 121) | class SConv1d(nn.Module):
    method __init__ (line 122) | def __init__(self, in_channels: int, out_channels: int,
    method forward (line 138) | def forward(self, x):

FILE: tts/modules/wavvae/encoder/common_modules/lstm.py
  class SLSTM (line 34) | class SLSTM(nn.Module):
    method __init__ (line 39) | def __init__(self, dimension: int, num_layers: int = 2, skip: bool = T...
    method forward (line 45) | def forward(self, x):

FILE: tts/modules/wavvae/encoder/common_modules/seanet.py
  class SEANetResnetBlock (line 41) | class SEANetResnetBlock(nn.Module):
    method __init__ (line 42) | def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dila...
    method forward (line 68) | def forward(self, x):
  class SEANetEncoder (line 72) | class SEANetEncoder(nn.Module):
    method __init__ (line 73) | def __init__(self, channels: int = 1, dimension: int = 128, n_filters:...
    method forward (line 125) | def forward(self, x):

FILE: tts/utils/audio_utils/align.py
  function mel2token_to_dur (line 17) | def mel2token_to_dur(mel2token, T_txt=None, max_dur=None):

FILE: tts/utils/audio_utils/io.py
  function to_wav_bytes (line 25) | def to_wav_bytes(wav, sr, norm=False):
  function save_wav (line 39) | def save_wav(wav_bytes, path):
  function to_mp3 (line 46) | def to_mp3(out_path):
  function convert_to_wav (line 55) | def convert_to_wav(wav_path):
  function convert_to_wav_bytes (line 73) | def convert_to_wav_bytes(audio_binary):
  function combine_audio_segments (line 83) | def combine_audio_segments(segments, crossfade_duration=0.16, sr=24000):

FILE: tts/utils/audio_utils/plot.py
  function spec_to_figure (line 25) | def spec_to_figure(spec, vmin=None, vmax=None, title='', f0s=None, dur_i...
  function align_to_figure (line 73) | def align_to_figure(align, dur_info):

FILE: tts/utils/commons/ckpt_utils.py
  function dist_load (line 28) | def dist_load(path):
  function torch_load_dist (line 48) | def torch_load_dist(path, map_location='cpu'):
  function get_last_checkpoint (line 54) | def get_last_checkpoint(work_dir, steps=None):
  function get_all_ckpts (line 64) | def get_all_ckpts(work_dir, steps=None):
  function load_ckpt (line 73) | def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, ...
  function load_with_size_mismatch (line 161) | def load_with_size_mismatch(model, state_dict, prefix=""):

FILE: tts/utils/commons/hparams.py
  class Args (line 26) | class Args:
    method __init__ (line 27) | def __init__(self, **kwargs):
  function override_config (line 32) | def override_config(old_config: dict, new_config: dict):
  function traverse_dict (line 42) | def traverse_dict(d, func, ctx):
  function parse_config (line 51) | def parse_config(v, context=None):
  function remove_meta_key (line 66) | def remove_meta_key(d):
  function load_config (line 76) | def load_config(config_fn, config_chains, loaded_configs):
  function set_hparams (line 103) | def set_hparams(config='', exp_name='', hparams_str='', print_hparams=Tr...

FILE: tts/utils/text_utils/ph_tone_convert.py
  function map_phone_to_tokendict (line 18) | def map_phone_to_tokendict(item, pad_bos_eos=True):
  function split_ph_timestamp (line 39) | def split_ph_timestamp(ph_timestamp):
  function split_ph (line 72) | def split_ph(ph_seq):

FILE: tts/utils/text_utils/split_text.py
  function chunk_text_chinese (line 18) | def chunk_text_chinese(text, limit=60):
  function chunk_text_english (line 60) | def chunk_text_english(text, max_chars=130):
  function chunk_text_chinesev2 (line 90) | def chunk_text_chinesev2(text, limit=60, look_ahead_limit=30):

FILE: tts/utils/text_utils/text_encoder.py
  function strip_ids (line 44) | def strip_ids(ids, ids_to_strip):
  class TextEncoder (line 52) | class TextEncoder(object):
    method __init__ (line 55) | def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS):
    method num_reserved_ids (line 59) | def num_reserved_ids(self):
    method encode (line 62) | def encode(self, s):
    method decode (line 78) | def decode(self, ids, strip_extraneous=False):
    method decode_list (line 95) | def decode_list(self, ids):
    method vocab_size (line 117) | def vocab_size(self):
  class TokenTextEncoder (line 121) | class TokenTextEncoder(TextEncoder):
    method __init__ (line 124) | def __init__(self,
    method encode (line 161) | def encode(self, s):
    method decode (line 174) | def decode(self, ids, strip_eos=False, strip_padding=False):
    method decode_list (line 183) | def decode_list(self, ids):
    method vocab_size (line 188) | def vocab_size(self):
    method __len__ (line 191) | def __len__(self):
    method _safe_id_to_token (line 194) | def _safe_id_to_token(self, idx):
    method _init_vocab_from_file (line 197) | def _init_vocab_from_file(self, filename):
    method _init_vocab_from_list (line 212) | def _init_vocab_from_list(self, vocab_list):
    method _init_vocab (line 229) | def _init_vocab(self, token_generator, add_reserved_tokens=True):
    method pad (line 245) | def pad(self):
    method eos (line 248) | def eos(self):
    method unk (line 251) | def unk(self):
    method seg (line 254) | def seg(self):
    method store_to_file (line 257) | def store_to_file(self, filename):
    method sil_phonemes (line 270) | def sil_phonemes(self):
  function build_token_encoder (line 274) | def build_token_encoder(token_list_file):
  function is_sil_phoneme (line 279) | def is_sil_phoneme(p):
Condensed preview — 40 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (309K chars).
[
  {
    "path": ".github/workflows/publish_action.yml",
    "chars": 581,
    "preview": "name: Publish to Comfy registry\non:\n  workflow_dispatch:\n  push:\n    branches:\n      - master\n      - main\n    paths:\n  "
  },
  {
    "path": ".gitignore",
    "chars": 3443,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": "LICENSE",
    "chars": 11341,
    "preview": "\n                                 Apache License\n                           Version 2.0, January 2004\n                  "
  },
  {
    "path": "README-CN.md",
    "chars": 3080,
    "preview": "[中文](README-CN.md) | [English](README.md) \r\n\r\n# ComfyUI 的 MegaTTS3 声音克隆节点\r\n\r\n声音克隆质量非常高, 支持中英文, 并可跨语言克隆. **支持自定义音色!!! 超长文"
  },
  {
    "path": "README.md",
    "chars": 3935,
    "preview": "[中文](README-CN.md) | [English](README.md)\n\n# MegaTTS3 Voice Cloning Nodes for ComfyUI\n\nHigh-quality voice cloning, suppo"
  },
  {
    "path": "__init__.py",
    "chars": 140,
    "preview": "from .megatts3node import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS\r\n\r\n__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DIS"
  },
  {
    "path": "megatts3node.py",
    "chars": 22364,
    "preview": "import json\nimport os\nimport librosa\nimport numpy as np\nimport torch\nimport torchaudio\nfrom typing import List, Union, O"
  },
  {
    "path": "pyproject.toml",
    "chars": 539,
    "preview": "[project]\r\nname = \"megatts3-mw\"\r\ndescription = \"Lightweight and Efficient, 🎧Ultra High-Quality Voice Cloning, Chinese an"
  },
  {
    "path": "requirements.txt",
    "chars": 198,
    "preview": "setproctitle\nattrdict\nlibrosa\npyloudnorm\nx-transformers\ntorchdiffeq\nopenai-whisper>=20240930\nlangdetect\npynini==2.1.6; p"
  },
  {
    "path": "tts/frontend_function.py",
    "chars": 8308,
    "preview": "# Copyright 2025 ByteDance and/or its affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# "
  },
  {
    "path": "tts/modules/aligner/whisper_small.py",
    "chars": 11771,
    "preview": "# MIT License\n\n# Copyright (c) 2022 OpenAI\n\n# Permission is hereby granted, free of charge, to any person obtaining a co"
  },
  {
    "path": "tts/modules/ar_dur/ar_dur_predictor.py",
    "chars": 17289,
    "preview": "# Copyright 2025 ByteDance and/or its affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# "
  },
  {
    "path": "tts/modules/ar_dur/commons/layers.py",
    "chars": 2000,
    "preview": "# Copyright 2025 ByteDance and/or its affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# "
  },
  {
    "path": "tts/modules/ar_dur/commons/nar_tts_modules.py",
    "chars": 2766,
    "preview": "# Copyright 2025 ByteDance and/or its affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# "
  },
  {
    "path": "tts/modules/ar_dur/commons/rel_transformer.py",
    "chars": 16002,
    "preview": "# Copyright 2025 ByteDance and/or its affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# "
  },
  {
    "path": "tts/modules/ar_dur/commons/rot_transformer.py",
    "chars": 27950,
    "preview": "# Copyright 2025 ByteDance and/or its affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# "
  },
  {
    "path": "tts/modules/ar_dur/commons/seq_utils.py",
    "chars": 12667,
    "preview": "# Copyright 2025 ByteDance and/or its affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# "
  },
  {
    "path": "tts/modules/ar_dur/commons/transformer.py",
    "chars": 32062,
    "preview": "# Copyright 2025 ByteDance and/or its affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# "
  },
  {
    "path": "tts/modules/llm_dit/cfm.py",
    "chars": 10404,
    "preview": "# MIT License\n\n# Copyright (c) 2023 Alexander Tong\n\n# Permission is hereby granted, free of charge, to any person obtain"
  },
  {
    "path": "tts/modules/llm_dit/dit.py",
    "chars": 8036,
    "preview": "# Copyright 2025 ByteDance and/or its affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# "
  },
  {
    "path": "tts/modules/llm_dit/time_embedding.py",
    "chars": 1616,
    "preview": "# Copyright 2025 ByteDance and/or its affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# "
  },
  {
    "path": "tts/modules/llm_dit/transformer.py",
    "chars": 8361,
    "preview": "# Copyright 2025 ByteDance and/or its affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# "
  },
  {
    "path": "tts/modules/wavvae/decoder/diag_gaussian.py",
    "chars": 2538,
    "preview": "# Copyright 2025 ByteDance and/or its affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# "
  },
  {
    "path": "tts/modules/wavvae/decoder/hifigan_modules.py",
    "chars": 9995,
    "preview": "# Copyright 2025 ByteDance and/or its affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# "
  },
  {
    "path": "tts/modules/wavvae/decoder/seanet_encoder.py",
    "chars": 1506,
    "preview": "# Copyright 2025 ByteDance and/or its affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# "
  },
  {
    "path": "tts/modules/wavvae/decoder/wavvae_v3.py",
    "chars": 2287,
    "preview": "# Copyright 2025 ByteDance and/or its affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# "
  },
  {
    "path": "tts/modules/wavvae/encoder/common_modules/conv.py",
    "chars": 6498,
    "preview": "# MIT License\n\n# Copyright (c) Meta Platforms, Inc. and affiliates.\n\n# Permission is hereby granted, free of charge, to "
  },
  {
    "path": "tts/modules/wavvae/encoder/common_modules/lstm.py",
    "chars": 2125,
    "preview": "# MIT License\n\n# Copyright (c) Meta Platforms, Inc. and affiliates.\n\n# Permission is hereby granted, free of charge, to "
  },
  {
    "path": "tts/modules/wavvae/encoder/common_modules/seanet.py",
    "chars": 5808,
    "preview": "# MIT License\n\n# Copyright (c) Meta Platforms, Inc. and affiliates.\n\n# Permission is hereby granted, free of charge, to "
  },
  {
    "path": "tts/utils/audio_utils/align.py",
    "chars": 1294,
    "preview": "# Copyright 2025 ByteDance and/or its affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# "
  },
  {
    "path": "tts/utils/audio_utils/io.py",
    "chars": 3277,
    "preview": "# Copyright 2025 ByteDance and/or its affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# "
  },
  {
    "path": "tts/utils/audio_utils/plot.py",
    "chars": 3334,
    "preview": "# Copyright 2025 ByteDance and/or its affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# "
  },
  {
    "path": "tts/utils/commons/ckpt_utils.py",
    "chars": 7919,
    "preview": "# Copyright 2025 ByteDance and/or its affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# "
  },
  {
    "path": "tts/utils/commons/hparams.py",
    "chars": 8311,
    "preview": "# Copyright 2025 ByteDance and/or its affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# "
  },
  {
    "path": "tts/utils/text_utils/dict.json",
    "chars": 1686,
    "preview": "{\"phone\": [\"C0a\", \"C0ai\", \"C0air\", \"C0an\", \"C0ang\", \"C0angr\", \"C0anr\", \"C0ao\", \"C0aor\", \"C0ar\", \"C0b\", \"C0c\", \"C0ch\", \"C"
  },
  {
    "path": "tts/utils/text_utils/ph_tone_convert.py",
    "chars": 3568,
    "preview": "# Copyright 2025 ByteDance and/or its affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# "
  },
  {
    "path": "tts/utils/text_utils/split_text.py",
    "chars": 9192,
    "preview": "# -*- coding: utf-8 -*-\n# Copyright 2025 ByteDance and/or its affiliates.\n#\n# Licensed under the Apache License, Version"
  },
  {
    "path": "tts/utils/text_utils/text_encoder.py",
    "chars": 9371,
    "preview": "# Copyright 2025 ByteDance and/or its affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# "
  },
  {
    "path": "workflow-examples/单人语音.json",
    "chars": 3852,
    "preview": "{\"id\":\"f4285961-b2b3-477d-8027-364fea281241\",\"revision\":0,\"last_node_id\":27,\"last_link_id\":74,\"nodes\":[{\"id\":15,\"type\":\""
  },
  {
    "path": "workflow-examples/双人会话.json",
    "chars": 4331,
    "preview": "{\"id\":\"eb4d7439-617f-4e18-beba-af841ebc9d1a\",\"revision\":0,\"last_node_id\":22,\"last_link_id\":54,\"nodes\":[{\"id\":11,\"type\":\""
  }
]

About this extraction

This page contains the full source code of the billwuhao/ComfyUI_MegaTTS3 GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 40 files (284.9 KB), approximately 76.0k tokens, and a symbol index with 373 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!