Full Code of SparkAudio/Spark-TTS for AI

main 2f1ea9082400 cached
44 files
280.4 KB
69.9k tokens
273 symbols
1 requests
Download .txt
Showing preview only (295K chars total). Download the full file or copy to clipboard to get everything.
Repository: SparkAudio/Spark-TTS
Branch: main
Commit: 2f1ea9082400
Files: 44
Total size: 280.4 KB

Directory structure:
gitextract_kozm2q6w/

├── .gitignore
├── LICENSE
├── README.md
├── cli/
│   ├── SparkTTS.py
│   └── inference.py
├── example/
│   └── infer.sh
├── requirements.txt
├── runtime/
│   └── triton_trtllm/
│       ├── Dockerfile.server
│       ├── README.md
│       ├── client_grpc.py
│       ├── client_http.py
│       ├── docker-compose.yml
│       ├── model_repo/
│       │   ├── audio_tokenizer/
│       │   │   ├── 1/
│       │   │   │   └── model.py
│       │   │   └── config.pbtxt
│       │   ├── spark_tts/
│       │   │   ├── 1/
│       │   │   │   └── model.py
│       │   │   └── config.pbtxt
│       │   ├── tensorrt_llm/
│       │   │   ├── 1/
│       │   │   │   └── .gitkeep
│       │   │   └── config.pbtxt
│       │   └── vocoder/
│       │       ├── 1/
│       │       │   └── model.py
│       │       └── config.pbtxt
│       ├── run.sh
│       └── scripts/
│           ├── convert_checkpoint.py
│           └── fill_template.py
├── sparktts/
│   ├── models/
│   │   ├── audio_tokenizer.py
│   │   └── bicodec.py
│   ├── modules/
│   │   ├── blocks/
│   │   │   ├── layers.py
│   │   │   ├── samper.py
│   │   │   └── vocos.py
│   │   ├── encoder_decoder/
│   │   │   ├── feat_decoder.py
│   │   │   ├── feat_encoder.py
│   │   │   └── wave_generator.py
│   │   ├── fsq/
│   │   │   ├── finite_scalar_quantization.py
│   │   │   └── residual_fsq.py
│   │   ├── speaker/
│   │   │   ├── ecapa_tdnn.py
│   │   │   ├── perceiver_encoder.py
│   │   │   ├── pooling_layers.py
│   │   │   └── speaker_encoder.py
│   │   └── vq/
│   │       └── factorized_vector_quantize.py
│   └── utils/
│       ├── __init__.py
│       ├── audio.py
│       ├── file.py
│       ├── parse_options.sh
│       └── token_parser.py
└── webui.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
pretrained_models/
results/
demo/
# C extensions
*.so
.gradio/
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
webui_test.py

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# UV
#   Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#uv.lock

# poetry
#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
#   in version control.
#   https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
#  and can be added to the global gitignore or merged into this file.  For a more nuclear
#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# PyPI configuration file
.pypirc


================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
<div align="center">
    <h1>
    Spark-TTS
    </h1>
    <p>
    Official PyTorch code for inference of <br>
    <b><em>Spark-TTS: An Efficient LLM-Based Text-to-Speech Model with Single-Stream Decoupled Speech Tokens</em></b>
    </p>
    <p>
    <img src="src/logo/SparkTTS.jpg" alt="Spark-TTS Logo" style="width: 200px; height: 200px;">
    </p>
        <p>
        <img src="src/logo/HKUST.jpg" alt="Institution 1" style="width: 200px; height: 60px;">
        <img src="src/logo/mobvoi.jpg" alt="Institution 2" style="width: 200px; height: 60px;">
        <img src="src/logo/SJU.jpg" alt="Institution 3" style="width: 200px; height: 60px;">
    </p>
    <p>
        <img src="src/logo/NTU.jpg" alt="Institution 4" style="width: 200px; height: 60px;">
        <img src="src/logo/NPU.jpg" alt="Institution 5" style="width: 200px; height: 60px;">
        <img src="src/logo/SparkAudio2.jpg" alt="Institution 6" style="width: 200px; height: 60px;">
    </p>
    <p>
    </p>
    <a href="https://arxiv.org/pdf/2503.01710"><img src="https://img.shields.io/badge/Paper-ArXiv-red" alt="paper"></a>
    <a href="https://sparkaudio.github.io/spark-tts/"><img src="https://img.shields.io/badge/Demo-Page-lightgrey" alt="version"></a>
    <a href="https://huggingface.co/SparkAudio/Spark-TTS-0.5B"><img src="https://img.shields.io/badge/Hugging%20Face-Model%20Page-yellow" alt="Hugging Face"></a>
    <a href="https://github.com/SparkAudio/Spark-TTS"><img src="https://img.shields.io/badge/Platform-linux-lightgrey" alt="version"></a>
    <a href="https://github.com/SparkAudio/Spark-TTS"><img src="https://img.shields.io/badge/Python-3.12+-orange" alt="version"></a>
    <a href="https://github.com/SparkAudio/Spark-TTS"><img src="https://img.shields.io/badge/PyTorch-2.5+-brightgreen" alt="python"></a>
    <a href="https://github.com/SparkAudio/Spark-TTS"><img src="https://img.shields.io/badge/License-Apache%202.0-blue.svg" alt="mit"></a>
</div>


## Spark-TTS 🔥

### Overview

Spark-TTS is an advanced text-to-speech system that uses the power of large language models (LLM) for highly accurate and natural-sounding voice synthesis. It is designed to be efficient, flexible, and powerful for both research and production use.

### Key Features

- **Simplicity and Efficiency**: Built entirely on Qwen2.5, Spark-TTS eliminates the need for additional generation models like flow matching. Instead of relying on separate models to generate acoustic features, it directly reconstructs audio from the code predicted by the LLM. This approach streamlines the process, improving efficiency and reducing complexity.
- **High-Quality Voice Cloning**: Supports zero-shot voice cloning, which means it can replicate a speaker's voice even without specific training data for that voice. This is ideal for cross-lingual and code-switching scenarios, allowing for seamless transitions between languages and voices without requiring separate training for each one.
- **Bilingual Support**: Supports both Chinese and English, and is capable of zero-shot voice cloning for cross-lingual and code-switching scenarios, enabling the model to synthesize speech in multiple languages with high naturalness and accuracy.
- **Controllable Speech Generation**: Supports creating virtual speakers by adjusting parameters such as gender, pitch, and speaking rate.

---

<table align="center">
  <tr>
    <td align="center"><b>Inference Overview of Voice Cloning</b><br><img src="src/figures/infer_voice_cloning.png" width="80%" /></td>
  </tr>
  <tr>
    <td align="center"><b>Inference Overview of Controlled Generation</b><br><img src="src/figures/infer_control.png" width="80%" /></td>
  </tr>
</table>


## 🚀 News

- **[2025-03-04]** Our paper on this project has been published! You can read it here: [Spark-TTS](https://arxiv.org/pdf/2503.01710). 

- **[2025-03-12]** Nvidia Triton Inference Serving is now supported. See the Runtime section below for more details.


## Install
**Clone and Install**

  Here are instructions for installing on Linux. If you're on Windows, please refer to the [Windows Installation Guide](https://github.com/SparkAudio/Spark-TTS/issues/5).  
*(Thanks to [@AcTePuKc](https://github.com/AcTePuKc) for the detailed Windows instructions!)*


- Clone the repo
``` sh
git clone https://github.com/SparkAudio/Spark-TTS.git
cd Spark-TTS
```

- Install Conda: please see https://docs.conda.io/en/latest/miniconda.html
- Create Conda env:

``` sh
conda create -n sparktts -y python=3.12
conda activate sparktts
pip install -r requirements.txt
# If you are in mainland China, you can set the mirror as follows:
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
```

**Model Download**

Download via python:
```python
from huggingface_hub import snapshot_download

snapshot_download("SparkAudio/Spark-TTS-0.5B", local_dir="pretrained_models/Spark-TTS-0.5B")
```

Download via git clone:
```sh
mkdir -p pretrained_models

# Make sure you have git-lfs installed (https://git-lfs.com)
git lfs install

git clone https://huggingface.co/SparkAudio/Spark-TTS-0.5B pretrained_models/Spark-TTS-0.5B
```

**Basic Usage**

You can simply run the demo with the following commands:
``` sh
cd example
bash infer.sh
```

Alternatively, you can directly execute the following command in the command line to perform inference:

``` sh
python -m cli.inference \
    --text "text to synthesis." \
    --device 0 \
    --save_dir "path/to/save/audio" \
    --model_dir pretrained_models/Spark-TTS-0.5B \
    --prompt_text "transcript of the prompt audio" \
    --prompt_speech_path "path/to/prompt_audio"
```

**Web UI Usage**

You can start the UI interface by running `python webui.py --device 0`, which allows you to perform Voice Cloning and Voice Creation. Voice Cloning supports uploading reference audio or directly recording the audio.


| **Voice Cloning** | **Voice Creation** |
|:-------------------:|:-------------------:|
| ![Image 1](src/figures/gradio_TTS.png) | ![Image 2](src/figures/gradio_control.png) |


**Optional Methods**

For additional CLI and Web UI methods, including alternative implementations and extended functionalities, you can refer to:

- [CLI and UI by AcTePuKc](https://github.com/SparkAudio/Spark-TTS/issues/10)


## Runtime

**Nvidia Triton Inference Serving**

We now provide a reference for deploying Spark-TTS with Nvidia Triton and TensorRT-LLM. The table below presents benchmark results on a single L20 GPU, using 26 different prompt_audio/target_text pairs (totalling 169 seconds of audio):

| Model | Note   | Concurrency | Avg Latency     | RTF | 
|-------|-----------|-----------------------|---------|--|
| Spark-TTS-0.5B | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 1                   | 876.24 ms | 0.1362|
| Spark-TTS-0.5B | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 2                   | 920.97 ms | 0.0737|
| Spark-TTS-0.5B | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 4                   | 1611.51 ms | 0.0704|


Please see the detailed instructions in [runtime/triton_trtllm/README.md](runtime/triton_trtllm/README.md ) for more information.


## **Demos**

Here are some demos generated by Spark-TTS using zero-shot voice cloning. For more demos, visit our [demo page](https://sparkaudio.github.io/spark-tts/).

---

<table>
<tr>
<td align="center">
    
**Donald Trump**
</td>
<td align="center">
    
**Zhongli (Genshin Impact)**
</td>
</tr>

<tr>
<td align="center">

[Donald Trump](https://github.com/user-attachments/assets/fb225780-d9fe-44b2-9b2e-54390cb3d8fd)

</td>
<td align="center">
    
[Zhongli](https://github.com/user-attachments/assets/80eeb9c7-0443-4758-a1ce-55ac59e64bd6)

</td>
</tr>
</table>

---

<table>

<tr>
<td align="center">
    
**陈鲁豫 Chen Luyu**
</td>
<td align="center">
    
**杨澜 Yang Lan**
</td>
</tr>

<tr>
<td align="center">
    
[陈鲁豫Chen_Luyu.webm](https://github.com/user-attachments/assets/5c6585ae-830d-47b1-992d-ee3691f48cf4)
</td>
<td align="center">
    
[Yang_Lan.webm](https://github.com/user-attachments/assets/2fb3d00c-abc3-410e-932f-46ba204fb1d7)
</td>
</tr>
</table>

---


<table>
<tr>
<td align="center">
    
**余承东 Richard Yu**
</td>
<td align="center">
    
**马云 Jack Ma**
</td>
</tr>

<tr>
<td align="center">

[Yu_Chengdong.webm](https://github.com/user-attachments/assets/78feca02-84bb-4d3a-a770-0cfd02f1a8da)

</td>
<td align="center">
    
[Ma_Yun.webm](https://github.com/user-attachments/assets/2d54e2eb-cec4-4c2f-8c84-8fe587da321b)

</td>
</tr>
</table>

---


<table>
<tr>
<td align="center">
    
**刘德华 Andy Lau**
</td>
<td align="center">

**徐志胜 Xu Zhisheng**
</td>
</tr>

<tr>
<td align="center">

[Liu_Dehua.webm](https://github.com/user-attachments/assets/195b5e97-1fee-4955-b954-6d10fa04f1d7)

</td>
<td align="center">
    
[Xu_Zhisheng.webm](https://github.com/user-attachments/assets/dd812af9-76bd-4e26-9988-9cdb9ccbb87b)

</td>
</tr>
</table>


---

<table>
<tr>
<td align="center">
    
**哪吒 Nezha**
</td>
<td align="center">
    
**李靖 Li Jing**
</td>
</tr>

<tr>
<td align="center">

[Ne_Zha.webm](https://github.com/user-attachments/assets/8c608037-a17a-46d4-8588-4db34b49ed1d)
</td>
<td align="center">

[Li_Jing.webm](https://github.com/user-attachments/assets/aa8ba091-097c-4156-b4e3-6445da5ea101)

</td>
</tr>
</table>


## To-Do List

- [x] Release the Spark-TTS paper.
- [ ] Release the training code.
- [ ] Release the training dataset, VoxBox.


## Citation

```
@misc{wang2025sparktts,
      title={Spark-TTS: An Efficient LLM-Based Text-to-Speech Model with Single-Stream Decoupled Speech Tokens}, 
      author={Xinsheng Wang and Mingqi Jiang and Ziyang Ma and Ziyu Zhang and Songxiang Liu and Linqin Li and Zheng Liang and Qixi Zheng and Rui Wang and Xiaoqin Feng and Weizhen Bian and Zhen Ye and Sitong Cheng and Ruibin Yuan and Zhixian Zhao and Xinfa Zhu and Jiahao Pan and Liumeng Xue and Pengcheng Zhu and Yunlin Chen and Zhifei Li and Xie Chen and Lei Xie and Yike Guo and Wei Xue},
      year={2025},
      eprint={2503.01710},
      archivePrefix={arXiv},
      primaryClass={cs.SD},
      url={https://arxiv.org/abs/2503.01710}, 
}
```


## ⚠️ Usage Disclaimer

This project provides a zero-shot voice cloning TTS model intended for academic research, educational purposes, and legitimate applications, such as personalized speech synthesis, assistive technologies, and linguistic research.

Please note:

- Do not use this model for unauthorized voice cloning, impersonation, fraud, scams, deepfakes, or any illegal activities.

- Ensure compliance with local laws and regulations when using this model and uphold ethical standards.

- The developers assume no liability for any misuse of this model.

We advocate for the responsible development and use of AI and encourage the community to uphold safety and ethical principles in AI research and applications. If you have any concerns regarding ethics or misuse, please contact us.

================================================
FILE: cli/SparkTTS.py
================================================
# Copyright (c) 2025 SparkAudio
#               2025 Xinsheng Wang (w.xinshawn@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import torch
from typing import Tuple
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM

from sparktts.utils.file import load_config
from sparktts.models.audio_tokenizer import BiCodecTokenizer
from sparktts.utils.token_parser import LEVELS_MAP, GENDER_MAP, TASK_TOKEN_MAP


class SparkTTS:
    """
    Spark-TTS for text-to-speech generation.
    """

    def __init__(self, model_dir: Path, device: torch.device = torch.device("cuda:0")):
        """
        Initializes the SparkTTS model with the provided configurations and device.

        Args:
            model_dir (Path): Directory containing the model and config files.
            device (torch.device): The device (CPU/GPU) to run the model on.
        """
        self.device = device
        self.model_dir = model_dir
        self.configs = load_config(f"{model_dir}/config.yaml")
        self.sample_rate = self.configs["sample_rate"]
        self._initialize_inference()

    def _initialize_inference(self):
        """Initializes the tokenizer, model, and audio tokenizer for inference."""
        self.tokenizer = AutoTokenizer.from_pretrained(f"{self.model_dir}/LLM")
        self.model = AutoModelForCausalLM.from_pretrained(f"{self.model_dir}/LLM")
        self.audio_tokenizer = BiCodecTokenizer(self.model_dir, device=self.device)
        self.model.to(self.device)

    def process_prompt(
        self,
        text: str,
        prompt_speech_path: Path,
        prompt_text: str = None,
    ) -> Tuple[str, torch.Tensor]:
        """
        Process input for voice cloning.

        Args:
            text (str): The text input to be converted to speech.
            prompt_speech_path (Path): Path to the audio file used as a prompt.
            prompt_text (str, optional): Transcript of the prompt audio.

        Return:
            Tuple[str, torch.Tensor]: Input prompt; global tokens
        """

        global_token_ids, semantic_token_ids = self.audio_tokenizer.tokenize(
            prompt_speech_path
        )
        global_tokens = "".join(
            [f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()]
        )

        # Prepare the input tokens for the model
        if prompt_text is not None:
            semantic_tokens = "".join(
                [f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()]
            )
            inputs = [
                TASK_TOKEN_MAP["tts"],
                "<|start_content|>",
                prompt_text,
                text,
                "<|end_content|>",
                "<|start_global_token|>",
                global_tokens,
                "<|end_global_token|>",
                "<|start_semantic_token|>",
                semantic_tokens,
            ]
        else:
            inputs = [
                TASK_TOKEN_MAP["tts"],
                "<|start_content|>",
                text,
                "<|end_content|>",
                "<|start_global_token|>",
                global_tokens,
                "<|end_global_token|>",
            ]

        inputs = "".join(inputs)

        return inputs, global_token_ids

    def process_prompt_control(
        self,
        gender: str,
        pitch: str,
        speed: str,
        text: str,
    ):
        """
        Process input for voice creation.

        Args:
            gender (str): female | male.
            pitch (str): very_low | low | moderate | high | very_high
            speed (str): very_low | low | moderate | high | very_high
            text (str): The text input to be converted to speech.

        Return:
            str: Input prompt
        """
        assert gender in GENDER_MAP.keys()
        assert pitch in LEVELS_MAP.keys()
        assert speed in LEVELS_MAP.keys()

        gender_id = GENDER_MAP[gender]
        pitch_level_id = LEVELS_MAP[pitch]
        speed_level_id = LEVELS_MAP[speed]

        pitch_label_tokens = f"<|pitch_label_{pitch_level_id}|>"
        speed_label_tokens = f"<|speed_label_{speed_level_id}|>"
        gender_tokens = f"<|gender_{gender_id}|>"

        attribte_tokens = "".join(
            [gender_tokens, pitch_label_tokens, speed_label_tokens]
        )

        control_tts_inputs = [
            TASK_TOKEN_MAP["controllable_tts"],
            "<|start_content|>",
            text,
            "<|end_content|>",
            "<|start_style_label|>",
            attribte_tokens,
            "<|end_style_label|>",
        ]

        return "".join(control_tts_inputs)

    @torch.no_grad()
    def inference(
        self,
        text: str,
        prompt_speech_path: Path = None,
        prompt_text: str = None,
        gender: str = None,
        pitch: str = None,
        speed: str = None,
        temperature: float = 0.8,
        top_k: float = 50,
        top_p: float = 0.95,
    ) -> torch.Tensor:
        """
        Performs inference to generate speech from text, incorporating prompt audio and/or text.

        Args:
            text (str): The text input to be converted to speech.
            prompt_speech_path (Path): Path to the audio file used as a prompt.
            prompt_text (str, optional): Transcript of the prompt audio.
            gender (str): female | male.
            pitch (str): very_low | low | moderate | high | very_high
            speed (str): very_low | low | moderate | high | very_high
            temperature (float, optional): Sampling temperature for controlling randomness. Default is 0.8.
            top_k (float, optional): Top-k sampling parameter. Default is 50.
            top_p (float, optional): Top-p (nucleus) sampling parameter. Default is 0.95.

        Returns:
            torch.Tensor: Generated waveform as a tensor.
        """
        if gender is not None:
            prompt = self.process_prompt_control(gender, pitch, speed, text)

        else:
            prompt, global_token_ids = self.process_prompt(
                text, prompt_speech_path, prompt_text
            )
        model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)

        # Generate speech using the model
        generated_ids = self.model.generate(
            **model_inputs,
            max_new_tokens=3000,
            do_sample=True,
            top_k=top_k,
            top_p=top_p,
            temperature=temperature,
        )

        # Trim the output tokens to remove the input tokens
        generated_ids = [
            output_ids[len(input_ids) :]
            for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
        ]

        # Decode the generated tokens into text
        predicts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

        # Extract semantic token IDs from the generated text
        pred_semantic_ids = (
            torch.tensor([int(token) for token in re.findall(r"bicodec_semantic_(\d+)", predicts)])
            .long()
            .unsqueeze(0)
        )

        if gender is not None:
            global_token_ids = (
                torch.tensor([int(token) for token in re.findall(r"bicodec_global_(\d+)", predicts)])
                .long()
                .unsqueeze(0)
                .unsqueeze(0)
            )

        # Convert semantic tokens back to waveform
        wav = self.audio_tokenizer.detokenize(
            global_token_ids.to(self.device).squeeze(0),
            pred_semantic_ids.to(self.device),
        )

        return wav

================================================
FILE: cli/inference.py
================================================
# Copyright (c) 2025 SparkAudio
#               2025 Xinsheng Wang (w.xinshawn@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import os
import argparse
import torch
import soundfile as sf
import logging
from datetime import datetime
import platform

from cli.SparkTTS import SparkTTS


def parse_args():
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(description="Run TTS inference.")

    parser.add_argument(
        "--model_dir",
        type=str,
        default="pretrained_models/Spark-TTS-0.5B",
        help="Path to the model directory",
    )
    parser.add_argument(
        "--save_dir",
        type=str,
        default="example/results",
        help="Directory to save generated audio files",
    )
    parser.add_argument("--device", type=int, default=0, help="CUDA device number")
    parser.add_argument(
        "--text", type=str, required=True, help="Text for TTS generation"
    )
    parser.add_argument("--prompt_text", type=str, help="Transcript of prompt audio")
    parser.add_argument(
        "--prompt_speech_path",
        type=str,
        help="Path to the prompt audio file",
    )
    parser.add_argument("--gender", choices=["male", "female"])
    parser.add_argument(
        "--pitch", choices=["very_low", "low", "moderate", "high", "very_high"]
    )
    parser.add_argument(
        "--speed", choices=["very_low", "low", "moderate", "high", "very_high"]
    )
    return parser.parse_args()


def run_tts(args):
    """Perform TTS inference and save the generated audio."""
    logging.info(f"Using model from: {args.model_dir}")
    logging.info(f"Saving audio to: {args.save_dir}")

    # Ensure the save directory exists
    os.makedirs(args.save_dir, exist_ok=True)

    # Convert device argument to torch.device
    if platform.system() == "Darwin" and torch.backends.mps.is_available():
        # macOS with MPS support (Apple Silicon)
        device = torch.device(f"mps:{args.device}")
        logging.info(f"Using MPS device: {device}")
    elif torch.cuda.is_available():
        # System with CUDA support
        device = torch.device(f"cuda:{args.device}")
        logging.info(f"Using CUDA device: {device}")
    else:
        # Fall back to CPU
        device = torch.device("cpu")
        logging.info("GPU acceleration not available, using CPU")

    # Initialize the model
    model = SparkTTS(args.model_dir, device)

    # Generate unique filename using timestamp
    timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
    save_path = os.path.join(args.save_dir, f"{timestamp}.wav")

    logging.info("Starting inference...")

    # Perform inference and save the output audio
    with torch.no_grad():
        wav = model.inference(
            args.text,
            args.prompt_speech_path,
            prompt_text=args.prompt_text,
            gender=args.gender,
            pitch=args.pitch,
            speed=args.speed,
        )
        sf.write(save_path, wav, samplerate=16000)

    logging.info(f"Audio saved at: {save_path}")


if __name__ == "__main__":
    logging.basicConfig(
        level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
    )

    args = parse_args()
    run_tts(args)


================================================
FILE: example/infer.sh
================================================
#!/bin/bash

# Copyright (c) 2025 SparkAudio
#               2025 Xinsheng Wang (w.xinshawn@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# Get the absolute path of the script's directory
script_dir=$(dirname "$(realpath "$0")")

# Get the root directory
root_dir=$(dirname "$script_dir")

# Set default parameters
device=0
save_dir='example/results'
model_dir="pretrained_models/Spark-TTS-0.5B"
text="身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。"
prompt_text="吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。"
prompt_speech_path="example/prompt_audio.wav"

# Change directory to the root directory
cd "$root_dir" || exit

source sparktts/utils/parse_options.sh

# Run inference
python -m cli.inference \
    --text "${text}" \
    --device "${device}" \
    --save_dir "${save_dir}" \
    --model_dir "${model_dir}" \
    --prompt_text "${prompt_text}" \
    --prompt_speech_path "${prompt_speech_path}"
    
    

================================================
FILE: requirements.txt
================================================
einops==0.8.1
einx==0.3.0
numpy==2.2.3
omegaconf==2.3.0
packaging==24.2
safetensors==0.5.2
soundfile==0.12.1
soxr==0.5.0.post1
torch==2.5.1
torchaudio==2.5.1
tqdm==4.66.5
transformers==4.46.2
gradio==5.18.0

================================================
FILE: runtime/triton_trtllm/Dockerfile.server
================================================
FROM nvcr.io/nvidia/tritonserver:25.02-trtllm-python-py3
RUN apt-get update && apt-get install -y cmake
RUN git clone https://github.com/pytorch/audio.git && cd audio && git checkout c670ad8 && PATH=/usr/local/cuda/bin:$PATH python3 setup.py develop
RUN pip install einx==0.3.0 omegaconf==2.3.0 soundfile==0.12.1 soxr==0.5.0.post1 gradio tritonclient librosa
WORKDIR /workspace

================================================
FILE: runtime/triton_trtllm/README.md
================================================
## Nvidia Triton Inference Serving Best Practice for Spark TTS

### Quick Start
Directly launch the service using docker compose.
```sh
docker compose up
```

### Build Image
Build the docker image from scratch. 
```sh
docker build . -f Dockerfile.server -t soar97/triton-spark-tts:25.02
```

### Create Docker Container
```sh
your_mount_dir=/mnt:/mnt
docker run -it --name "spark-tts-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-spark-tts:25.02
```

### Understanding `run.sh`

The `run.sh` script automates various steps using stages. You can run specific stages using:
```sh
bash run.sh <start_stage> <stop_stage> [service_type]
```
- `<start_stage>`: The stage to begin execution from (0-5).
- `<stop_stage>`: The stage to end execution at (0-5).
- `[service_type]`: Optional, specifies the service type ('streaming' or 'offline', defaults may apply based on script logic). Required for stages 4 and 5.

Stages:
- **Stage 0**: Download Spark-TTS-0.5B model from HuggingFace.
- **Stage 1**: Convert HuggingFace checkpoint to TensorRT-LLM format and build TensorRT engines.
- **Stage 2**: Create the Triton model repository structure and configure model files (adjusts for streaming/offline).
- **Stage 3**: Launch the Triton Inference Server.
- **Stage 4**: Run the gRPC benchmark client.
- **Stage 5**: Run the single utterance client (gRPC for streaming, HTTP for offline).

### Export Models to TensorRT-LLM and Launch Server
Inside the docker container, you can prepare the models and launch the Triton server by running stages 0 through 3. This involves downloading the original model, converting it to TensorRT-LLM format, building the optimized TensorRT engines, creating the necessary model repository structure for Triton, and finally starting the server.
```sh
# This runs stages 0, 1, 2, and 3
bash run.sh 0 3
```
*Note: Stage 2 prepares the model repository differently based on whether you intend to run streaming or offline inference later. You might need to re-run stage 2 if switching service types.*


### Single Utterance Client
Run a single inference request. Specify `streaming` or `offline` as the third argument.

**Streaming Mode (gRPC):**
```sh
bash run.sh 5 5 streaming
```
This executes the `client_grpc.py` script with predefined example text and prompt audio in streaming mode.

**Offline Mode (HTTP):**
```sh
bash run.sh 5 5 offline
```

### Benchmark using Dataset
Run the benchmark client against the running Triton server. Specify `streaming` or `offline` as the third argument.
```sh
# Run benchmark in streaming mode
bash run.sh 4 4 streaming

# Run benchmark in offline mode
bash run.sh 4 4 offline

# You can also customize parameters like num_task directly in client_grpc.py or via args if supported
# Example from run.sh (streaming):
# python3 client_grpc.py \
#     --server-addr localhost \
#     --model-name spark_tts \
#     --num-tasks 2 \
#     --mode streaming \
#     --log-dir ./log_concurrent_tasks_2_streaming_new

# Example customizing dataset (requires modifying client_grpc.py or adding args):
# python3 client_grpc.py --num-tasks 2 --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --mode [streaming|offline]
```

### Benchmark Results
Decoding on a single L20 GPU, using 26 different prompt_audio/target_text [pairs](https://huggingface.co/datasets/yuekai/seed_tts), total audio duration 169 secs.

| Mode | Note   | Concurrency | Avg Latency     | First Chunk Latency (P50) |  RTF | 
|-------|-----------|-----------------------|---------|----------------|-|
| Offline | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 1                   | 876.24 ms |-| 0.1362|
| Offline | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 2                   | 920.97 ms |-|0.0737|
| Offline | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 4                   | 1611.51 ms |-| 0.0704|
| Streaming | [Code Commit](https://github.com/yuekaizhang/Spark-TTS/commit/0e978a327f99aa49f0735f86eb09372f16410d86) | 1                   | 913.28 ms |210.42 ms| 0.1501 |
| Streaming | [Code Commit](https://github.com/yuekaizhang/Spark-TTS/commit/0e978a327f99aa49f0735f86eb09372f16410d86) | 2                   | 1009.23 ms |226.08 ms |0.0862 |
| Streaming | [Code Commit](https://github.com/yuekaizhang/Spark-TTS/commit/0e978a327f99aa49f0735f86eb09372f16410d86) | 4                   | 1793.86 ms |1017.70 ms| 0.0824 |

================================================
FILE: runtime/triton_trtllm/client_grpc.py
================================================
#!/usr/bin/env python3
# Copyright      2022  Xiaomi Corp.        (authors: Fangjun Kuang)
#                2023  Nvidia              (authors: Yuekai Zhang)
#                2023  Recurrent.ai        (authors: Songtao Shi)
# See LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script supports to load dataset from huggingface and sends it to the server
for decoding, in parallel.

Usage:
num_task=2

# For offline F5-TTS
python3 client_grpc.py \
    --server-addr localhost \
    --model-name f5_tts \
    --num-tasks $num_task \
    --huggingface-dataset yuekai/seed_tts \
    --split-name test_zh \
    --log-dir ./log_concurrent_tasks_${num_task}

# For offline Spark-TTS-0.5B
python3 client_grpc.py \
    --server-addr localhost \
    --model-name spark_tts \
    --num-tasks $num_task \
    --huggingface-dataset yuekai/seed_tts \
    --split-name wenetspeech4tts \
    --log-dir ./log_concurrent_tasks_${num_task}
"""

import argparse
import asyncio
import json
import queue  # Added
import uuid  # Added
import functools # Added

import os
import time
import types
from pathlib import Path

import numpy as np
import soundfile as sf
import tritonclient
import tritonclient.grpc.aio as grpcclient_aio # Renamed original import
import tritonclient.grpc as grpcclient_sync # Added sync client import
from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException


# --- Added UserData and callback ---
class UserData:
    def __init__(self):
        self._completed_requests = queue.Queue()
        self._first_chunk_time = None
        self._start_time = None

    def record_start_time(self):
        self._start_time = time.time()

    def get_first_chunk_latency(self):
        if self._first_chunk_time and self._start_time:
            return self._first_chunk_time - self._start_time
        return None

def callback(user_data, result, error):
    if user_data._first_chunk_time is None and not error:
        user_data._first_chunk_time = time.time() # Record time of first successful chunk
    if error:
        user_data._completed_requests.put(error)
    else:
        user_data._completed_requests.put(result)
# --- End Added UserData and callback ---


def write_triton_stats(stats, summary_file):
    with open(summary_file, "w") as summary_f:
        model_stats = stats["model_stats"]
        # write a note, the log is from triton_client.get_inference_statistics(), to better human readability
        summary_f.write(
            "The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
        )
        summary_f.write("To learn more about the log, please refer to: \n")
        summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n")
        summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n")
        summary_f.write(
            "To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
        )
        summary_f.write(
            "However, there is a trade-off between the increased queue time and the increased batch size. \n"
        )
        summary_f.write(
            "You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n"
        )
        summary_f.write(
            "See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n"
        )
        for model_state in model_stats:
            if "last_inference" not in model_state:
                continue
            summary_f.write(f"model name is {model_state['name']} \n")
            model_inference_stats = model_state["inference_stats"]
            total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
            total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
            total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
            total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
            summary_f.write(
                f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n"  # noqa
            )
            model_batch_stats = model_state["batch_stats"]
            for batch in model_batch_stats:
                batch_size = int(batch["batch_size"])
                compute_input = batch["compute_input"]
                compute_output = batch["compute_output"]
                compute_infer = batch["compute_infer"]
                batch_count = int(compute_infer["count"])
                assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
                compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
                compute_input_time_ms = int(compute_input["ns"]) / 1e6
                compute_output_time_ms = int(compute_output["ns"]) / 1e6
                summary_f.write(
                    f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n"  # noqa
                )
                summary_f.write(
                    f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, "  # noqa
                )
                summary_f.write(
                    f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n"  # noqa
                )


def get_args():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument(
        "--server-addr",
        type=str,
        default="localhost",
        help="Address of the server",
    )

    parser.add_argument(
        "--server-port",
        type=int,
        default=8001,
        help="Grpc port of the triton server, default is 8001",
    )

    parser.add_argument(
        "--reference-audio",
        type=str,
        default=None,
        help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
    )

    parser.add_argument(
        "--reference-text",
        type=str,
        default="",
        help="",
    )

    parser.add_argument(
        "--target-text",
        type=str,
        default="",
        help="",
    )

    parser.add_argument(
        "--huggingface-dataset",
        type=str,
        default="yuekai/seed_tts",
        help="dataset name in huggingface dataset hub",
    )

    parser.add_argument(
        "--split-name",
        type=str,
        default="wenetspeech4tts",
        choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
        help="dataset split name, default is 'test'",
    )

    parser.add_argument(
        "--manifest-path",
        type=str,
        default=None,
        help="Path to the manifest dir which includes wav.scp trans.txt files.",
    )

    parser.add_argument(
        "--model-name",
        type=str,
        default="f5_tts",
        choices=["f5_tts", "spark_tts"],
        help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
    )

    parser.add_argument(
        "--num-tasks",
        type=int,
        default=1,
        help="Number of concurrent tasks for sending",
    )

    parser.add_argument(
        "--log-interval",
        type=int,
        default=5,
        help="Controls how frequently we print the log.",
    )

    parser.add_argument(
        "--compute-wer",
        action="store_true",
        default=False,
        help="""True to compute WER.
        """,
    )

    parser.add_argument(
        "--log-dir",
        type=str,
        required=False,
        default="./tmp",
        help="log directory",
    )

    # --- Added arguments ---
    parser.add_argument(
        "--mode",
        type=str,
        default="offline",
        choices=["offline", "streaming"],
        help="Select offline or streaming benchmark mode."
    )
    parser.add_argument(
        "--chunk-overlap-duration",
        type=float,
        default=0.1,
        help="Chunk overlap duration for streaming reconstruction (in seconds)."
    )
    # --- End Added arguments ---

    return parser.parse_args()


def load_audio(wav_path, target_sample_rate=16000):
    assert target_sample_rate == 16000, "hard coding in server"
    if isinstance(wav_path, dict):
        waveform = wav_path["array"]
        sample_rate = wav_path["sampling_rate"]
    else:
        waveform, sample_rate = sf.read(wav_path)
    if sample_rate != target_sample_rate:
        from scipy.signal import resample

        num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
        waveform = resample(waveform, num_samples)
    return waveform, target_sample_rate

def prepare_request_input_output(
    protocol_client, # Can be grpcclient_aio or grpcclient_sync
    waveform,
    reference_text,
    target_text,
    sample_rate=16000,
    padding_duration: int = None # Optional padding for offline mode
):
    """Prepares inputs for Triton inference (offline or streaming)."""
    assert len(waveform.shape) == 1, "waveform should be 1D"
    lengths = np.array([[len(waveform)]], dtype=np.int32)

    # Apply padding only if padding_duration is provided (for offline)
    if padding_duration:
        duration = len(waveform) / sample_rate
        # Estimate target duration based on text length ratio (crude estimation)
        # Avoid division by zero if reference_text is empty
        if reference_text:
             estimated_target_duration = duration / len(reference_text) * len(target_text)
        else:
             estimated_target_duration = duration # Assume target duration similar to reference if no text

        # Calculate required samples based on estimated total duration
        required_total_samples = padding_duration * sample_rate * (
            (int(estimated_target_duration + duration) // padding_duration) + 1
        )
        samples = np.zeros((1, required_total_samples), dtype=np.float32)
        samples[0, : len(waveform)] = waveform
    else:
        # No padding for streaming or if padding_duration is None
        samples = waveform.reshape(1, -1).astype(np.float32)

    # Common input creation logic
    inputs = [
        protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
        protocol_client.InferInput(
            "reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)
        ),
        protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
        protocol_client.InferInput("target_text", [1, 1], "BYTES"),
    ]
    inputs[0].set_data_from_numpy(samples)
    inputs[1].set_data_from_numpy(lengths)

    input_data_numpy = np.array([reference_text], dtype=object)
    input_data_numpy = input_data_numpy.reshape((1, 1))
    inputs[2].set_data_from_numpy(input_data_numpy)

    input_data_numpy = np.array([target_text], dtype=object)
    input_data_numpy = input_data_numpy.reshape((1, 1))
    inputs[3].set_data_from_numpy(input_data_numpy)

    outputs = [protocol_client.InferRequestedOutput("waveform")]

    return inputs, outputs

def run_sync_streaming_inference(
    sync_triton_client: tritonclient.grpc.InferenceServerClient,
    model_name: str,
    inputs: list,
    outputs: list,
    request_id: str,
    user_data: UserData,
    chunk_overlap_duration: float,
    save_sample_rate: int,
    audio_save_path: str,
):
    """Helper function to run the blocking sync streaming call."""
    start_time_total = time.time()
    user_data.record_start_time() # Record start time for first chunk latency calculation

    # Establish stream
    sync_triton_client.start_stream(callback=functools.partial(callback, user_data))

    # Send request
    sync_triton_client.async_stream_infer(
        model_name,
        inputs,
        request_id=request_id,
        outputs=outputs,
        enable_empty_final_response=True,
    )

    # Process results
    audios = []
    while True:
        try:
            result = user_data._completed_requests.get() # Add timeout
            if isinstance(result, InferenceServerException):
                print(f"Received InferenceServerException: {result}")
                sync_triton_client.stop_stream()
                return None, None, None # Indicate error
            # Get response metadata
            response = result.get_response()
            final = response.parameters["triton_final_response"].bool_param
            if final is True:
                break

            audio_chunk = result.as_numpy("waveform").reshape(-1)
            if audio_chunk.size > 0: # Only append non-empty chunks
                 audios.append(audio_chunk)
            else:
                print("Warning: received empty audio chunk.")

        except queue.Empty:
            print(f"Timeout waiting for response for request id {request_id}")
            sync_triton_client.stop_stream()
            return None, None, None # Indicate error

    sync_triton_client.stop_stream()
    end_time_total = time.time()
    total_request_latency = end_time_total - start_time_total
    first_chunk_latency = user_data.get_first_chunk_latency()

    # Reconstruct audio using cross-fade (from client_grpc_streaming.py)
    actual_duration = 0
    if audios:
        cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
        fade_out = np.linspace(1, 0, cross_fade_samples)
        fade_in = np.linspace(0, 1, cross_fade_samples)
        reconstructed_audio = None

        # Simplified reconstruction based on client_grpc_streaming.py
        if not audios:
            print("Warning: No audio chunks received.")
            reconstructed_audio = np.array([], dtype=np.float32) # Empty array
        elif len(audios) == 1:
            reconstructed_audio = audios[0]
        else:
            reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap
            for i in range(1, len(audios)):
                 # Cross-fade section
                 cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
                                        audios[i - 1][-cross_fade_samples:] * fade_out)
                 # Middle section of the current chunk
                 middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
                 # Concatenate
                 reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
            # Add the last part of the final chunk
            reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])

        if reconstructed_audio is not None and reconstructed_audio.size > 0:
            actual_duration = len(reconstructed_audio) / save_sample_rate
            # Save reconstructed audio
            os.makedirs(os.path.dirname(audio_save_path), exist_ok=True)
            sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
        else:
            print("Warning: No audio chunks received or reconstructed.")
            actual_duration = 0 # Set duration to 0 if no audio

    else:
         print("Warning: No audio chunks received.")
         actual_duration = 0

    return total_request_latency, first_chunk_latency, actual_duration


async def send_streaming(
    manifest_item_list: list,
    name: str,
    server_url: str, # Changed from sync_triton_client
    protocol_client: types.ModuleType,
    log_interval: int,
    model_name: str,
    audio_save_dir: str = "./",
    save_sample_rate: int = 16000,
    chunk_overlap_duration: float = 0.1,
    padding_duration: int = None,
):
    total_duration = 0.0
    latency_data = []
    task_id = int(name[5:])
    sync_triton_client = None # Initialize client variable

    try: # Wrap in try...finally to ensure client closing
        print(f"{name}: Initializing sync client for streaming...")
        sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here

        print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
        for i, item in enumerate(manifest_item_list):
            if i % log_interval == 0:
                print(f"{name}: Processing item {i}/{len(manifest_item_list)}")

            try:
                waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
                reference_text, target_text = item["reference_text"], item["target_text"]

                inputs, outputs = prepare_request_input_output(
                    protocol_client,
                    waveform,
                    reference_text,
                    target_text,
                    sample_rate,
                    padding_duration=padding_duration
                )
                request_id = str(uuid.uuid4())
                user_data = UserData()

                audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")

                total_request_latency, first_chunk_latency, actual_duration = await asyncio.to_thread(
                    run_sync_streaming_inference,
                    sync_triton_client,
                    model_name,
                    inputs,
                    outputs,
                    request_id,
                    user_data,
                    chunk_overlap_duration,
                    save_sample_rate,
                    audio_save_path
                )

                if total_request_latency is not None:
                    print(f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s")
                    latency_data.append((total_request_latency, first_chunk_latency, actual_duration))
                    total_duration += actual_duration
                else:
                     print(f"{name}: Item {i} failed.")


            except FileNotFoundError:
                print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
            except Exception as e:
                print(f"Error processing item {i} ({item['target_audio_path']}): {e}")
                import traceback
                traceback.print_exc()


    finally: # Ensure client is closed
        if sync_triton_client:
            try:
                print(f"{name}: Closing sync client...")
                sync_triton_client.close()
            except Exception as e:
                print(f"{name}: Error closing sync client: {e}")


    print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s")
    return total_duration, latency_data

async def send(
    manifest_item_list: list,
    name: str,
    triton_client: tritonclient.grpc.aio.InferenceServerClient,
    protocol_client: types.ModuleType,
    log_interval: int,
    model_name: str,
    padding_duration: int = None,
    audio_save_dir: str = "./",
    save_sample_rate: int = 16000,
):
    total_duration = 0.0
    latency_data = []
    task_id = int(name[5:])

    print(f"manifest_item_list: {manifest_item_list}")
    for i, item in enumerate(manifest_item_list):
        if i % log_interval == 0:
            print(f"{name}: {i}/{len(manifest_item_list)}")
        waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
        reference_text, target_text = item["reference_text"], item["target_text"]

        inputs, outputs = prepare_request_input_output(
            protocol_client,
            waveform,
            reference_text,
            target_text,
            sample_rate,
            padding_duration=padding_duration
        )
        sequence_id = 100000000 + i + task_id * 10
        start = time.time()
        response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs)

        audio = response.as_numpy("waveform").reshape(-1)
        actual_duration = len(audio) / save_sample_rate

        end = time.time() - start

        audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
        sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")

        latency_data.append((end, actual_duration))
        total_duration += actual_duration

    return total_duration, latency_data


def load_manifests(manifest_path):
    with open(manifest_path, "r") as f:
        manifest_list = []
        for line in f:
            assert len(line.strip().split("|")) == 4
            utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
            utt = Path(utt).stem
            # gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav")
            if not os.path.isabs(prompt_wav):
                prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
            manifest_list.append(
                {
                    "audio_filepath": prompt_wav,
                    "reference_text": prompt_text,
                    "target_text": gt_text,
                    "target_audio_path": utt,
                }
            )
    return manifest_list


def split_data(data, k):
    n = len(data)
    if n < k:
        print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.")
        k = n

    quotient = n // k
    remainder = n % k

    result = []
    start = 0
    for i in range(k):
        if i < remainder:
            end = start + quotient + 1
        else:
            end = start + quotient

        result.append(data[start:end])
        start = end

    return result

async def main():
    args = get_args()
    url = f"{args.server_addr}:{args.server_port}"

    # --- Client Initialization based on mode ---
    triton_client = None
    protocol_client = None
    if args.mode == "offline":
        print("Initializing gRPC client for offline mode...")
        # Use the async client for offline tasks
        triton_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
        protocol_client = grpcclient_aio
    elif args.mode == "streaming":
        print("Initializing gRPC client for streaming mode...")
        # Use the sync client for streaming tasks, handled via asyncio.to_thread
        # We will create one sync client instance PER TASK inside send_streaming.
        # triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now
        protocol_client = grpcclient_sync # protocol client for input prep
    else:
        raise ValueError(f"Invalid mode: {args.mode}")
    # --- End Client Initialization ---

    if args.reference_audio:
        args.num_tasks = 1
        args.log_interval = 1
        manifest_item_list = [
            {
                "reference_text": args.reference_text,
                "target_text": args.target_text,
                "audio_filepath": args.reference_audio,
                "target_audio_path": "test",
            }
        ]
    elif args.huggingface_dataset:
        import datasets

        dataset = datasets.load_dataset(
            args.huggingface_dataset,
            split=args.split_name,
            trust_remote_code=True,
        )
        manifest_item_list = []
        for i in range(len(dataset)):
            manifest_item_list.append(
                {
                    "audio_filepath": dataset[i]["prompt_audio"],
                    "reference_text": dataset[i]["prompt_text"],
                    "target_audio_path": dataset[i]["id"],
                    "target_text": dataset[i]["target_text"],
                }
            )
    else:
        manifest_item_list = load_manifests(args.manifest_path)

    num_tasks = min(args.num_tasks, len(manifest_item_list))
    manifest_item_list = split_data(manifest_item_list, num_tasks)

    os.makedirs(args.log_dir, exist_ok=True)
    tasks = []
    start_time = time.time()
    for i in range(num_tasks):
        # --- Task Creation based on mode ---
        if args.mode == "offline":
            task = asyncio.create_task(
                send(
                    manifest_item_list[i],
                    name=f"task-{i}",
                    triton_client=triton_client,
                    protocol_client=protocol_client,
                    log_interval=args.log_interval,
                    model_name=args.model_name,
                    audio_save_dir=args.log_dir,
                    padding_duration=1,
                    save_sample_rate=24000 if args.model_name == "f5_tts" else 16000,
                )
            )
        elif args.mode == "streaming":
             task = asyncio.create_task(
                send_streaming(
                    manifest_item_list[i],
                    name=f"task-{i}",
                    server_url=url, # Pass URL instead of client
                    protocol_client=protocol_client,
                    log_interval=args.log_interval,
                    model_name=args.model_name,
                    audio_save_dir=args.log_dir,
                    padding_duration=10,
                    save_sample_rate=24000 if args.model_name == "f5_tts" else 16000,
                    chunk_overlap_duration=args.chunk_overlap_duration,
                )
            )
        # --- End Task Creation ---
        tasks.append(task)

    ans_list = await asyncio.gather(*tasks)

    end_time = time.time()
    elapsed = end_time - start_time

    total_duration = 0.0
    latency_data = []
    for ans in ans_list:
        if ans:
            total_duration += ans[0]
            latency_data.extend(ans[1]) # Use extend for list of lists
        else:
             print("Warning: A task returned None, possibly due to an error.")


    if total_duration == 0:
        print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.")
        rtf = float('inf')
    else:
         rtf = elapsed / total_duration

    s = f"Mode: {args.mode}\n"
    s += f"RTF: {rtf:.4f}\n"
    s += f"total_duration: {total_duration:.3f} seconds\n"
    s += f"({total_duration / 3600:.2f} hours)\n"
    s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"

    # --- Statistics Reporting based on mode ---
    if latency_data:
        if args.mode == "offline":
            # Original offline latency calculation
            latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
            if latency_list:
                latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
                latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
                s += f"latency_variance: {latency_variance:.2f}\n"
                s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
                s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
                s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n"
                s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
                s += f"average_latency_ms: {latency_ms:.2f}\n"
            else:
                s += "No latency data collected for offline mode.\n"

        elif args.mode == "streaming":
            # Calculate stats for total request latency and first chunk latency
            total_latency_list = [total for (total, first, duration) in latency_data if total is not None]
            first_chunk_latency_list = [first for (total, first, duration) in latency_data if first is not None]

            s += "\n--- Total Request Latency ---\n"
            if total_latency_list:
                avg_total_latency_ms = sum(total_latency_list) / len(total_latency_list) * 1000.0
                variance_total_latency = np.var(total_latency_list, dtype=np.float64) * 1000.0
                s += f"total_request_latency_variance: {variance_total_latency:.2f}\n"
                s += f"total_request_latency_50_percentile_ms: {np.percentile(total_latency_list, 50) * 1000.0:.2f}\n"
                s += f"total_request_latency_90_percentile_ms: {np.percentile(total_latency_list, 90) * 1000.0:.2f}\n"
                s += f"total_request_latency_95_percentile_ms: {np.percentile(total_latency_list, 95) * 1000.0:.2f}\n"
                s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n"
                s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n"
            else:
                 s += "No total request latency data collected.\n"

            s += "\n--- First Chunk Latency ---\n"
            if first_chunk_latency_list:
                avg_first_chunk_latency_ms = sum(first_chunk_latency_list) / len(first_chunk_latency_list) * 1000.0
                variance_first_chunk_latency = np.var(first_chunk_latency_list, dtype=np.float64) * 1000.0
                s += f"first_chunk_latency_variance: {variance_first_chunk_latency:.2f}\n"
                s += f"first_chunk_latency_50_percentile_ms: {np.percentile(first_chunk_latency_list, 50) * 1000.0:.2f}\n"
                s += f"first_chunk_latency_90_percentile_ms: {np.percentile(first_chunk_latency_list, 90) * 1000.0:.2f}\n"
                s += f"first_chunk_latency_95_percentile_ms: {np.percentile(first_chunk_latency_list, 95) * 1000.0:.2f}\n"
                s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n"
                s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
            else:
                 s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
    else:
        s += "No latency data collected.\n"
    # --- End Statistics Reporting ---

    print(s)
    if args.manifest_path:
        name = Path(args.manifest_path).stem
    elif args.split_name:
        name = args.split_name
    elif args.reference_audio:
        name = Path(args.reference_audio).stem
    else:
        name = "results" # Default name if no manifest/split/audio provided
    with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
        f.write(s)

    # --- Statistics Fetching using temporary Async Client ---
    # Use a separate async client for fetching stats regardless of mode
    stats_client = None
    try:
        print("Initializing temporary async client for fetching stats...")
        stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
        print("Fetching inference statistics...")
        # Fetching for all models, filtering might be needed depending on server setup
        stats = await stats_client.get_inference_statistics(model_name="", as_json=True)
        print("Fetching model config...")
        metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)

        write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")

        with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
            json.dump(metadata, f, indent=4)

    except Exception as e:
        print(f"Could not retrieve statistics or config: {e}")
    finally:
        if stats_client:
            try:
                print("Closing temporary async stats client...")
                await stats_client.close()
            except Exception as e:
                print(f"Error closing async stats client: {e}")
    # --- End Statistics Fetching ---


if __name__ == "__main__":
    # asyncio.run(main()) # Use TaskGroup for better exception handling if needed
    async def run_main():
        try:
            await main()
        except Exception as e:
            print(f"An error occurred in main: {e}")
            import traceback
            traceback.print_exc()

    asyncio.run(run_main())


================================================
FILE: runtime/triton_trtllm/client_http.py
================================================
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#  * Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
#  * Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#  * Neither the name of NVIDIA CORPORATION nor the names of its
#    contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import requests
import soundfile as sf
import json
import numpy as np
import argparse

def get_args():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    parser.add_argument(
        "--server-url",
        type=str,
        default="localhost:8000",
        help="Address of the server",
    )

    parser.add_argument(
        "--reference-audio",
        type=str,
        default="../../example/prompt_audio.wav",
        help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
    )

    parser.add_argument(
        "--reference-text",
        type=str,
        default="吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。",
        help="",
    )

    parser.add_argument(
        "--target-text",
        type=str,
        default="身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。",
        help="",
    )

    parser.add_argument(
        "--model-name",
        type=str,
        default="spark_tts",
        choices=[
            "f5_tts", "spark_tts"
        ],
        help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
    )

    parser.add_argument(
        "--output-audio",
        type=str,
        default="output.wav",
        help="Path to save the output audio",
    )
    return parser.parse_args()

def prepare_request(
    waveform,
    reference_text,
    target_text,
    sample_rate=16000,
    padding_duration: int = None,
    audio_save_dir: str = "./",
):
    assert len(waveform.shape) == 1, "waveform should be 1D"
    lengths = np.array([[len(waveform)]], dtype=np.int32)
    if padding_duration:
        # padding to nearset 10 seconds
        samples = np.zeros(
            (
                1,
                padding_duration
                * sample_rate
                * ((int(duration) // padding_duration) + 1),
            ),
            dtype=np.float32,
        )

        samples[0, : len(waveform)] = waveform
    else:
        samples = waveform
        
    samples = samples.reshape(1, -1).astype(np.float32)

    data = {
        "inputs":[
            {
                "name": "reference_wav",
                "shape": samples.shape,
                "datatype": "FP32",
                "data": samples.tolist()
            },
            {
                "name": "reference_wav_len",
                "shape": lengths.shape,
                "datatype": "INT32",
                "data": lengths.tolist(),
            },
            {
                "name": "reference_text",
                "shape": [1, 1],
                "datatype": "BYTES",
                "data": [reference_text]
            },
            {
                "name": "target_text",
                "shape": [1, 1],
                "datatype": "BYTES",
                "data": [target_text]
            }
        ]
    }

    return data

if __name__ == "__main__":
    args = get_args()
    server_url = args.server_url
    if not server_url.startswith(("http://", "https://")):
        server_url = f"http://{server_url}"
    
    url = f"{server_url}/v2/models/{args.model_name}/infer"
    waveform, sr = sf.read(args.reference_audio)
    assert sr == 16000, "sample rate hardcoded in server"
    
    samples = np.array(waveform, dtype=np.float32)
    data = prepare_request(samples, args.reference_text, args.target_text)

    rsp = requests.post(
        url,
        headers={"Content-Type": "application/json"},
        json=data,
        verify=False,
        params={"request_id": '0'}
    )
    result = rsp.json()
    audio = result["outputs"][0]["data"]
    audio = np.array(audio, dtype=np.float32)
    sf.write(args.output_audio, audio, 16000, "PCM_16")

================================================
FILE: runtime/triton_trtllm/docker-compose.yml
================================================
services:
  tts:
    image: soar97/triton-spark-tts:25.02
    shm_size: '1gb'
    ports:
      - "8000:8000"
      - "8001:8001"
      - "8002:8002"
    environment:
      - PYTHONIOENCODING=utf-8
      - MODEL_ID=${MODEL_ID}
    deploy:
      resources:
        reservations:
          devices:
            - driver: nvidia
              device_ids: ['0']
              capabilities: [gpu]
    command: >
      /bin/bash -c "rm -rf Spark-TTS && git clone https://github.com/SparkAudio/Spark-TTS.git && cd Spark-TTS/runtime/triton_trtllm && bash run.sh 0 3"


================================================
FILE: runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py
================================================
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#  * Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
#  * Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#  * Neither the name of NVIDIA CORPORATION nor the names of its
#    contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import torch
from torch.utils.dlpack import to_dlpack

import triton_python_backend_utils as pb_utils

import os
import numpy as np

from sparktts.models.audio_tokenizer import BiCodecTokenizer

class TritonPythonModel:
    """Triton Python model for audio tokenization.
    
    This model takes reference audio input and extracts semantic and global tokens
    using BiCodec tokenizer.
    """

    def initialize(self, args):
        """Initialize the model.
        
        Args:
            args: Dictionary containing model configuration
        """
        # Parse model parameters
        parameters = json.loads(args['model_config'])['parameters']
        model_params = {k: v["string_value"] for k, v in parameters.items()}
        
        # Initialize tokenizer
        self.device = torch.device("cuda")
        self.audio_tokenizer = BiCodecTokenizer(model_params["model_dir"], 
                                              device=self.device)

    def get_ref_clip(self, wav: np.ndarray) -> np.ndarray:
        """Extract reference audio clip for speaker embedding.
        
        Args:
            wav: Input waveform array
            
        Returns:
            Reference clip of fixed duration
        """
        SAMPLE_RATE = 16000
        REF_SEGMENT_DURATION = 6  # seconds
        LATENT_HOP_LENGTH = 320

        ref_segment_length = (
            int(SAMPLE_RATE * REF_SEGMENT_DURATION)
            // LATENT_HOP_LENGTH
            * LATENT_HOP_LENGTH
        )
        wav_length = len(wav)

        if ref_segment_length > wav_length:
            # Repeat and truncate if input is too short
            repeat_times = ref_segment_length // wav_length + 1
            wav = np.tile(wav, repeat_times)

        return wav[:ref_segment_length]

    def execute(self, requests):
        """Execute inference on the batched requests.
        
        Args:
            requests: List of inference requests
            
        Returns:
            List of inference responses containing tokenized outputs
        """
        reference_wav_list = []
        reference_wav_ref_clip_list = []

        # Process each request in batch
        for request in requests:
            # Extract input tensors
            wav_array = pb_utils.get_input_tensor_by_name(
                request, "reference_wav").as_numpy()
            wav_len = pb_utils.get_input_tensor_by_name(
                request, "reference_wav_len").as_numpy().item()

            # Prepare inputs
            wav = wav_array[:, :wav_len].squeeze(0)
            reference_wav_list.append(wav)
            
            wav_ref_clip = self.get_ref_clip(wav)
            reference_wav_ref_clip_list.append(torch.from_numpy(wav_ref_clip))

        # Batch process through tokenizer
        ref_wav_clip_tensor = torch.stack(reference_wav_ref_clip_list, dim=0)
        wav2vec2_features = self.audio_tokenizer.extract_wav2vec2_features(
            reference_wav_list)
        
        audio_tokenizer_input = {
            "ref_wav": ref_wav_clip_tensor.to(self.device),
            "feat": wav2vec2_features.to(self.device),
        }
        semantic_tokens, global_tokens = self.audio_tokenizer.model.tokenize(
            audio_tokenizer_input)

        # Prepare responses
        responses = []
        for i in range(len(requests)):
            global_tokens_tensor = pb_utils.Tensor.from_dlpack(
                "global_tokens", to_dlpack(global_tokens[i]))
            semantic_tokens_tensor = pb_utils.Tensor.from_dlpack(
                "semantic_tokens", to_dlpack(semantic_tokens[i]))
            
            inference_response = pb_utils.InferenceResponse(
                output_tensors=[global_tokens_tensor, semantic_tokens_tensor])
            responses.append(inference_response)
                             
        return responses


================================================
FILE: runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt
================================================
# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

name: "audio_tokenizer"
backend: "python"
max_batch_size: ${triton_max_batch_size}
dynamic_batching {
    max_queue_delay_microseconds: ${max_queue_delay_microseconds}
}
parameters [
  {
   key: "model_dir", 
   value: {string_value:"${model_dir}"}
  }
]

input [
  {
    name: "reference_wav"
    data_type: TYPE_FP32
    dims: [-1]
  },
  {
    name: "reference_wav_len"
    data_type: TYPE_INT32
    dims: [1]
  }
]
output [
  {
    name: "global_tokens"
    data_type: TYPE_INT32
    dims: [-1]
  },
  {
    name: "semantic_tokens"
    data_type: TYPE_INT32
    dims: [-1]
  }
]

instance_group [
  {
    count: 1
    kind: KIND_CPU
  }
]

================================================
FILE: runtime/triton_trtllm/model_repo/spark_tts/1/model.py
================================================
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#  * Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
#  * Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#  * Neither the name of NVIDIA CORPORATION nor the names of its
#    contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import json
import math
import os
import re
from typing import Dict, List, Tuple, Optional, Union

import numpy as np
import torch
from torch.utils.dlpack import from_dlpack, to_dlpack
import triton_python_backend_utils as pb_utils
from transformers import AutoTokenizer

from sparktts.utils.token_parser import TASK_TOKEN_MAP

def process_prompt(
    text: str,
    prompt_text: Optional[str] = None,
    global_token_ids: torch.Tensor = None,
    semantic_token_ids: torch.Tensor = None,
) -> Tuple[str, torch.Tensor]:
    """
    Process input for voice cloning.

    Args:
        text: The text input to be converted to speech.
        prompt_text: Transcript of the prompt audio.
        global_token_ids: Global token IDs extracted from reference audio.
        semantic_token_ids: Semantic token IDs extracted from reference audio.

    Returns:
        Tuple containing the formatted input prompt and global token IDs.
    """
    # Convert global tokens to string format
    global_tokens = "".join(
        [f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()]
    )

    
    # Prepare the input tokens for the model
    if prompt_text is not None:
        # Include semantic tokens when prompt text is provided
        semantic_tokens = "".join(
            [f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()]
        )

        inputs = [
            TASK_TOKEN_MAP["tts"],
            "<|start_content|>",
            prompt_text,
            text,
            "<|end_content|>",
            "<|start_global_token|>",
            global_tokens,
            "<|end_global_token|>",
            "<|start_semantic_token|>",
            semantic_tokens,
        ]
    else:
        # Without prompt text, exclude semantic tokens
        inputs = [
            TASK_TOKEN_MAP["tts"],
            "<|start_content|>",
            text,
            "<|end_content|>",
            "<|start_global_token|>",
            global_tokens,
            "<|end_global_token|>",
        ]

    # Join all input components into a single string
    inputs = "".join(inputs)
    return inputs, global_token_ids


class TritonPythonModel:
    """Triton Python model for Spark TTS.
    
    This model orchestrates the end-to-end TTS pipeline by coordinating
    between audio tokenizer, LLM, and vocoder components.
    """
    
    def initialize(self, args):
        """Initialize the model.
        
        Args:
            args: Dictionary containing model configuration
        """
        self.logger = pb_utils.Logger
        # Parse model parameters
        self.model_config = json.loads(args['model_config'])
        parameters = self.model_config['parameters']
        model_params = {k: v["string_value"] for k, v in parameters.items()}
        self.logger.log_info(f"model_params:{model_params}")
        # streaming TTS parameters
        assert (
            float(model_params["audio_chunk_duration"]) >= 0.5
        ), f"audio_chunk_duration at least 0.5 seconds"
        self.audio_chunk_duration = float(model_params["audio_chunk_duration"])
        self.max_audio_chunk_duration = float(model_params["max_audio_chunk_duration"])
        assert (
            float(model_params["audio_chunk_size_scale_factor"]) >= 1.0
        ), "audio_chunk_size_scale_factor should be greater than 1, change it according to your actual rtf"
        self.audio_chunk_size_scale_factor = float(model_params["audio_chunk_size_scale_factor"])  # scale speed
        self.audio_chunk_overlap_duration = float(model_params["audio_chunk_overlap_duration"])
        self.audio_tokenizer_frame_rate = int(model_params["audio_tokenizer_frame_rate"])

        # Initialize tokenizer
        llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
        self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
        self.device = torch.device("cuda")
        self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)

    def forward_llm(self, input_ids):
        """
        Prepares the response from the language model based on the provided
        inputs. Creates a `pb_utils.InferenceRequest` object with passed
        `llm_request_inputs` to send to a decoupled TensorRTLLM model.
        For each response from the language model:
            - Checks for errors and raise an exception if any are found.
            - Extracts the "output_ids" tensor from the response.
            - Determines the finish reason based on the presence of the
              end-of-sequence token or reaching the maximum length.
            - Appends the generated token IDs to `output_ids`.
            - If the finish reason is determined, decodes the output IDs to text
              and prepares the final response.

        The final response includes the generated text, finish reason,
        completion tokens, prompt tokens, and total tokens.

        Parameters
        ----------
        - llm_request_inputs (dict): A dictionary containing the inputs for the language model.

        Returns
        -------
        - pb_utils.InferenceResponse: The response object containing the generated text and additional metadata.
        """
        # convert input_ids to numpy, with shape [1, sequence_length]
        input_ids = input_ids.cpu().numpy()
        max_tokens = 512
        input_dict = {
            "request_output_len": np.array([[max_tokens]], dtype=np.int32),
            "end_id": np.array([[self.tokenizer.eos_token_id]], dtype=np.int32),
            "pad_id": np.array([[self.tokenizer.pad_token_id]], dtype=np.int32),
            "streaming": np.array([[self.decoupled]], dtype=np.bool_),
            "runtime_top_p": np.array([[0.95]], dtype=np.float32),
            "runtime_top_k": np.array([[50]], dtype=np.int32),
            "temperature": np.array([[0.8]], dtype=np.float32),
            "input_ids": input_ids,
            "input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
        }
        
        # Convert inputs to Triton tensors
        input_tensor_list = [
            pb_utils.Tensor(k, v) for k, v in input_dict.items()
        ]
        
        # Create and execute inference request
        llm_request = pb_utils.InferenceRequest(
            model_name="tensorrt_llm",
            requested_output_names=["output_ids", "sequence_length"],
            inputs=input_tensor_list,
        )
        
        llm_responses = llm_request.exec(decoupled=self.decoupled)
        if self.decoupled:
            for llm_response in llm_responses:
                if llm_response.has_error():
                    raise pb_utils.TritonModelException(llm_response.error().message())
                
                # Extract and process output
                output_ids = pb_utils.get_output_tensor_by_name(
                    llm_response, "output_ids").as_numpy()
                seq_lens = pb_utils.get_output_tensor_by_name(
                    llm_response, "sequence_length").as_numpy()
                
                # Get actual output IDs up to the sequence length
                actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
                
                yield actual_output_ids
        else:
            llm_response = llm_responses
            if llm_response.has_error():
                raise pb_utils.TritonModelException(llm_response.error().message())
            
            # Extract and process output
            output_ids = pb_utils.get_output_tensor_by_name(
                llm_response, "output_ids").as_numpy()
            seq_lens = pb_utils.get_output_tensor_by_name(
                llm_response, "sequence_length").as_numpy()
            
            # Get actual output IDs up to the sequence length
            actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
            
            yield actual_output_ids    
                
    def forward_audio_tokenizer(self, wav, wav_len):
        """Forward pass through the audio tokenizer component.
        
        Args:
            wav: Input waveform tensor
            wav_len: Waveform length tensor
            
        Returns:
            Tuple of global and semantic tokens
        """
        inference_request = pb_utils.InferenceRequest(
            model_name='audio_tokenizer',
            requested_output_names=['global_tokens', 'semantic_tokens'],
            inputs=[wav, wav_len]
        )
        
        inference_response = inference_request.exec()
        if inference_response.has_error():
            raise pb_utils.TritonModelException(inference_response.error().message())
        
        # Extract and convert output tensors
        global_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'global_tokens')
        global_tokens = torch.utils.dlpack.from_dlpack(global_tokens.to_dlpack()).cpu()
        
        semantic_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'semantic_tokens')
        semantic_tokens = torch.utils.dlpack.from_dlpack(semantic_tokens.to_dlpack()).cpu()
        
        return global_tokens, semantic_tokens

    def forward_vocoder(self, global_token_ids: torch.Tensor, pred_semantic_ids: torch.Tensor) -> torch.Tensor:
        """Forward pass through the vocoder component.
        
        Args:
            global_token_ids: Global token IDs tensor
            pred_semantic_ids: Predicted semantic token IDs tensor
            
        Returns:
            Generated waveform tensor
        """
        # Convert tensors to Triton format
        global_token_ids_tensor = pb_utils.Tensor.from_dlpack("global_tokens", to_dlpack(global_token_ids))
        pred_semantic_ids_tensor = pb_utils.Tensor.from_dlpack("semantic_tokens", to_dlpack(pred_semantic_ids))
        
        # Create and execute inference request
        inference_request = pb_utils.InferenceRequest(
            model_name='vocoder',
            requested_output_names=['waveform'],
            inputs=[global_token_ids_tensor, pred_semantic_ids_tensor]
        )
        
        inference_response = inference_request.exec()
        if inference_response.has_error():
            raise pb_utils.TritonModelException(inference_response.error().message())
        
        # Extract and convert output waveform
        waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
        waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
        
        return waveform
    
    def token2wav(self, generated_token_ids, global_token_ids):
        # Decode and extract semantic token IDs from generated text
        predicted_text = self.tokenizer.batch_decode(
            [generated_token_ids],
            skip_special_tokens=True,
        )[0]
        pred_semantic_ids = (
            torch.tensor(
                [int(token) for token in re.findall(r"bicodec_semantic_(\d+)", predicted_text)]
            )
            .unsqueeze(0)
            .to(torch.int32)
        )

        # Generate audio with vocoder
        audio = self.forward_vocoder(
            global_token_ids.to(self.device),
            pred_semantic_ids.to(self.device),
        )

        return audio

    def execute(self, requests):
        """Execute inference on the batched requests.
        
        Args:
            requests: List of inference requests
            
        Returns:
            List of inference responses containing generated audio
        """
        responses = []
        
        for request in requests:
            # Extract input tensors
            wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
            wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
            
            # Process reference audio through audio tokenizer
            global_tokens, semantic_tokens = self.forward_audio_tokenizer(wav, wav_len)
            
            # Extract text inputs
            reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
            reference_text = reference_text[0][0].decode('utf-8')
            
            target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
            target_text = target_text[0][0].decode('utf-8')
            
            # Prepare prompt for LLM
            prompt, global_token_ids = process_prompt(
                text=target_text,
                prompt_text=reference_text,
                global_token_ids=global_tokens,
                semantic_token_ids=semantic_tokens,
            )
            
            
            # Tokenize prompt for LLM
            model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
            input_ids = model_inputs.input_ids.to(torch.int32)
            
            # Generate semantic tokens with LLM
            generated_ids_iter = self.forward_llm(input_ids)

            if self.decoupled:
                response_sender = request.get_response_sender()
                request_id = request.request_id()
                semantic_token_ids_arr = []
                max_chunk_size = math.ceil(self.max_audio_chunk_duration * self.audio_tokenizer_frame_rate)
                chunk_size = math.ceil(self.audio_chunk_duration * self.audio_tokenizer_frame_rate)
                overlap_chunk_size = math.ceil(self.audio_chunk_overlap_duration * self.audio_tokenizer_frame_rate)
                self.logger.log_info(
                    f"[{request_id}] init chunk_size: {chunk_size} max_chunk_size: {max_chunk_size}"
                )
                for generated_ids in generated_ids_iter:
                    if generated_ids is None or len(generated_ids) == 0:
                        break

                    semantic_token_ids_arr.append(generated_ids)
                    if len(semantic_token_ids_arr) >= chunk_size:
                        chunk = semantic_token_ids_arr[:chunk_size]
                        generated_semantic_token_ids = np.hstack(chunk)
                        # Process each chunk
                        sub_tts_speech = self.token2wav(generated_semantic_token_ids, global_token_ids)
                        # Prepare response to send
                        audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
                        inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
                        response_sender.send(inference_response)

                        semantic_token_ids_arr = semantic_token_ids_arr[chunk_size - overlap_chunk_size:]
                        # increase chunk size for better speech quality
                        chunk_size = min(max_chunk_size, int(chunk_size * self.audio_chunk_size_scale_factor))
                        self.logger.log_info(f"[{request_id}] increase chunk_size: {chunk_size}")

                if len(semantic_token_ids_arr) > 0:  # end to finalize
                    generated_semantic_token_ids = np.hstack(semantic_token_ids_arr)
                    # Process each chunk
                    sub_tts_speech = self.token2wav(generated_semantic_token_ids, global_token_ids)
                    # Prepare response to send
                    audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
                    inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
                    response_sender.send(inference_response)
                    self.logger.log_info(f"[{request_id}] last chunk len: {len(semantic_token_ids_arr)}")
            else:
                generated_ids = next(generated_ids_iter)
                if generated_ids is None or len(generated_ids) == 0:
                    raise pb_utils.TritonModelException("Generated IDs is None or empty")

                audio = self.token2wav(generated_ids, global_token_ids)
                
                # Prepare response
                audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
                inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
                responses.append(inference_response)
            
            if self.decoupled:
                response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
                self.logger.log_info(f"send tritonserver_response_complete_final to end")
        
        if not self.decoupled:
            return responses



================================================
FILE: runtime/triton_trtllm/model_repo/spark_tts/config.pbtxt
================================================
# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

name: "spark_tts"
backend: "python"
max_batch_size: ${triton_max_batch_size}
dynamic_batching {
    max_queue_delay_microseconds: ${max_queue_delay_microseconds}
}
model_transaction_policy {
  decoupled: ${decoupled_mode}
}
parameters [
  {
   key: "llm_tokenizer_dir", 
   value: {string_value:"${llm_tokenizer_dir}"}
  },
  {
   key: "audio_chunk_duration", 
   value: {string_value:"${audio_chunk_duration}"}
  },
  {
   key: "audio_chunk_size_scale_factor", 
   value: {string_value:"${audio_chunk_size_scale_factor}"}
  },
  {
   key: "max_audio_chunk_duration", 
   value: {string_value:"${max_audio_chunk_duration}"}
  },
  {
   key: "audio_chunk_overlap_duration", 
   value: {string_value:"${audio_chunk_overlap_duration}"}
  },
  {
   key: "audio_tokenizer_frame_rate", 
   value: {string_value:"50"}
  }
]

input [
  {
    name: "reference_wav"
    data_type: TYPE_FP32
    dims: [-1]
  },
  {
    name: "reference_wav_len"
    data_type: TYPE_INT32
    dims: [1]
  },
  {
    name: "reference_text"
    data_type: TYPE_STRING
    dims: [1]
  },
  {
    name: "target_text"
    data_type: TYPE_STRING
    dims: [1]
  }
]
output [
  {
    name: "waveform"
    data_type: TYPE_FP32
    dims: [ -1 ]
  }
]

instance_group [
  {
    count: ${bls_instance_num}
    kind: KIND_CPU
  }
]

================================================
FILE: runtime/triton_trtllm/model_repo/tensorrt_llm/1/.gitkeep
================================================


================================================
FILE: runtime/triton_trtllm/model_repo/tensorrt_llm/config.pbtxt
================================================
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#  * Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
#  * Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#  * Neither the name of NVIDIA CORPORATION nor the names of its
#    contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

name: "tensorrt_llm"
backend: "${triton_backend}"
max_batch_size: ${triton_max_batch_size}

model_transaction_policy {
  decoupled: ${decoupled_mode}
}

dynamic_batching {
    preferred_batch_size: [ ${triton_max_batch_size} ]
    max_queue_delay_microseconds: ${max_queue_delay_microseconds}
    default_queue_policy: { max_queue_size: ${max_queue_size} }
}

input [
  {
    name: "input_ids"
    data_type: TYPE_INT32
    dims: [ -1 ]
    allow_ragged_batch: true
    optional: true
  },
  {
    name: "encoder_input_features"
    data_type: ${encoder_input_features_data_type}
    dims: [ -1, -1 ]
    allow_ragged_batch: true
    optional: true
  },
  {
    name: "encoder_output_lengths"
    data_type: TYPE_INT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "input_lengths"
    data_type: TYPE_INT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
  },
  {
    name: "request_output_len"
    data_type: TYPE_INT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
  },
  {
    name: "num_return_sequences"
    data_type: TYPE_INT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "draft_input_ids"
    data_type: TYPE_INT32
    dims: [ -1 ]
    optional: true
    allow_ragged_batch: true
  },
  {
    name: "decoder_input_ids"
    data_type: TYPE_INT32
    dims: [ -1 ]
    optional: true
    allow_ragged_batch: true
  },
  {
    name: "decoder_input_lengths"
    data_type: TYPE_INT32
    dims: [ 1 ]
    optional: true
    reshape: { shape: [ ] }
  },
  {
    name: "draft_logits"
    data_type: ${logits_datatype}
    dims: [ -1, -1 ]
    optional: true
    allow_ragged_batch: true
  },
  {
    name: "draft_acceptance_threshold"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "end_id"
    data_type: TYPE_INT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "pad_id"
    data_type: TYPE_INT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "stop_words_list"
    data_type: TYPE_INT32
    dims: [ 2, -1 ]
    optional: true
    allow_ragged_batch: true
  },
  {
    name: "bad_words_list"
    data_type: TYPE_INT32
    dims: [ 2, -1 ]
    optional: true
    allow_ragged_batch: true
  },
  {
    name: "embedding_bias"
    data_type: TYPE_FP32
    dims: [ -1 ]
    optional: true
    allow_ragged_batch: true
  },
  {
    name: "beam_width"
    data_type: TYPE_INT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "temperature"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "runtime_top_k"
    data_type: TYPE_INT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "runtime_top_p"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "runtime_top_p_min"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "runtime_top_p_decay"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "runtime_top_p_reset_ids"
    data_type: TYPE_INT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "len_penalty"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "early_stopping"
    data_type: TYPE_BOOL
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "repetition_penalty"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "min_length"
    data_type: TYPE_INT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "beam_search_diversity_rate"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "presence_penalty"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "frequency_penalty"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "random_seed"
    data_type: TYPE_UINT64
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "return_log_probs"
    data_type: TYPE_BOOL
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "return_context_logits"
    data_type: TYPE_BOOL
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "return_generation_logits"
    data_type: TYPE_BOOL
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "return_perf_metrics"
    data_type: TYPE_BOOL
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "exclude_input_in_output"
    data_type: TYPE_BOOL
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "stop"
    data_type: TYPE_BOOL
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "streaming"
    data_type: TYPE_BOOL
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "prompt_embedding_table"
    data_type: TYPE_FP16
    dims: [ -1, -1 ]
    optional: true
    allow_ragged_batch: true
  },
  {
    name: "prompt_table_extra_ids"
    data_type: TYPE_UINT64
    dims: [ -1 ]
    optional: true
    allow_ragged_batch: true
  },
  {
    name: "prompt_vocab_size"
    data_type: TYPE_INT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  # cross_attention_mask shape `[bs, seq_len, num_images*num_tiles]`
  {
    name: "cross_attention_mask"
    data_type: TYPE_BOOL
    dims: [ -1, -1 ]
    optional: true
    allow_ragged_batch: true
  },
  # Mrope param when mrope is used
  {
    name: "mrope_rotary_cos_sin"
    data_type: TYPE_FP32
    dims: [ -1 ]
    optional: true
  },
  {
    name: "mrope_position_deltas"
    data_type: TYPE_INT64
    dims: [ 1 ]
    optional: true
  },
  # the unique task ID for the given LoRA.
  # To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given.
  # The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`.
  # If the cache is full the oldest LoRA will be evicted to make space for new ones.  An error is returned if `lora_task_id` is not cached.
  {
    name: "lora_task_id"
    data_type: TYPE_UINT64
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  # weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ]
  # where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer
  # each of the in / out tensors are first flattened and then concatenated together in the format above.
  # D=adapter_size (R value), Hi=hidden_size_in, Ho=hidden_size_out.
  {
    name: "lora_weights"
    data_type: TYPE_FP16
    dims: [ -1, -1 ]
    optional: true
    allow_ragged_batch: true
  },
  # module identifier (same size a first dimension of lora_weights)
  # See LoraModule::ModuleType for model id mapping
  #
  # "attn_qkv": 0     # compbined qkv adapter
  # "attn_q": 1       # q adapter
  # "attn_k": 2       # k adapter
  # "attn_v": 3       # v adapter
  # "attn_dense": 4   # adapter for the dense layer in attention
  # "mlp_h_to_4h": 5  # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection
  # "mlp_4h_to_h": 6  # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection
  # "mlp_gate": 7     # for llama2 adapter for gated mlp later after attention / RMSNorm: gate
  #
  # last dim holds [ module_id, layer_idx, adapter_size (D aka R value) ]
  {
    name: "lora_config"
    data_type: TYPE_INT32
    dims: [ -1, 3 ]
    optional: true
    allow_ragged_batch: true
  },
  {
    name: "context_phase_params"
    data_type: TYPE_UINT8
    dims: [ -1 ]
    optional: true
    allow_ragged_batch: true
  },
  # skip_cross_attn_blocks shape `[bs, 1]`, only used in mllama
  {
    name: "skip_cross_attn_blocks"
    data_type: TYPE_BOOL
    dims: [ 1 ]
    optional: true
    allow_ragged_batch: true
  },
  {
    name: "retention_token_range_starts"
    data_type: TYPE_INT32
    dims: [ -1 ]
    optional: true
    allow_ragged_batch: true
  },
  {
    name: "retention_token_range_ends"
    data_type: TYPE_INT32
    dims: [ -1 ]
    optional: true
    allow_ragged_batch: true
  },
  {
    name: "retention_token_range_priorities"
    data_type: TYPE_INT32
    dims: [ -1 ]
    optional: true
    allow_ragged_batch: true
  },
  {
    name: "retention_token_range_durations_ms"
    data_type: TYPE_INT32
    dims: [ -1 ]
    optional: true
    allow_ragged_batch: true
  },
  {
    name: "retention_decode_priority"
    data_type: TYPE_INT32
    dims: [ 1 ]
    optional: true
    allow_ragged_batch: true
  },
  {
    name: "retention_decode_duration_ms"
    data_type: TYPE_INT32
    dims: [ 1 ]
    optional: true
    allow_ragged_batch: true
  },
  {
    name: "guided_decoding_guide_type"
    data_type: TYPE_STRING
    dims: [ 1 ]
    optional: true
    allow_ragged_batch: true
  },
  {
    name: "guided_decoding_guide"
    data_type: TYPE_STRING
    dims: [ 1 ]
    optional: true
    allow_ragged_batch: true
  },
  {
    name: "lookahead_window_size"
    data_type: TYPE_INT32
    dims: [ 1 ]
    optional: true
    allow_ragged_batch: true
  },
  {
    name: "lookahead_ngram_size"
    data_type: TYPE_INT32
    dims: [ 1 ]
    optional: true
    allow_ragged_batch: true
  },
  {
    name: "lookahead_verification_set_size"
    data_type: TYPE_INT32
    dims: [ 1 ]
    optional: true
    allow_ragged_batch: true
  }
]
output [
  {
    name: "output_ids"
    data_type: TYPE_INT32
    dims: [ -1, -1 ]
  },
  {
    name: "sequence_length"
    data_type: TYPE_INT32
    dims: [ -1 ]
  },
  {
    name: "cum_log_probs"
    data_type: TYPE_FP32
    dims: [ -1 ]
  },
  {
    name: "output_log_probs"
    data_type: TYPE_FP32
    dims: [ -1, -1 ]
  },
  {
    name: "context_logits"
    data_type: ${logits_datatype}
    dims: [ -1, -1 ]
  },
  {
    name: "generation_logits"
    data_type: ${logits_datatype}
    dims: [ -1, -1, -1 ]
  },
  {
    name: "batch_index"
    data_type: TYPE_INT32
    dims: [ 1 ]
  },
  {
    name: "sequence_index"
    data_type: TYPE_INT32
    dims: [ 1 ]
  },
  {
    name: "context_phase_params"
    data_type: TYPE_UINT8
    dims: [ -1 ]
  },
  {
    name: "kv_cache_alloc_new_blocks"
    data_type: TYPE_INT32
    dims: [ 1 ]
  },
  {
    name: "kv_cache_reused_blocks"
    data_type: TYPE_INT32
    dims: [ 1 ]
  },
  {
    name: "kv_cache_alloc_total_blocks"
    data_type: TYPE_INT32
    dims: [ 1 ]
  },
  {
    name: "arrival_time_ns"
    data_type: TYPE_INT64
    dims: [ 1 ]
  },
  {
    name: "first_scheduled_time_ns"
    data_type: TYPE_INT64
    dims: [ 1 ]
  },
  {
    name: "first_token_time_ns"
    data_type: TYPE_INT64
    dims: [ 1 ]
  },
  {
    name: "last_token_time_ns"
    data_type: TYPE_INT64
    dims: [ 1 ]
  },
  {
    name: "acceptance_rate"
    data_type: TYPE_FP32
    dims: [ 1 ]
  },
  {
    name: "total_accepted_draft_tokens"
    data_type: TYPE_INT32
    dims: [ 1 ]
  },
  {
    name: "total_draft_tokens"
    data_type: TYPE_INT32
    dims: [ 1 ]
  }
]
instance_group [
  {
    count: 1
    kind : KIND_CPU
  }
]
parameters: {
  key: "max_beam_width"
  value: {
    string_value: "${max_beam_width}"
  }
}
parameters: {
  key: "FORCE_CPU_ONLY_INPUT_TENSORS"
  value: {
    string_value: "no"
  }
}
parameters: {
  key: "gpt_model_type"
  value: {
    string_value: "${batching_strategy}"
  }
}
parameters: {
  key: "gpt_model_path"
  value: {
    string_value: "${engine_dir}"
  }
}
parameters: {
  key: "encoder_model_path"
  value: {
    string_value: "${encoder_engine_dir}"
  }
}
parameters: {
  key: "max_tokens_in_paged_kv_cache"
  value: {
    string_value: "${max_tokens_in_paged_kv_cache}"
  }
}
parameters: {
  key: "max_attention_window_size"
  value: {
    string_value: "${max_attention_window_size}"
  }
}
parameters: {
  key: "sink_token_length"
  value: {
    string_value: "${sink_token_length}"
  }
}
parameters: {
  key: "batch_scheduler_policy"
  value: {
    string_value: "${batch_scheduler_policy}"
  }
}
parameters: {
  key: "kv_cache_free_gpu_mem_fraction"
  value: {
    string_value: "${kv_cache_free_gpu_mem_fraction}"
  }
}
parameters: {
  key: "cross_kv_cache_fraction"
  value: {
    string_value: "${cross_kv_cache_fraction}"
  }
}
parameters: {
  key: "kv_cache_host_memory_bytes"
  value: {
    string_value: "${kv_cache_host_memory_bytes}"
  }
}
# kv_cache_onboard_blocks is for internal implementation.
parameters: {
  key: "kv_cache_onboard_blocks"
  value: {
    string_value: "${kv_cache_onboard_blocks}"
  }
}
# enable_trt_overlap is deprecated and doesn't have any effect on the runtime
# parameters: {
#   key: "enable_trt_overlap"
#   value: {
#     string_value: "${enable_trt_overlap}"
#   }
# }
parameters: {
  key: "exclude_input_in_output"
  value: {
    string_value: "${exclude_input_in_output}"
  }
}
parameters: {
  key: "cancellation_check_period_ms"
  value: {
    string_value: "${cancellation_check_period_ms}"
  }
}
parameters: {
  key: "stats_check_period_ms"
  value: {
    string_value: "${stats_check_period_ms}"
  }
}
parameters: {
  key: "iter_stats_max_iterations"
  value: {
    string_value: "${iter_stats_max_iterations}"
  }
}
parameters: {
  key: "request_stats_max_iterations"
  value: {
    string_value: "${request_stats_max_iterations}"
  }
}
parameters: {
  key: "enable_kv_cache_reuse"
  value: {
    string_value: "${enable_kv_cache_reuse}"
  }
}
parameters: {
  key: "normalize_log_probs"
  value: {
    string_value: "${normalize_log_probs}"
  }
}
parameters: {
  key: "enable_chunked_context"
  value: {
    string_value: "${enable_chunked_context}"
  }
}
parameters: {
  key: "gpu_device_ids"
  value: {
    string_value: "${gpu_device_ids}"
  }
}
parameters: {
  key: "participant_ids"
  value: {
    string_value: "${participant_ids}"
  }
}
parameters: {
  key: "lora_cache_optimal_adapter_size"
  value: {
    string_value: "${lora_cache_optimal_adapter_size}"
  }
}
parameters: {
  key: "lora_cache_max_adapter_size"
  value: {
    string_value: "${lora_cache_max_adapter_size}"
  }
}
parameters: {
  key: "lora_cache_gpu_memory_fraction"
  value: {
    string_value: "${lora_cache_gpu_memory_fraction}"
  }
}
parameters: {
  key: "lora_cache_host_memory_bytes"
  value: {
    string_value: "${lora_cache_host_memory_bytes}"
  }
}
parameters: {
  key: "lora_prefetch_dir"
  value: {
    string_value: "${lora_prefetch_dir}"
  }
}
parameters: {
  key: "decoding_mode"
  value: {
    string_value: "${decoding_mode}"
  }
}
parameters: {
  key: "executor_worker_path"
  value: {
    string_value: "/opt/tritonserver/backends/tensorrtllm/trtllmExecutorWorker"
  }
}
parameters: {
  key: "lookahead_window_size"
    value: {
      string_value: "${lookahead_window_size}"
  }
}
parameters: {
  key: "lookahead_ngram_size"
    value: {
      string_value: "${lookahead_ngram_size}"
  }
}
parameters: {
  key: "lookahead_verification_set_size"
    value: {
      string_value: "${lookahead_verification_set_size}"
  }
}
parameters: {
  key: "medusa_choices"
    value: {
      string_value: "${medusa_choices}"
  }
}
parameters: {
  key: "eagle_choices"
    value: {
      string_value: "${eagle_choices}"
  }
}
parameters: {
  key: "gpu_weights_percent"
    value: {
      string_value: "${gpu_weights_percent}"
  }
}
parameters: {
  key: "enable_context_fmha_fp32_acc"
  value: {
    string_value: "${enable_context_fmha_fp32_acc}"
  }
}
parameters: {
  key: "multi_block_mode"
  value: {
    string_value: "${multi_block_mode}"
  }
}
parameters: {
  key: "cuda_graph_mode"
  value: {
    string_value: "${cuda_graph_mode}"
  }
}
parameters: {
  key: "cuda_graph_cache_size"
  value: {
    string_value: "${cuda_graph_cache_size}"
  }
}
parameters: {
  key: "speculative_decoding_fast_logits"
  value: {
    string_value: "${speculative_decoding_fast_logits}"
  }
}
parameters: {
  key: "tokenizer_dir"
  value: {
    string_value: "${tokenizer_dir}"
  }
}
parameters: {
  key: "guided_decoding_backend"
  value: {
    string_value: "${guided_decoding_backend}"
  }
}
parameters: {
  key: "xgrammar_tokenizer_info_path"
  value: {
    string_value: "${xgrammar_tokenizer_info_path}"
  }
}


================================================
FILE: runtime/triton_trtllm/model_repo/vocoder/1/model.py
================================================
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#  * Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
#  * Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#  * Neither the name of NVIDIA CORPORATION nor the names of its
#    contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import json
import os
import logging
from typing import List, Dict

import torch
from torch.utils.dlpack import to_dlpack

import triton_python_backend_utils as pb_utils

from sparktts.models.bicodec import BiCodec

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class TritonPythonModel:
    """Triton Python model for vocoder.
    
    This model takes global and semantic tokens as input and generates audio waveforms
    using the BiCodec vocoder.
    """

    def initialize(self, args):
        """Initialize the model.
        
        Args:
            args: Dictionary containing model configuration
        """
        # Parse model parameters
        parameters = json.loads(args['model_config'])['parameters']
        model_params = {key: value["string_value"] for key, value in parameters.items()}
        model_dir = model_params["model_dir"]
        
        # Initialize device and vocoder
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
        
        self.vocoder = BiCodec.load_from_checkpoint(f"{model_dir}/BiCodec")
        del self.vocoder.encoder, self.vocoder.postnet
        self.vocoder.eval().to(self.device)  # Set model to evaluation mode

        logger.info("Vocoder initialized successfully")


    def execute(self, requests):
        """Execute inference on the batched requests.
        
        Args:
            requests: List of inference requests
            
        Returns:
            List of inference responses containing generated waveforms
        """
        global_tokens_list, semantic_tokens_list = [], []

        # Process each request in batch
        for request in requests:
            global_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "global_tokens").as_numpy()
            semantic_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "semantic_tokens").as_numpy()
            global_tokens_list.append(torch.from_numpy(global_tokens_tensor).to(self.device))
            semantic_tokens_list.append(torch.from_numpy(semantic_tokens_tensor).to(self.device))

        # Concatenate tokens for batch processing
        global_tokens = torch.cat(global_tokens_list, dim=0)
        semantic_tokens = torch.cat(semantic_tokens_list, dim=0)
        

        # Generate waveforms
        with torch.no_grad():
            wavs = self.vocoder.detokenize(semantic_tokens, global_tokens.unsqueeze(1))

        # Prepare responses
        responses = []
        for i in range(len(requests)):
            wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(wavs[i]))
            inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor])
            responses.append(inference_response)
                             
        return responses


================================================
FILE: runtime/triton_trtllm/model_repo/vocoder/config.pbtxt
================================================
# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

name: "vocoder"
backend: "python"
max_batch_size: ${triton_max_batch_size}
dynamic_batching {
    max_queue_delay_microseconds: ${max_queue_delay_microseconds}
}
parameters [
  {
   key: "model_dir", 
   value: {string_value:"${model_dir}"}
  }
]

input [
  {
    name: "global_tokens"
    data_type: TYPE_INT32
    dims: [-1]
  },
  {
    name: "semantic_tokens"
    data_type: TYPE_INT32
    dims: [-1]
  }
]
output [
  {
    name: "waveform"
    data_type: TYPE_FP32
    dims: [ -1 ]
  }
]

instance_group [
  {
    count: 1
    kind: KIND_CPU
  }
]

================================================
FILE: runtime/triton_trtllm/run.sh
================================================
export PYTHONPATH=../../../Spark-TTS/
export CUDA_VISIBLE_DEVICES=0
stage=$1
stop_stage=$2
service_type=$3
echo "Start stage: $stage, Stop stage: $stop_stage service_type: $service_type"

huggingface_model_local_dir=../../pretrained_models/Spark-TTS-0.5B
trt_dtype=bfloat16
trt_weights_dir=./tllm_checkpoint_${trt_dtype}
trt_engines_dir=./trt_engines_${trt_dtype}

model_repo=./model_repo_test

if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
    echo "Downloading Spark-TTS-0.5B from HuggingFace"
    huggingface-cli download SparkAudio/Spark-TTS-0.5B --local-dir $huggingface_model_local_dir || exit 1
fi


if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
    echo "Converting checkpoint to TensorRT weights"
    python scripts/convert_checkpoint.py --model_dir $huggingface_model_local_dir/LLM \
                                --output_dir $trt_weights_dir \
                                --dtype $trt_dtype || exit 1

    echo "Building TensorRT engines"
    trtllm-build --checkpoint_dir $trt_weights_dir \
                --output_dir $trt_engines_dir \
                --max_batch_size 16 \
                --max_num_tokens 32768 \
                --gemm_plugin $trt_dtype || exit 1
fi

if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
    echo "Creating model repository"
    rm -rf $model_repo
    mkdir -p $model_repo
    spark_tts_dir="spark_tts"

    cp -r ./model_repo/${spark_tts_dir} $model_repo
    cp -r ./model_repo/audio_tokenizer $model_repo
    cp -r ./model_repo/tensorrt_llm $model_repo
    cp -r ./model_repo/vocoder $model_repo

    ENGINE_PATH=$trt_engines_dir
    MAX_QUEUE_DELAY_MICROSECONDS=0
    MODEL_DIR=$huggingface_model_local_dir
    LLM_TOKENIZER_DIR=$huggingface_model_local_dir/LLM
    BLS_INSTANCE_NUM=4
    TRITON_MAX_BATCH_SIZE=16
    # streaming TTS parameters
    AUDIO_CHUNK_DURATION=1.0
    MAX_AUDIO_CHUNK_DURATION=30.0
    AUDIO_CHUNK_SIZE_SCALE_FACTOR=8.0
    AUDIO_CHUNK_OVERLAP_DURATION=0.1
    python3 scripts/fill_template.py -i ${model_repo}/vocoder/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
    python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
    if [ "$service_type" == "streaming" ]; then
        DECOUPLED_MODE=True
    else
        DECOUPLED_MODE=False
    fi
    python3 scripts/fill_template.py -i ${model_repo}/${spark_tts_dir}/config.pbtxt bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},audio_chunk_duration:${AUDIO_CHUNK_DURATION},max_audio_chunk_duration:${MAX_AUDIO_CHUNK_DURATION},audio_chunk_size_scale_factor:${AUDIO_CHUNK_SIZE_SCALE_FACTOR},audio_chunk_overlap_duration:${AUDIO_CHUNK_OVERLAP_DURATION}
    python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32

fi

if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
    echo "Starting Triton server"
    tritonserver --model-repository ${model_repo}
fi


if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
    echo "Running benchmark client"
    num_task=2
    if [ "$service_type" == "streaming" ]; then
        mode="streaming"
    else
        mode="offline"
    fi
    python3 client_grpc.py \
        --server-addr localhost \
        --model-name spark_tts \
        --num-tasks $num_task \
        --mode $mode \
        --log-dir ./log_concurrent_tasks_${num_task}_${mode}_new
fi

if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
    echo "Running single utterance client"
    if [ "$service_type" == "streaming" ]; then
        python client_grpc.py \
            --server-addr localhost \
            --reference-audio ../../example/prompt_audio.wav \
            --reference-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \
            --target-text "身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。" \
            --model-name spark_tts \
            --chunk-overlap-duration 0.1 \
            --mode streaming
    else
        python client_http.py \
            --reference-audio ../../example/prompt_audio.wav \
            --reference-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \
            --target-text "身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。" \
            --model-name spark_tts
    fi
fi

================================================
FILE: runtime/triton_trtllm/scripts/convert_checkpoint.py
================================================
import argparse
import os
import time
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed

from transformers import AutoConfig

import tensorrt_llm
from tensorrt_llm._utils import release_gc
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models import QWenForCausalLM
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization import QuantAlgo


def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_dir', type=str, default=None, required=True)
    parser.add_argument('--tp_size',
                        type=int,
                        default=1,
                        help='N-way tensor parallelism size')
    parser.add_argument('--pp_size',
                        type=int,
                        default=1,
                        help='N-way pipeline parallelism size')
    parser.add_argument(
        '--dtype',
        type=str,
        default='auto',
        choices=['auto', 'float16', 'bfloat16', 'float32'],
        help=
        "The data type for the model weights and activations if not quantized. "
        "If 'auto', the data type is automatically inferred from the source model; "
        "however, if the source dtype is float32, it is converted to float16.")
    parser.add_argument(
        '--use_weight_only',
        default=False,
        action="store_true",
        help='Quantize weights for the various GEMMs to INT4/INT8.'
        'See --weight_only_precision to set the precision')
    parser.add_argument(
        '--disable_weight_only_quant_plugin',
        default=False,
        action="store_true",
        help=
        'By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.'
        'You must also use --use_weight_only for that argument to have an impact.'
    )
    parser.add_argument(
        '--weight_only_precision',
        const='int8',
        type=str,
        nargs='?',
        default='int8',
        choices=['int8', 'int4', 'int4_gptq'],
        help=
        'Define the precision for the weights when using weight-only quantization.'
        'You must also use --use_weight_only for that argument to have an impact.'
    )
    parser.add_argument(
        '--calib_dataset',
        type=str,
        default='ccdv/cnn_dailymail',
        help=
        "The huggingface dataset name or the local directory of the dataset for calibration."
    )
    parser.add_argument(
        "--smoothquant",
        "-sq",
        type=float,
        default=None,
        help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)"
        " to Smoothquant the model, and output int8 weights."
        " A good first try is 0.5. Must be in [0, 1]")
    parser.add_argument(
        '--per_channel',
        action="store_true",
        default=False,
        help=
        'By default, we use a single static scaling factor for the GEMM\'s result. '
        'per_channel instead uses a different static scaling factor for each channel. '
        'The latter is usually more accurate, but a little slower.')
    parser.add_argument(
        '--per_token',
        action="store_true",
        default=False,
        help=
        'By default, we use a single static scaling factor to scale activations in the int8 range. '
        'per_token chooses at run time, and for each token, a custom scaling factor. '
        'The latter is usually more accurate, but a little slower.')
    parser.add_argument(
        '--int8_kv_cache',
        default=False,
        action="store_true",
        help=
        'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
    )
    parser.add_argument(
        '--per_group',
        default=False,
        action="store_true",
        help=
        'By default, we use a single static scaling factor to scale weights in the int4 range. '
        'per_group chooses at run time, and for each group, a custom scaling factor. '
        'The flag is built for GPTQ/AWQ quantization.')

    parser.add_argument('--group_size',
                        type=int,
                        default=128,
                        help='Group size used in GPTQ quantization.')

    parser.add_argument("--load_model_on_cpu", action="store_true")
    parser.add_argument(
        '--use_parallel_embedding',
        action="store_true",
        default=False,
        help=
        'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
    )
    parser.add_argument(
        '--embedding_sharding_dim',
        type=int,
        default=0,
        choices=[0, 1],
        help=
        'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
        'To shard it along hidden dimension, set embedding_sharding_dim=1'
        'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
    )
    parser.add_argument('--output_dir',
                        type=str,
                        default='tllm_checkpoint',
                        help='The path to save the TensorRT-LLM checkpoint')
    parser.add_argument(
        '--workers',
        type=int,
        default=1,
        help='The number of workers for converting checkpoint in parallel')
    parser.add_argument(
        '--moe_tp_size',
        type=int,
        default=-1,
        help=
        'N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE'
    )
    parser.add_argument(
        '--moe_ep_size',
        type=int,
        default=-1,
        help=
        'N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE'
    )
    args = parser.parse_args()
    return args


def args_to_quant_config(args: argparse.Namespace) -> QuantConfig:
    '''return config dict with quantization info based on the command line args
    '''
    quant_config = QuantConfig()
    if args.use_weight_only:
        if args.weight_only_precision == 'int8':
            quant_config.quant_algo = QuantAlgo.W8A16
        elif args.weight_only_precision == 'int4':
            quant_config.quant_algo = QuantAlgo.W4A16
    elif args.smoothquant:
        quant_config.smoothquant_val = args.smoothquant
        if args.per_channel:
            if args.per_token:
                quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN
            else:
                quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN
        else:
            if args.per_token:
                quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN
            else:
                quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN

    if args.int8_kv_cache:
        quant_config.kv_cache_quant_algo = QuantAlgo.INT8

    if args.weight_only_precision == 'int4_gptq':
        quant_config.group_size = args.group_size
        quant_config.has_zero_point = True
        quant_config.pre_quant_scale = False
        quant_config.quant_algo = QuantAlgo.W4A16_GPTQ

    return quant_config


def update_quant_config_from_hf(quant_config, hf_config,
                                override_fields) -> tuple[QuantConfig, dict]:
    hf_config_dict = hf_config.to_dict()
    if hf_config_dict.get('quantization_config'):
        # update the quant_algo, and clamp_val.
        if hf_config_dict['quantization_config'].get('quant_method') == 'awq':
            logger.info(
                "Load quantization configs from huggingface model_config.")
            quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
            quant_config.group_size = hf_config_dict['quantization_config'].get(
                'group_size', 128)
            quant_config.has_zero_point = hf_config_dict[
                'quantization_config'].get('zero_point', False)
            override_fields.update({"use_autoawq": True})
        elif hf_config_dict['quantization_config'].get(
                'quant_method') == 'gptq':
            logger.info(
                "Load quantization configs from huggingface model_config.")
            desc_act = hf_config_dict['quantization_config'].get(
                'desc_act', False)
            if desc_act:
                raise ValueError("GPTQ with desc_act=True is not implemented!")
            quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
            quant_config.group_size = hf_config_dict['quantization_config'].get(
                'group_size', 128)
            quant_config.has_zero_point = hf_config_dict[
                'quantization_config'].get('sym', False)
    return quant_config, override_fields


def args_to_build_options(args):
    return {
        'use_parallel_embedding': args.use_parallel_embedding,
        'embedding_sharding_dim': args.embedding_sharding_dim,
        'disable_weight_only_quant_plugin':
        args.disable_weight_only_quant_plugin
    }


def convert_and_save_hf(args):
    model_dir = args.model_dir
    world_size = args.tp_size * args.pp_size
    # Need to convert the cli args to the kay-value pairs and override them in the generate config dict.
    # Ideally these fields will be moved out of the config and pass them into build API, keep them here for compatibility purpose for now,
    # before the refactor is done.
    override_fields = {}
    override_fields.update(args_to_build_options(args))
    quant_config = args_to_quant_config(args)

    try:
        hf_config = AutoConfig.from_pretrained(model_dir,
                                               trust_remote_code=True)
        quant_config, override_fields = update_quant_config_from_hf(
            quant_config, hf_config, override_fields)
    except:
        logger.warning("AutoConfig cannot load the huggingface config.")

    if args.smoothquant is not None or args.int8_kv_cache:
        mapping = Mapping(
            world_size=world_size,
            tp_size=args.tp_size,
            pp_size=args.pp_size,
            moe_tp_size=args.moe_tp_size,
            moe_ep_size=args.moe_ep_size,
        )
        QWenForCausalLM.quantize(args.model_dir,
                                 args.output_dir,
                                 dtype=args.dtype,
                                 mapping=mapping,
                                 quant_config=quant_config,
                                 calib_dataset=args.calib_dataset,
                                 **override_fields)
    else:

        def convert_and_save_rank(args, rank):
            mapping = Mapping(world_size=world_size,
                              rank=rank,
                              tp_size=args.tp_size,
                              pp_size=args.pp_size,
                              moe_tp_size=args.moe_tp_size,
                              moe_ep_size=args.moe_ep_size)
            qwen = QWenForCausalLM.from_hugging_face(model_dir,
                                                     args.dtype,
                                                     mapping=mapping,
                                                     quant_config=quant_config,
                                                     **override_fields)
            qwen.save_checkpoint(args.output_dir, save_config=(rank == 0))
            del qwen

        execute(args.workers, [convert_and_save_rank] * world_size, args)
        release_gc()


def execute(workers, func, args):
    if workers == 1:
        for rank, f in enumerate(func):
            f(args, rank)
    else:
        with ThreadPoolExecutor(max_workers=workers) as p:
            futures = [p.submit(f, args, rank) for rank, f in enumerate(func)]
            exceptions = []
            for future in as_completed(futures):
                try:
                    future.result()
                except Exception as e:
                    traceback.print_exc()
                    exceptions.append(e)
            assert len(
                exceptions
            ) == 0, "Checkpoint conversion failed, please check error log."


def main():
    print(tensorrt_llm.__version__)
    args = parse_arguments()

    if (args.moe_tp_size == -1 and args.moe_ep_size == -1):
        # moe default to tp-only
        args.moe_tp_size = args.tp_size
        args.moe_ep_size = 1
    elif (args.moe_tp_size == -1):
        args.moe_tp_size = args.tp_size // args.moe_ep_size
    elif (args.moe_ep_size == -1):
        args.moe_ep_size = args.tp_size // args.moe_tp_size
    assert (args.moe_tp_size * args.moe_ep_size == args.tp_size
            ), "moe_tp_size * moe_ep_size must equal to tp_size"

    tik = time.time()

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    assert args.model_dir is not None
    convert_and_save_hf(args)

    tok = time.time()
    t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
    print(f'Total time of converting checkpoints: {t}')


if __name__ == '__main__':
    main()


================================================
FILE: runtime/triton_trtllm/scripts/fill_template.py
================================================
#! /usr/bin/env python3
from argparse import ArgumentParser
from string import Template


def split(string, delimiter):
    """Split a string using delimiter. Supports escaping.

    Args:
        string (str): The string to split.
        delimiter (str): The delimiter to split the string with.

    Returns:
        list: A list of strings.
    """
    result = []
    current = ""
    escape = False
    for char in string:
        if escape:
            current += char
            escape = False
        elif char == delimiter:
            result.append(current)
            current = ""
        elif char == "\\":
            escape = True
        else:
            current += char
    result.append(current)
    return result


def main(file_path, substitutions, in_place):
    with open(file_path) as f:
        pbtxt = Template(f.read())

    sub_dict = {
        "max_queue_size": 0,
        'max_queue_delay_microseconds': 0,
    }
    for sub in split(substitutions, ","):
        key, value = split(sub, ":")
        sub_dict[key] = value

        assert key in pbtxt.template, f"key '{key}' does not exist in the file {file_path}."

    pbtxt = pbtxt.safe_substitute(sub_dict)

    if in_place:
        with open(file_path, "w") as f:
            f.write(pbtxt)
    else:
        print(pbtxt)


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("file_path", help="path of the .pbtxt to modify")
    parser.add_argument(
        "substitutions",
        help=
        "substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..."
    )
    parser.add_argument("--in_place",
                        "-i",
                        action="store_true",
                        help="do the operation in-place")
    args = parser.parse_args()
    main(**vars(args))


================================================
FILE: sparktts/models/audio_tokenizer.py
================================================
# Copyright (c) 2025 SparkAudio
#               2025 Xinsheng Wang (w.xinshawn@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
import numpy as np

from pathlib import Path
from typing import Any, Dict, Tuple
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model

from sparktts.utils.file import load_config
from sparktts.utils.audio import load_audio
from sparktts.models.bicodec import BiCodec


class BiCodecTokenizer:
    """BiCodec tokenizer for handling audio input and tokenization."""

    def __init__(self, model_dir: Path, device: torch.device = None, **kwargs):
        super().__init__()
        """
        Args:
            model_dir: Path to the model directory.
            device: Device to run the model on (default is GPU if available).
        """
        self.device = device
        self.model_dir = model_dir
        self.config = load_config(f"{model_dir}/config.yaml")
        self._initialize_model()

    def _initialize_model(self):
        """Load and initialize the BiCodec model and Wav2Vec2 feature extractor."""
        self.model = BiCodec.load_from_checkpoint(f"{self.model_dir}/BiCodec").to(
            self.device
        )
        self.processor = Wav2Vec2FeatureExtractor.from_pretrained(
            f"{self.model_dir}/wav2vec2-large-xlsr-53"
        )
        self.feature_extractor = Wav2Vec2Model.from_pretrained(
            f"{self.model_dir}/wav2vec2-large-xlsr-53"
        ).to(self.device)
        self.feature_extractor.config.output_hidden_states = True

    def get_ref_clip(self, wav: np.ndarray) -> np.ndarray:
        """Get reference audio clip for speaker embedding."""
        ref_segment_length = (
            int(self.config["sample_rate"] * self.config["ref_segment_duration"])
            // self.config["latent_hop_length"]
            * self.config["latent_hop_length"]
        )
        wav_length = len(wav)

        if ref_segment_length > wav_length:
            # Repeat and truncate to handle insufficient length
            wav = np.tile(wav, ref_segment_length // wav_length + 1)

        return wav[:ref_segment_length]

    def process_audio(self, wav_path: Path) -> Tuple[np.ndarray, torch.Tensor]:
        """load auido and get reference audio from wav path"""
        wav = load_audio(
            wav_path,
            sampling_rate=self.config["sample_rate"],
            volume_normalize=self.config["volume_normalize"],
        )

        wav_ref = self.get_ref_clip(wav)

        wav_ref = torch.from_numpy(wav_ref).unsqueeze(0).float()
        return wav, wav_ref

    def extract_wav2vec2_features(self, wavs: torch.Tensor) -> torch.Tensor:
        """extract wav2vec2 features"""
        inputs = self.processor(
            wavs,
            sampling_rate=16000,
            return_tensors="pt",
            padding=True,
            output_hidden_states=True,
        ).input_values
        feat = self.feature_extractor(inputs.to(self.feature_extractor.device))
        feats_mix = (
            feat.hidden_states[11] + feat.hidden_states[14] + feat.hidden_states[16]
        ) / 3

        return feats_mix

    def tokenize_batch(self, batch: Dict[str, Any]) -> torch.Tensor:
        """tokenize the batch of audio

        Args:
            batch:
                wavs (List[np.ndarray]): batch of audio
                ref_wavs (torch.Tensor): reference audio. shape: (batch_size, seq_len)

        Returns:
            semantic_tokens: semantic tokens. shape: (batch_size, seq_len, latent_dim)
            global_tokens: global tokens. shape: (batch_size, seq_len, global_dim)
        """
        feats = self.extract_wav2vec2_features(batch["wav"])
        batch["feat"] = feats
        semantic_tokens, global_tokens = self.model.tokenize(batch)

        return global_tokens, semantic_tokens

    def tokenize(self, audio_path: str) -> Tuple[torch.Tensor, torch.Tensor]:
        """tokenize the audio"""
        wav, ref_wav = self.process_audio(audio_path)
        feat = self.extract_wav2vec2_features(wav)
        batch = {
            "wav": torch.from_numpy(wav).unsqueeze(0).float().to(self.device),
            "ref_wav": ref_wav.to(self.device),
            "feat": feat.to(self.device),
        }
        semantic_tokens, global_tokens = self.model.tokenize(batch)

        return global_tokens, semantic_tokens

    def detokenize(
        self, global_tokens: torch.Tensor, semantic_tokens: torch.Tensor
    ) -> np.array:
        """detokenize the tokens to waveform

        Args:
            global_tokens: global tokens. shape: (batch_size, global_dim)
            semantic_tokens: semantic tokens. shape: (batch_size, latent_dim)

        Returns:
            wav_rec: waveform. shape: (batch_size, seq_len) for batch or (seq_len,) for single
        """
        global_tokens = global_tokens.unsqueeze(1)
        wav_rec = self.model.detokenize(semantic_tokens, global_tokens)
        return wav_rec.detach().squeeze().cpu().numpy()


# test
if __name__ == "__main__":
    import soundfile as sf

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = BiCodecTokenizer(
        model_dir="pretrained_models/Spark-TTS-0.5B",
        device=device,
    )
    wav_path = "example/prompt_audio.wav"

    global_tokens, semantic_tokens = tokenizer.tokenize(wav_path)

    wav_rec = tokenizer.detokenize(global_tokens.squeeze(0), semantic_tokens)
    sf.write("example/prompt_recon.wav", wav_rec, 16000)


================================================
FILE: sparktts/models/bicodec.py
================================================
# Copyright (c) 2025 SparkAudio
#               2025 Xinsheng Wang (w.xinshawn@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
from pathlib import Path
from typing import Dict, Any
from omegaconf import DictConfig
from safetensors.torch import load_file

from sparktts.utils.file import load_config
from sparktts.modules.speaker.speaker_encoder import SpeakerEncoder
from sparktts.modules.encoder_decoder.feat_encoder import Encoder
from sparktts.modules.encoder_decoder.feat_decoder import Decoder
from sparktts.modules.encoder_decoder.wave_generator import WaveGenerator
from sparktts.modules.vq.factorized_vector_quantize import FactorizedVectorQuantize


class BiCodec(nn.Module):
    """
    BiCodec model for speech synthesis, incorporating a speaker encoder, feature encoder/decoder,
    quantizer, and wave generator.
    """

    def __init__(
        self,
        mel_params: Dict[str, Any],
        encoder: nn.Module,
        decoder: nn.Module,
        quantizer: nn.Module,
        speaker_encoder: nn.Module,
        prenet: nn.Module,
        postnet: nn.Module,
        **kwargs
    ) -> None:
        """
        Initializes the BiCodec model with the required components.

        Args:
            mel_params (dict): Parameters for the mel-spectrogram transformer.
            encoder (nn.Module): Encoder module.
            decoder (nn.Module): Decoder module.
            quantizer (nn.Module): Quantizer module.
            speaker_encoder (nn.Module): Speaker encoder module.
            prenet (nn.Module): Prenet network.
            postnet (nn.Module): Postnet network.
        """
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.quantizer = quantizer
        self.speaker_encoder = speaker_encoder
        self.prenet = prenet
        self.postnet = postnet
        self.init_mel_transformer(mel_params)

    @classmethod
    def load_from_checkpoint(cls, model_dir: Path, **kwargs) -> "BiCodec":
        """
        Loads the model from a checkpoint.

        Args:
            model_dir (Path): Path to the model directory containing checkpoint and config.
        
        Returns:
            BiCodec: The initialized BiCodec model.
        """
        ckpt_path = f'{model_dir}/model.safetensors'
        config = load_config(f'{model_dir}/config.yaml')['audio_tokenizer']
        mel_params = config["mel_params"]
        encoder = Encoder(**config["encoder"])
        quantizer = FactorizedVectorQuantize(**config["quantizer"])
        prenet = Decoder(**config["prenet"])
        postnet = Decoder(**config["postnet"])
        decoder = WaveGenerator(**config["decoder"])
        speaker_encoder = SpeakerEncoder(**config["speaker_encoder"])

        model = cls(
            mel_params=mel_params,
            encoder=encoder,
            decoder=decoder,
            quantizer=quantizer,
            speaker_encoder=speaker_encoder,
            prenet=prenet,
            postnet=postnet,
        )

        state_dict = load_file(ckpt_path)
        missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)

        for key in missing_keys:
            print(f"Missing tensor: {key}")
        for key in unexpected_keys:
            print(f"Unexpected tensor: {key}")

        model.eval()
        model.remove_weight_norm()

        return model

    def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
        """
        Performs a forward pass through the model.

        Args:
            batch (dict): A dictionary containing features, reference waveform, and target waveform.
        
        Returns:
            dict: A dictionary containing the reconstruction, features, and other metrics.
        """
        feat = batch["feat"]
        mel = self.mel_transformer(batch["ref_wav"]).squeeze(1)

        z = self.encoder(feat.transpose(1, 2))
        vq_outputs = self.quantizer(z)

        x_vector, d_vector = self.speaker_encoder(mel.transpose(1, 2))

        conditions = d_vector
        with_speaker_loss = False

        x = self.prenet(vq_outputs["z_q"], conditions)
        pred_feat = self.postnet(x)
        x = x + conditions.unsqueeze(-1)
        wav_recon = self.decoder(x)

        return {
            "vq_loss": vq_outputs["vq_loss"],
            "perplexity": vq_outputs["perplexity"],
            "cluster_size": vq_outputs["active_num"],
            "recons": wav_recon,
            "pred_feat": pred_feat,
            "x_vector": x_vector,
            "d_vector": d_vector,
            "audios": batch["wav"].unsqueeze(1),
            "with_speaker_loss": with_speaker_loss,
        }

    @torch.no_grad()
    def tokenize(self, batch: Dict[str, Any]):
        """
        Tokenizes the input audio into semantic and global tokens.

        Args:
            batch (dict): The input audio features and reference waveform.

        Returns:
            tuple: Semantic tokens and global tokens.
        """
        feat = batch["feat"]
        mel = self.mel_transformer(batch["ref_wav"]).squeeze(1)

        z = self.encoder(feat.transpose(1, 2))
        semantic_tokens = self.quantizer.tokenize(z)
        global_tokens = self.speaker_encoder.tokenize(mel.transpose(1, 2))

        return semantic_tokens, global_tokens

    @torch.no_grad()
    def detokenize(self, semantic_tokens, global_tokens):
        """
        Detokenizes the semantic and global tokens into a waveform.

        Args:
            semantic_tokens (tensor): Semantic tokens.
            global_tokens (tensor): Global tokens.

        Returns:
            tensor: Reconstructed waveform.
        """
        z_q = self.quantizer.detokenize(semantic_tokens)
        d_vector = self.speaker_encoder.detokenize(global_tokens)
        x = self.prenet(z_q, d_vector)
        x = x + d_vector.unsqueeze(-1)
        wav_recon = self.decoder(x)

        return wav_recon

    def init_mel_transformer(self, config: Dict[str, Any]):
        """
        Initializes the MelSpectrogram transformer based on the provided configuration.

        Args:
            config (dict): Configuration parameters for MelSpectrogram.
        """
        import torchaudio.transforms as TT

        self.mel_transformer = TT.MelSpectrogram(
            config["sample_rate"],
            config["n_fft"],
            config["win_length"],
            config["hop_length"],
            config["mel_fmin"],
            config["mel_fmax"],
            n_mels=config["num_mels"],
            power=1,
            norm="slaney",
            mel_scale="slaney",
        )

    def remove_weight_norm(self):
        """Removes weight normalization from all layers."""
        def _remove_weight_norm(m):
            try:
                torch.nn.utils.remove_weight_norm(m)
            except ValueError:
                pass  # The module didn't have weight norm

        self.apply(_remove_weight_norm)


# Test the model
if __name__ == "__main__":

    config = load_config("pretrained_models/SparkTTS-0.5B/BiCodec/config.yaml")
    model = BiCodec.load_from_checkpoint(
        model_dir="pretrained_models/SparkTTS-0.5B/BiCodec",
    )

    # Generate random inputs for testing
    duration = 0.96
    x = torch.randn(20, 1, int(duration * 16000))
    feat = torch.randn(20, int(duration * 50), 1024)
    inputs = {"feat": feat, "wav": x, "ref_wav": x}

    # Forward pass
    outputs = model(inputs)
    semantic_tokens, global_tokens = model.tokenize(inputs)
    wav_recon = model.detokenize(semantic_tokens, global_tokens)

    # Verify if the reconstruction matches
    if torch.allclose(outputs["recons"].detach(), wav_recon):
        print("Test successful")
    else:
        print("Test failed")


================================================
FILE: sparktts/modules/blocks/layers.py
================================================
# Copyright (c) 2025 SparkAudio
#               2025 Xinsheng Wang (w.xinshawn@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Adapted from https://github.com/descriptinc/descript-audio-codec under the Apache License 2.0


import torch
import torch.nn as nn
from torch.nn.utils import weight_norm


def WNConv1d(*args, **kwargs):
    return weight_norm(nn.Conv1d(*args, **kwargs))


def WNConvTranspose1d(*args, **kwargs):
    return weight_norm(nn.ConvTranspose1d(*args, **kwargs))


# Scripting this brings model speed up 1.4x
@torch.jit.script
def snake(x, alpha):
    shape = x.shape
    x = x.reshape(shape[0], shape[1], -1)
    x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
    x = x.reshape(shape)
    return x


class Snake1d(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.alpha = nn.Parameter(torch.ones(1, channels, 1))

    def forward(self, x):
        return snake(x, self.alpha)


class ResidualUnit(nn.Module):
    def __init__(self, dim: int = 16, dilation: int = 1):
        super().__init__()
        pad = ((7 - 1) * dilation) // 2
        self.block = nn.Sequential(
            Snake1d(dim),
            WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
            Snake1d(dim),
            WNConv1d(dim, dim, kernel_size=1),
        )

    def forward(self, x):
        y = self.block(x)
        pad = (x.shape[-1] - y.shape[-1]) // 2
        if pad > 0:
            x = x[..., pad:-pad]
        return x + y


def init_weights(m):
    if isinstance(m, nn.Conv1d):
        nn.init.trunc_normal_(m.weight, std=0.02)
        nn.init.constant_(m.bias, 0)


================================================
FILE: sparktts/modules/blocks/samper.py
================================================
# Copyright (c) 2025 SparkAudio
#               2025 Xinsheng Wang (w.xinshawn@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
import torch.nn as nn
import torch.nn.functional as F


class SamplingBlock(nn.Module):
    """Sampling block for upsampling or downsampling"""

    def __init__(
        self,
        dim: int,
        groups: int = 1,
        upsample_scale: int = 1,
        downsample_scale: int = 1,
    ) -> None:
        """
        Args:
            dim: input dimension
            groups: number of groups
            upsample_scale: upsampling scale
            downsample_scale: downsampling scale
        """
        super(SamplingBlock, self).__init__()

        self.upsample_scale = upsample_scale
        self.downsample_scale = downsample_scale

        if self.upsample_scale > 1:
            self.de_conv_upsampler = nn.Sequential(
                nn.LeakyReLU(0.2),
                nn.ConvTranspose1d(
                    dim,
                    dim,
                    kernel_size=upsample_scale * 2,
                    stride=upsample_scale,
                    padding=upsample_scale // 2 + upsample_scale % 2,
                    output_padding=upsample_scale % 2,
                    groups=groups,
                ),
            )

        if self.downsample_scale > 1:
            self.conv_downsampler = nn.Sequential(
                nn.LeakyReLU(0.2),
                nn.Conv1d(
                    dim,
                    dim,
                    kernel_size=2 * downsample_scale,
                    stride=downsample_scale,
                    padding=downsample_scale // 2 + downsample_scale % 2,
                    groups=groups,
                ),
            )

    @staticmethod
    def repeat_upsampler(x, upsample_scale):
        return x.repeat_interleave(upsample_scale, dim=2)

    @staticmethod
    def skip_downsampler(x, downsample_scale):
        return F.avg_pool1d(x, kernel_size=downsample_scale, stride=downsample_scale)

    def forward(self, x):
        x = x.transpose(1, 2)
        if self.upsample_scale > 1:
            repeat_res = self.repeat_upsampler(x, self.upsample_scale)
            deconv_res = self.de_conv_upsampler(x)
            upmerge_res = repeat_res + deconv_res
        else:
            upmerge_res = x
            repeat_res = x

        if self.downsample_scale > 1:
            conv_res = self.conv_downsampler(upmerge_res)
            skip2_res = self.skip_downsampler(upmerge_res, self.downsample_scale)
            skip1_res = self.skip_downsampler(repeat_res, self.downsample_scale)
        else:
            conv_res = upmerge_res
            skip2_res = upmerge_res
            skip1_res = repeat_res

        final_res = conv_res + skip1_res + skip2_res

        return final_res


# test
if __name__ == "__main__":
    test_input = torch.randn(8, 1024, 50)  # Batch size = 8, 1024 channels, length = 50
    model = SamplingBlock(1024, 1024, upsample_scale=2)
    model_down = SamplingBlock(1024, 1024, downsample_scale=2)
    output = model(test_input)
    output_down = model_down(test_input)
    print("shape after upsample * 2", output.shape)  # torch.Size([8, 1024, 100])
    print("shape after downsample * 2", output_down.shape)  # torch.Size([8, 1024, 25])
    if output.shape == torch.Size([8, 1024, 100]) and output_down.shape == torch.Size(
        [8, 1024, 25]
    ):
        print("test successful")


================================================
FILE: sparktts/modules/blocks/vocos.py
================================================
# Copyright (c) 2025 SparkAudio
#               2025 Xinsheng Wang (w.xinshawn@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
import torch.nn as nn

from typing import Tuple
from torch.nn.utils import weight_norm, remove_weight_norm

from typing import Optional


class ConvNeXtBlock(nn.Module):
    """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.

    Args:
        dim (int): Number of input channels.
        intermediate_dim (int): Dimensionality of the intermediate layer.
        layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
            Defaults to None.
        adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
            None means non-conditional LayerNorm. Defaults to None.
    """

    def __init__(
        self,
        dim: int,
        intermediate_dim: int,
        layer_scale_init_value: float,
        condition_dim: Optional[int] = None,
    ):
        super().__init__()
        self.dwconv = nn.Conv1d(
            dim, dim, kernel_size=7, padding=3, groups=dim
        )  # depthwise conv
        self.adanorm = condition_dim is not None
        if condition_dim:
            self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6)
        else:
            self.norm = nn.LayerNorm(dim, eps=1e-6)
        self.pwconv1 = nn.Linear(
            dim, intermediate_dim
        )  # pointwise/1x1 convs, implemented with linear layers
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(intermediate_dim, dim)
        self.gamma = (
            nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
            if layer_scale_init_value > 0
            else None
        )

    def forward(
        self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        residual = x
        x = self.dwconv(x)
        x = x.transpose(1, 2)  # (B, C, T) -> (B, T, C)
        if self.adanorm:
            assert cond_embedding_id is not None
            x = self.norm(x, cond_embedding_id)
        else:
            x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = x.transpose(1, 2)  # (B, T, C) -> (B, C, T)

        x = residual + x
        return x


class AdaLayerNorm(nn.Module):
    """
    Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes

    Args:
        condition_dim (int): Dimension of the condition.
        embedding_dim (int): Dimension of the embeddings.
    """

    def __init__(self, condition_dim: int, embedding_dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.dim = embedding_dim
        self.scale = nn.Linear(condition_dim, embedding_dim)
        self.shift = nn.Linear(condition_dim, embedding_dim)
        torch.nn.init.ones_(self.scale.weight)
        torch.nn.init.zeros_(self.shift.weight)

    def forward(self, x: torch.Tensor, cond_embedding: torch.Tensor) -> torch.Tensor:
        scale = self.scale(cond_embedding)
        shift = self.shift(cond_embedding)
        x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
        x = x * scale.unsqueeze(1) + shift.unsqueeze(1)
        return x


class ResBlock1(nn.Module):
    """
    ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
    but without upsampling layers.

    Args:
        dim (int): Number of input channels.
        kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
        dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
            Defaults to (1, 3, 5).
        lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
            Defaults to 0.1.
        layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
            Defaults to None.
    """

    def __init__(
        self,
        dim: int,
        kernel_size: int = 3,
        dilation: Tuple[int, int, int] = (1, 3, 5),
        lrelu_slope: float = 0.1,
        layer_scale_init_value: Optional[float] = None,
    ):
        super().__init__()
        self.lrelu_slope = lrelu_slope
        self.convs1 = nn.ModuleList(
            [
                weight_norm(
                    nn.Conv1d(
                        dim,
                        dim,
                        kernel_size,
                        1,
                        dilation=dilation[0],
                        padding=self.get_padding(kernel_size, dilation[0]),
                    )
                ),
                weight_norm(
                    nn.Conv1d(
                        dim,
                        dim,
                        kernel_size,
                        1,
                        dilation=dilation[1],
                        padding=self.get_padding(kernel_size, dilation[1]),
                    )
                ),
                weight_norm(
                    nn.Conv1d(
                        dim,
                        dim,
                        kernel_size,
                        1,
                        dilation=dilation[2],
                        padding=self.get_padding(kernel_size, dilation[2]),
                    )
                ),
            ]
        )

        self.convs2 = nn.ModuleList(
            [
                weight_norm(
                    nn.Conv1d(
                        dim,
                        dim,
                        kernel_size,
                        1,
                        dilation=1,
                        padding=self.get_padding(kernel_size, 1),
                    )
                ),
                weight_norm(
                    nn.Conv1d(
                        dim,
                        dim,
                        kernel_size,
                        1,
                        dilation=1,
                        padding=self.get_padding(kernel_size, 1),
                    )
                ),
                weight_norm(
                    nn.Conv1d(
                        dim,
                        dim,
                        kernel_size,
                        1,
                        dilation=1,
                        padding=self.get_padding(kernel_size, 1),
                    )
                ),
            ]
        )

        self.gamma = nn.ParameterList(
            [
                (
                    nn.Parameter(
                        layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
                    )
                    if layer_scale_init_value is not None
                    else None
                ),
                (
                    nn.Parameter(
                        layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
                    )
                    if layer_scale_init_value is not None
                    else None
                ),
                (
                    nn.Parameter(
                        layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
                    )
                    if layer_scale_init_value is not None
                    else None
                ),
            ]
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
            xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
            xt = c1(xt)
            xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
            xt = c2(xt)
            if gamma is not None:
                xt = gamma * xt
            x = xt + x
        return x

    def remove_weight_norm(self):
        for l in self.convs1:
            remove_weight_norm(l)
        for l in self.convs2:
            remove_weight_norm(l)

    @staticmethod
    def get_padding(kernel_size: int, dilation: int = 1) -> int:
        return int((kernel_size * dilation - dilation) / 2)


class Backbone(nn.Module):
    """Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""

    def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Args:
            x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
                        C denotes output features, and L is the sequence length.

        Returns:
            Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
                    and H denotes the model dimension.
        """
        raise NotImplementedError("Subclasses must implement the forward method.")


class VocosBackbone(Backbone):
    """
    Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization

    Args:
        input_channels (int): Number of input features channels.
        dim (int): Hidden dimension of the model.
        intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
        num_layers (int): Number of ConvNeXtBlock layers.
        layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
        adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
                                                None means non-conditional model. Defaults to None.
    """

    def __init__(
        self,
        input_channels: int,
        dim: int,
        intermediate_dim: int,
        num_layers: int,
        layer_scale_init_value: Optional[float] = None,
        condition_dim: Optional[int] = None,
    ):
        super().__init__()
        self.input_channels = input_channels
        self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
        self.adanorm = condition_dim is not None
        if condition_dim:
            self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6)
        else:
            self.norm = nn.LayerNorm(dim, eps=1e-6)
        layer_scale_init_value = layer_scale_init_value or 1 / num_layers
        self.convnext = nn.ModuleList(
            [
                ConvNeXtBlock(
                    dim=dim,
                    intermediate_dim=intermediate_dim,
                    layer_scale_init_value=layer_scale_init_value,
                    condition_dim=condition_dim,
                )
                for _ in range(num_layers)
            ]
        )
        self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv1d, nn.Linear)):
            nn.init.trunc_normal_(m.weight, std=0.02)
            nn.init.constant_(m.bias, 0)

    def forward(self, x: torch.Tensor, condition: torch.Tensor = None) -> torch.Tensor:
        x = self.embed(x)
        if self.adanorm:
            assert condition is not None
            x = self.norm(x.transpose(1, 2), condition)
        else:
            x = self.norm(x.transpose(1, 2))
        x = x.transpose(1, 2)
        for conv_block in self.convnext:
            x = conv_block(x, condition)
        x = self.final_layer_norm(x.transpose(1, 2))
        return x


class VocosResNetBackbone(Backbone):
    """
    Vocos backbone module built with ResBlocks.

    Args:
        input_channels (int): Number of input features channels.
        dim (int): Hidden dimension of the model.
        num_blocks (int): Number of ResBlock1 blocks.
        layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
    """

    def __init__(
        self,
        input_channels,
        dim,
        num_blocks,
        layer_scale_init_value=None,
    ):
        super().__init__()
        self.input_channels = input_channels
        self.embed = weight_norm(
            nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
        )
        layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
        self.resnet = nn.Sequential(
            *[
                ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
                for _ in range(num_blocks)
            ]
        )

    def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
        x = self.embed(x)
        x = self.resnet(x)
        x = x.transpose(1, 2)
        return x


================================================
FILE: sparktts/modules/encoder_decoder/feat_decoder.py
================================================
# Copyright (c) 2025 SparkAudio
#               2025 Xinsheng Wang (w.xinshawn@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
import torch.nn as nn

from typing import List

from sparktts.modules.blocks.vocos import VocosBackbone
from sparktts.modules.blocks.samper import SamplingBlock


class Decoder(nn.Module):
    """Decoder module with convnext and upsampling blocks

    Args:
        sample_ratios (List[int]): sample ratios
            example: [2, 2] means downsample by 2x and then upsample by 2x
    """

    def __init__(
        self,
        input_channels: int,
        vocos_dim: int,
        vocos_intermediate_dim: int,
        vocos_num_layers: int,
        out_channels: int,
        condition_dim: int = None,
        sample_ratios: List[int] = [1, 1],
        use_tanh_at_final: bool = False,
    ):
        super().__init__()

        self.linear_pre = nn.Linear(input_channels, vocos_dim)
        modules = [
            nn.Sequential(
                SamplingBlock(
                    dim=vocos_dim,
                    groups=vocos_dim,
                    upsample_scale=ratio,
                ),
                VocosBackbone(
                    input_channels=vocos_dim,
                    dim=vocos_dim,
                    intermediate_dim=vocos_intermediate_dim,
                    num_layers=2,
                    condition_dim=None,
                ),
            )
            for ratio in sample_ratios
        ]

        self.downsample = nn.Sequential(*modules)

        self.vocos_backbone = VocosBackbone(
            input_channels=vocos_dim,
            dim=vocos_dim,
            intermediate_dim=vocos_intermediate_dim,
            num_layers=vocos_num_layers,
            condition_dim=condition_dim,
        )
        self.linear = nn.Linear(vocos_dim, out_channels)
        self.use_tanh_at_final = use_tanh_at_final

    def forward(self, x: torch.Tensor, c: torch.Tensor = None):
        """encoder forward.

        Args:
            x (torch.Tensor): (batch_size, input_channels, length)

        Returns:
            x (torch.Tensor): (batch_size, encode_channels, length)
        """
        x = self.linear_pre(x.transpose(1, 2))
        x = self.downsample(x).transpose(1, 2)
        x = self.vocos_backbone(x, condition=c)
        x = self.linear(x).transpose(1, 2)
        if self.use_tanh_at_final:
            x = torch.tanh(x)

        return x


# test
if __name__ == "__main__":
    test_input = torch.randn(8, 1024, 50)  # Batch size = 8, 1024 channels, length = 50
    condition = torch.randn(8, 256)
    decoder = Decoder(
        input_channels=1024,
        vocos_dim=384,
        vocos_intermediate_dim=2048,
        vocos_num_layers=12,
        out_channels=256,
        condition_dim=256,
        sample_ratios=[2, 2],
    )
    output = decoder(test_input, condition)
    print(output.shape)  # torch.Size([8, 256, 200])
    if output.shape == torch.Size([8, 256, 200]):
        print("Decoder test passed")
    else:
        print("Decoder test failed")


================================================
FILE: sparktts/modules/encoder_decoder/feat_encoder.py
================================================
# Copyright (c) 2025 SparkAudio
#               2025 Xinsheng Wang (w.xinshawn@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
import torch.nn as nn

from typing import List

from sparktts.modules.blocks.vocos import VocosBackbone
from sparktts.modules.blocks.samper import SamplingBlock


class Encoder(nn.Module):
    """Encoder module with convnext and downsampling blocks"""

    def __init__(
        self,
        input_channels: int,
        vocos_dim: int,
        vocos_intermediate_dim: int,
        vocos_num_layers: int,
        out_channels: int,
        sample_ratios: List[int] = [1, 1],
    ):
        super().__init__()
        """
        Encoder module with VocosBackbone and sampling blocks.

        Args:
            sample_ratios (List[int]): sample ratios
                example: [2, 2] means downsample by 2x and then upsample by 2x
        """
        self.encoder = VocosBackbone(
            input_channels=input_channels,
            dim=vocos_dim,
            intermediate_dim=vocos_intermediate_dim,
            num_layers=vocos_num_layers,
            condition_dim=None,
        )

        modules = [
            nn.Sequential(
                SamplingBlock(
                    dim=vocos_dim,
                    groups=vocos_dim,
                    downsample_scale=ratio,
                ),
                VocosBackbone(
                    input_channels=vocos_dim,
                    dim=vocos_dim,
                    intermediate_dim=vocos_intermediate_dim,
                    num_layers=2,
                    condition_dim=None,
                ),
            )
            for ratio in sample_ratios
        ]

        self.downsample = nn.Sequential(*modules)

        self.project = nn.Linear(vocos_dim, out_channels)

    def forward(self, x: torch.Tensor, *args):
        """
        Args:
            x (torch.Tensor): (batch_size, input_channels, length)

        Returns:
            x (torch.Tensor): (batch_size, encode_channels, length)
        """
        x = self.encoder(x)
        x = self.downsample(x)
        x = self.project(x)
        return x.transpose(1, 2)


# test
if __name__ == "__main__":
    test_input = torch.randn(8, 1024, 50)  # Batch size = 8, 1024 channels, length = 50
    encoder = Encoder(
        input_channels=1024,
        vocos_dim=384,
        vocos_intermediate_dim=2048,
        vocos_num_layers=12,
        out_channels=256,
        sample_ratios=[2, 2],
    )

    output = encoder(test_input)
    print(output.shape)  # torch.Size([8, 256, 12])
    if output.shape == torch.Size([8, 256, 12]):
        print("test successful")


================================================
FILE: sparktts/modules/encoder_decoder/wave_generator.py
================================================
# Copyright (c) 2024 Xinsheng Wang (w.xinshawn@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Adapted from https://github.com/descriptinc/descript-audio-codec under the Apache License 2.0


import torch.nn as nn

from sparktts.modules.blocks.layers import (
    Snake1d,
    WNConv1d,
    ResidualUnit,
    WNConvTranspose1d,
    init_weights,
)


class DecoderBlock(nn.Module):
    def __init__(
        self,
        input_dim: int = 16,
        output_dim: int = 8,
        kernel_size: int = 2,
        stride: int = 1,
    ):
        super().__init__()
        self.block = nn.Sequential(
            Snake1d(input_dim),
            WNConvTranspose1d(
                input_dim,
                output_dim,
                kernel_size=kernel_size,
                stride=stride,
                padding=(kernel_size - stride) // 2,
            ),
            ResidualUnit(output_dim, dilation=1),
            ResidualUnit(output_dim, dilation=3),
            ResidualUnit(output_dim, dilation=9),
        )

    def forward(self, x):
        return self.block(x)


class WaveGenerator(nn.Module):
    def __init__(
        self,
        input_channel,
        channels,
        rates,
        kernel_sizes,
        d_out: int = 1,
    ):
     
Download .txt
gitextract_kozm2q6w/

├── .gitignore
├── LICENSE
├── README.md
├── cli/
│   ├── SparkTTS.py
│   └── inference.py
├── example/
│   └── infer.sh
├── requirements.txt
├── runtime/
│   └── triton_trtllm/
│       ├── Dockerfile.server
│       ├── README.md
│       ├── client_grpc.py
│       ├── client_http.py
│       ├── docker-compose.yml
│       ├── model_repo/
│       │   ├── audio_tokenizer/
│       │   │   ├── 1/
│       │   │   │   └── model.py
│       │   │   └── config.pbtxt
│       │   ├── spark_tts/
│       │   │   ├── 1/
│       │   │   │   └── model.py
│       │   │   └── config.pbtxt
│       │   ├── tensorrt_llm/
│       │   │   ├── 1/
│       │   │   │   └── .gitkeep
│       │   │   └── config.pbtxt
│       │   └── vocoder/
│       │       ├── 1/
│       │       │   └── model.py
│       │       └── config.pbtxt
│       ├── run.sh
│       └── scripts/
│           ├── convert_checkpoint.py
│           └── fill_template.py
├── sparktts/
│   ├── models/
│   │   ├── audio_tokenizer.py
│   │   └── bicodec.py
│   ├── modules/
│   │   ├── blocks/
│   │   │   ├── layers.py
│   │   │   ├── samper.py
│   │   │   └── vocos.py
│   │   ├── encoder_decoder/
│   │   │   ├── feat_decoder.py
│   │   │   ├── feat_encoder.py
│   │   │   └── wave_generator.py
│   │   ├── fsq/
│   │   │   ├── finite_scalar_quantization.py
│   │   │   └── residual_fsq.py
│   │   ├── speaker/
│   │   │   ├── ecapa_tdnn.py
│   │   │   ├── perceiver_encoder.py
│   │   │   ├── pooling_layers.py
│   │   │   └── speaker_encoder.py
│   │   └── vq/
│   │       └── factorized_vector_quantize.py
│   └── utils/
│       ├── __init__.py
│       ├── audio.py
│       ├── file.py
│       ├── parse_options.sh
│       └── token_parser.py
└── webui.py
Download .txt
SYMBOL INDEX (273 symbols across 28 files)

FILE: cli/SparkTTS.py
  class SparkTTS (line 27) | class SparkTTS:
    method __init__ (line 32) | def __init__(self, model_dir: Path, device: torch.device = torch.devic...
    method _initialize_inference (line 46) | def _initialize_inference(self):
    method process_prompt (line 53) | def process_prompt(
    method process_prompt_control (line 110) | def process_prompt_control(
    method inference (line 158) | def inference(

FILE: cli/inference.py
  function parse_args (line 28) | def parse_args():
  function run_tts (line 64) | def run_tts(args):

FILE: runtime/triton_trtllm/client_grpc.py
  class UserData (line 65) | class UserData:
    method __init__ (line 66) | def __init__(self):
    method record_start_time (line 71) | def record_start_time(self):
    method get_first_chunk_latency (line 74) | def get_first_chunk_latency(self):
  function callback (line 79) | def callback(user_data, result, error):
  function write_triton_stats (line 89) | def write_triton_stats(stats, summary_file):
  function get_args (line 145) | def get_args():
  function load_audio (line 262) | def load_audio(wav_path, target_sample_rate=16000):
  function prepare_request_input_output (line 276) | def prepare_request_input_output(
  function run_sync_streaming_inference (line 332) | def run_sync_streaming_inference(
  function send_streaming (line 433) | async def send_streaming(
  function send (line 517) | async def send(
  function load_manifests (line 565) | def load_manifests(manifest_path):
  function split_data (line 586) | def split_data(data, k):
  function main (line 608) | async def main():
  function run_main (line 823) | async def run_main():

FILE: runtime/triton_trtllm/client_http.py
  function get_args (line 32) | def get_args():
  function prepare_request (line 83) | def prepare_request(

FILE: runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py
  class TritonPythonModel (line 37) | class TritonPythonModel:
    method initialize (line 44) | def initialize(self, args):
    method get_ref_clip (line 59) | def get_ref_clip(self, wav: np.ndarray) -> np.ndarray:
    method execute (line 86) | def execute(self, requests):

FILE: runtime/triton_trtllm/model_repo/spark_tts/1/model.py
  function process_prompt (line 41) | def process_prompt(
  class TritonPythonModel (line 101) | class TritonPythonModel:
    method initialize (line 108) | def initialize(self, args):
    method forward_llm (line 139) | def forward_llm(self, input_ids):
    method forward_audio_tokenizer (line 223) | def forward_audio_tokenizer(self, wav, wav_len):
    method forward_vocoder (line 252) | def forward_vocoder(self, global_token_ids: torch.Tensor, pred_semanti...
    method token2wav (line 283) | def token2wav(self, generated_token_ids, global_token_ids):
    method execute (line 305) | def execute(self, requests):

FILE: runtime/triton_trtllm/model_repo/vocoder/1/model.py
  class TritonPythonModel (line 43) | class TritonPythonModel:
    method initialize (line 50) | def initialize(self, args):
    method execute (line 72) | def execute(self, requests):

FILE: runtime/triton_trtllm/scripts/convert_checkpoint.py
  function parse_arguments (line 18) | def parse_arguments():
  function args_to_quant_config (line 160) | def args_to_quant_config(args: argparse.Namespace) -> QuantConfig:
  function update_quant_config_from_hf (line 194) | def update_quant_config_from_hf(quant_config, hf_config,
  function args_to_build_options (line 224) | def args_to_build_options(args):
  function convert_and_save_hf (line 233) | def convert_and_save_hf(args):
  function execute (line 287) | def execute(workers, func, args):
  function main (line 306) | def main():

FILE: runtime/triton_trtllm/scripts/fill_template.py
  function split (line 6) | def split(string, delimiter):
  function main (line 34) | def main(file_path, substitutions, in_place):

FILE: sparktts/models/audio_tokenizer.py
  class BiCodecTokenizer (line 29) | class BiCodecTokenizer:
    method __init__ (line 32) | def __init__(self, model_dir: Path, device: torch.device = None, **kwa...
    method _initialize_model (line 44) | def _initialize_model(self):
    method get_ref_clip (line 57) | def get_ref_clip(self, wav: np.ndarray) -> np.ndarray:
    method process_audio (line 72) | def process_audio(self, wav_path: Path) -> Tuple[np.ndarray, torch.Ten...
    method extract_wav2vec2_features (line 85) | def extract_wav2vec2_features(self, wavs: torch.Tensor) -> torch.Tensor:
    method tokenize_batch (line 101) | def tokenize_batch(self, batch: Dict[str, Any]) -> torch.Tensor:
    method tokenize (line 119) | def tokenize(self, audio_path: str) -> Tuple[torch.Tensor, torch.Tensor]:
    method detokenize (line 132) | def detokenize(

FILE: sparktts/models/bicodec.py
  class BiCodec (line 31) | class BiCodec(nn.Module):
    method __init__ (line 37) | def __init__(
    method load_from_checkpoint (line 70) | def load_from_checkpoint(cls, model_dir: Path, **kwargs) -> "BiCodec":
    method forward (line 113) | def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
    method tokenize (line 152) | def tokenize(self, batch: Dict[str, Any]):
    method detokenize (line 172) | def detokenize(self, semantic_tokens, global_tokens):
    method init_mel_transformer (line 191) | def init_mel_transformer(self, config: Dict[str, Any]):
    method remove_weight_norm (line 213) | def remove_weight_norm(self):

FILE: sparktts/modules/blocks/layers.py
  function WNConv1d (line 24) | def WNConv1d(*args, **kwargs):
  function WNConvTranspose1d (line 28) | def WNConvTranspose1d(*args, **kwargs):
  function snake (line 34) | def snake(x, alpha):
  class Snake1d (line 42) | class Snake1d(nn.Module):
    method __init__ (line 43) | def __init__(self, channels):
    method forward (line 47) | def forward(self, x):
  class ResidualUnit (line 51) | class ResidualUnit(nn.Module):
    method __init__ (line 52) | def __init__(self, dim: int = 16, dilation: int = 1):
    method forward (line 62) | def forward(self, x):
  function init_weights (line 70) | def init_weights(m):

FILE: sparktts/modules/blocks/samper.py
  class SamplingBlock (line 22) | class SamplingBlock(nn.Module):
    method __init__ (line 25) | def __init__(
    method repeat_upsampler (line 72) | def repeat_upsampler(x, upsample_scale):
    method skip_downsampler (line 76) | def skip_downsampler(x, downsample_scale):
    method forward (line 79) | def forward(self, x):

FILE: sparktts/modules/blocks/vocos.py
  class ConvNeXtBlock (line 26) | class ConvNeXtBlock(nn.Module):
    method __init__ (line 38) | def __init__(
    method forward (line 65) | def forward(
  class AdaLayerNorm (line 87) | class AdaLayerNorm(nn.Module):
    method __init__ (line 96) | def __init__(self, condition_dim: int, embedding_dim: int, eps: float ...
    method forward (line 105) | def forward(self, x: torch.Tensor, cond_embedding: torch.Tensor) -> to...
  class ResBlock1 (line 113) | class ResBlock1(nn.Module):
    method __init__ (line 129) | def __init__(
    method forward (line 235) | def forward(self, x: torch.Tensor) -> torch.Tensor:
    method remove_weight_norm (line 246) | def remove_weight_norm(self):
    method get_padding (line 253) | def get_padding(kernel_size: int, dilation: int = 1) -> int:
  class Backbone (line 257) | class Backbone(nn.Module):
    method forward (line 260) | def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
  class VocosBackbone (line 273) | class VocosBackbone(Backbone):
    method __init__ (line 287) | def __init__(
    method _init_weights (line 319) | def _init_weights(self, m):
    method forward (line 324) | def forward(self, x: torch.Tensor, condition: torch.Tensor = None) -> ...
  class VocosResNetBackbone (line 338) | class VocosResNetBackbone(Backbone):
    method __init__ (line 349) | def __init__(
    method forward (line 369) | def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:

FILE: sparktts/modules/encoder_decoder/feat_decoder.py
  class Decoder (line 26) | class Decoder(nn.Module):
    method __init__ (line 34) | def __init__(
    method forward (line 78) | def forward(self, x: torch.Tensor, c: torch.Tensor = None):

FILE: sparktts/modules/encoder_decoder/feat_encoder.py
  class Encoder (line 26) | class Encoder(nn.Module):
    method __init__ (line 29) | def __init__(
    method forward (line 76) | def forward(self, x: torch.Tensor, *args):

FILE: sparktts/modules/encoder_decoder/wave_generator.py
  class DecoderBlock (line 29) | class DecoderBlock(nn.Module):
    method __init__ (line 30) | def __init__(
    method forward (line 52) | def forward(self, x):
  class WaveGenerator (line 56) | class WaveGenerator(nn.Module):
    method __init__ (line 57) | def __init__(
    method forward (line 87) | def forward(self, x):

FILE: sparktts/modules/fsq/finite_scalar_quantization.py
  function exists (line 22) | def exists(v):
  function default (line 26) | def default(*args):
  function maybe (line 33) | def maybe(fn):
  function pack_one (line 43) | def pack_one(t, pattern):
  function unpack_one (line 47) | def unpack_one(t, ps, pattern):
  function round_ste (line 54) | def round_ste(z: Tensor) -> Tensor:
  class FSQ (line 63) | class FSQ(Module):
    method __init__ (line 64) | def __init__(
    method bound (line 126) | def bound(self, z, eps: float = 1e-3):
    method quantize (line 133) | def quantize(self, z):
    method _scale_and_shift (line 139) | def _scale_and_shift(self, zhat_normalized):
    method _scale_and_shift_inverse (line 143) | def _scale_and_shift_inverse(self, zhat):
    method _indices_to_codes (line 147) | def _indices_to_codes(self, indices):
    method codes_to_indices (line 152) | def codes_to_indices(self, zhat):
    method indices_to_level_indices (line 158) | def indices_to_level_indices(self, indices):
    method indices_to_codes (line 164) | def indices_to_codes(self, indices):
    method forward (line 182) | def forward(self, z):

FILE: sparktts/modules/fsq/residual_fsq.py
  function exists (line 16) | def exists(val):
  function first (line 20) | def first(l):
  function default (line 24) | def default(val, d):
  function round_up_multiple (line 28) | def round_up_multiple(num, mult):
  function is_distributed (line 35) | def is_distributed():
  function get_maybe_sync_seed (line 39) | def get_maybe_sync_seed(device, max_size=10_000):
  class ResidualFSQ (line 48) | class ResidualFSQ(Module):
    method __init__ (line 51) | def __init__(
    method codebooks (line 107) | def codebooks(self):
    method get_codes_from_indices (line 112) | def get_codes_from_indices(self, indices):
    method get_output_from_indices (line 153) | def get_output_from_indices(self, indices):
    method forward (line 158) | def forward(self, x, return_all_codes=False, rand_quantize_dropout_fix...
  class GroupedResidualFSQ (line 269) | class GroupedResidualFSQ(Module):
    method __init__ (line 270) | def __init__(self, *, dim, groups=1, accept_image_fmap=False, **kwargs):
    method codebooks (line 287) | def codebooks(self):
    method split_dim (line 291) | def split_dim(self):
    method get_codes_from_indices (line 294) | def get_codes_from_indices(self, indices):
    method get_output_from_indices (line 301) | def get_output_from_indices(self, indices):
    method forward (line 308) | def forward(self, x, return_all_codes=False):

FILE: sparktts/modules/speaker/ecapa_tdnn.py
  class Res2Conv1dReluBn (line 28) | class Res2Conv1dReluBn(nn.Module):
    method __init__ (line 33) | def __init__(
    method forward (line 67) | def forward(self, x):
  class Conv1dReluBn (line 89) | class Conv1dReluBn(nn.Module):
    method __init__ (line 91) | def __init__(
    method forward (line 107) | def forward(self, x):
  class SE_Connect (line 115) | class SE_Connect(nn.Module):
    method __init__ (line 117) | def __init__(self, channels, se_bottleneck_dim=128):
    method forward (line 122) | def forward(self, x):
  class SE_Res2Block (line 135) | class SE_Res2Block(nn.Module):
    method __init__ (line 137) | def __init__(self, channels, kernel_size, stride, padding, dilation, s...
    method forward (line 148) | def forward(self, x):
  class ECAPA_TDNN (line 152) | class ECAPA_TDNN(nn.Module):
    method __init__ (line 154) | def __init__(
    method forward (line 191) | def forward(self, x, return_latent=False):
  function ECAPA_TDNN_c1024 (line 211) | def ECAPA_TDNN_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=Fa...
  function ECAPA_TDNN_GLOB_c1024 (line 221) | def ECAPA_TDNN_GLOB_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_...
  function ECAPA_TDNN_c512 (line 232) | def ECAPA_TDNN_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=Fal...
  function ECAPA_TDNN_GLOB_c512 (line 242) | def ECAPA_TDNN_GLOB_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_b...

FILE: sparktts/modules/speaker/perceiver_encoder.py
  function exists (line 29) | def exists(val):
  function once (line 33) | def once(fn):
  class Attend (line 52) | class Attend(nn.Module):
    method __init__ (line 53) | def __init__(self, dropout=0.0, causal=False, use_flash=False):
    method get_mask (line 90) | def get_mask(self, n, device):
    method flash_attn (line 98) | def flash_attn(self, q, k, v, mask=None):
    method forward (line 135) | def forward(self, q, k, v, mask=None):
  function Sequential (line 181) | def Sequential(*mods):
  function exists (line 185) | def exists(x):
  function default (line 189) | def default(val, d):
  class RMSNorm (line 195) | class RMSNorm(nn.Module):
    method __init__ (line 196) | def __init__(self, dim, scale=True, dim_cond=None):
    method forward (line 204) | def forward(self, x, cond=None):
  class CausalConv1d (line 217) | class CausalConv1d(nn.Conv1d):
    method __init__ (line 218) | def __init__(self, *args, **kwargs):
    method forward (line 227) | def forward(self, x):
  class GEGLU (line 232) | class GEGLU(nn.Module):
    method forward (line 233) | def forward(self, x):
  function FeedForward (line 238) | def FeedForward(dim, mult=4, causal_conv=False):
  class Attention (line 254) | class Attention(nn.Module):
    method __init__ (line 255) | def __init__(
    method forward (line 280) | def forward(self, x, context=None, mask=None):
  class PerceiverResampler (line 297) | class PerceiverResampler(nn.Module):
    method __init__ (line 298) | def __init__(
    method forward (line 339) | def forward(self, x, mask=None):

FILE: sparktts/modules/speaker/pooling_layers.py
  class TAP (line 27) | class TAP(nn.Module):
    method __init__ (line 32) | def __init__(self, in_dim=0, **kwargs):
    method forward (line 36) | def forward(self, x):
    method get_out_dim (line 42) | def get_out_dim(self):
  class TSDP (line 47) | class TSDP(nn.Module):
    method __init__ (line 52) | def __init__(self, in_dim=0, **kwargs):
    method forward (line 56) | def forward(self, x):
    method get_out_dim (line 62) | def get_out_dim(self):
  class TSTP (line 67) | class TSTP(nn.Module):
    method __init__ (line 74) | def __init__(self, in_dim=0, **kwargs):
    method forward (line 78) | def forward(self, x):
    method get_out_dim (line 87) | def get_out_dim(self):
  class ASTP (line 92) | class ASTP(nn.Module):
    method __init__ (line 97) | def __init__(self,
    method forward (line 119) | def forward(self, x):
    method get_out_dim (line 146) | def get_out_dim(self):
  class MHASTP (line 151) | class MHASTP(torch.nn.Module):
    method __init__ (line 158) | def __init__(self,
    method forward (line 193) | def forward(self, input):
    method get_out_dim (line 220) | def get_out_dim(self):
  class MQMHASTP (line 225) | class MQMHASTP(torch.nn.Module):
    method __init__ (line 247) | def __init__(self,
    method forward (line 266) | def forward(self, input):
    method get_out_dim (line 283) | def get_out_dim(self):

FILE: sparktts/modules/speaker/speaker_encoder.py
  class SpeakerEncoder (line 29) | class SpeakerEncoder(nn.Module):
    method __init__ (line 44) | def __init__(
    method get_codes_from_indices (line 71) | def get_codes_from_indices(self, indices: torch.Tensor) -> torch.Tensor:
    method get_indices (line 75) | def get_indices(self, mels: torch.Tensor) -> torch.Tensor:
    method forward (line 81) | def forward(self, mels: torch.Tensor) -> Tuple[torch.Tensor, torch.Ten...
    method tokenize (line 100) | def tokenize(self, mels: torch.Tensor) -> torch.Tensor:
    method detokenize (line 107) | def detokenize(self, indices: torch.Tensor) -> torch.Tensor:

FILE: sparktts/modules/vq/factorized_vector_quantize.py
  function WNConv1d (line 28) | def WNConv1d(*args, **kwargs):
  function ema_inplace (line 32) | def ema_inplace(moving_avg, new, decay):
  class FactorizedVectorQuantize (line 36) | class FactorizedVectorQuantize(nn.Module):
    method __init__ (line 37) | def __init__(
    method forward (line 70) | def forward(self, z: torch.Tensor) -> Dict[str, Any]:
    method vq2emb (line 142) | def vq2emb(self, vq, out_proj=True):
    method tokenize (line 148) | def tokenize(self, z: torch.Tensor) -> torch.Tensor:
    method detokenize (line 154) | def detokenize(self, indices):
    method get_emb (line 160) | def get_emb(self):
    method embed_code (line 163) | def embed_code(self, embed_id):
    method decode_code (line 166) | def decode_code(self, embed_id):
    method decode_latents (line 169) | def decode_latents(self, latents):

FILE: sparktts/utils/audio.py
  function audio_volume_normalize (line 33) | def audio_volume_normalize(audio: np.ndarray, coeff: float = 0.2) -> np....
  function load_audio (line 76) | def load_audio(
  function random_select_audio_segment (line 122) | def random_select_audio_segment(audio: np.ndarray, length: int) -> np.nd...
  function audio_highpass_filter (line 137) | def audio_highpass_filter(audio, sample_rate, highpass_cutoff_freq):
  function stft (line 152) | def stft(
  function detect_speech_boundaries (line 186) | def detect_speech_boundaries(
  function remove_silence_on_both_ends (line 228) | def remove_silence_on_both_ends(
  function hertz_to_mel (line 258) | def hertz_to_mel(pitch: float) -> float:

FILE: sparktts/utils/file.py
  function resolve_symbolic_link (line 34) | def resolve_symbolic_link(symbolic_link_path: Path) -> Path:
  function write_jsonl (line 50) | def write_jsonl(metadata: List[dict], file_path: Path) -> None:
  function read_jsonl (line 69) | def read_jsonl(file_path: Path) -> List[dict]:
  function read_json_as_jsonl (line 94) | def read_json_as_jsonl(file_path: Path) -> List[dict]:
  function decode_unicode_strings (line 106) | def decode_unicode_strings(meta: Dict[str, Any]) -> Dict[str, Any]:
  function load_config (line 116) | def load_config(config_path: Path) -> DictConfig:
  function jsonl_to_csv (line 134) | def jsonl_to_csv(jsonl_file_path: str, csv_file_path: str) -> None:
  function save_metadata (line 169) | def save_metadata(data, filename, headers=None):
  function read_metadata (line 192) | def read_metadata(filename, headers=None):

FILE: sparktts/utils/token_parser.py
  class TokenParser (line 66) | class TokenParser:
    method __init__ (line 69) | def __init__(self):
    method __init__ (line 74) | def __init__(self):
    method age (line 78) | def age(age: str) -> str:
    method gender (line 84) | def gender(gender: str) -> str:
    method mel_value (line 90) | def mel_value(mel: int):
    method mel_level (line 97) | def mel_level(level: str):
    method pitch_var_value (line 103) | def pitch_var_value(pitch_std: int):
    method pitch_var_level (line 111) | def pitch_var_level(level: str):
    method loudness_value (line 117) | def loudness_value(loudness: int):
    method loudness_level (line 125) | def loudness_level(level: str):
    method speed_value (line 131) | def speed_value(speed: int):
    method speed_level (line 138) | def speed_level(level: str):
    method task (line 144) | def task(task: str) -> str:
    method emotion (line 151) | def emotion(emotion: str):

FILE: webui.py
  function initialize_model (line 29) | def initialize_model(model_dir="pretrained_models/Spark-TTS-0.5B", devic...
  function run_tts (line 51) | def run_tts(
  function build_ui (line 94) | def build_ui(model_dir, device=0):
  function parse_arguments (line 224) | def parse_arguments():
Condensed preview — 44 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (302K chars).
[
  {
    "path": ".gitignore",
    "chars": 3470,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\npretrained_models/\nresults/\ndemo/\n# C extensio"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 11174,
    "preview": "<div align=\"center\">\n    <h1>\n    Spark-TTS\n    </h1>\n    <p>\n    Official PyTorch code for inference of <br>\n    <b><em"
  },
  {
    "path": "cli/SparkTTS.py",
    "chars": 8110,
    "preview": "# Copyright (c) 2025 SparkAudio\n#               2025 Xinsheng Wang (w.xinshawn@gmail.com)\n#\n# Licensed under the Apache "
  },
  {
    "path": "cli/inference.py",
    "chars": 3726,
    "preview": "# Copyright (c) 2025 SparkAudio\n#               2025 Xinsheng Wang (w.xinshawn@gmail.com)\n#\n# Licensed under the Apache "
  },
  {
    "path": "example/infer.sh",
    "chars": 1424,
    "preview": "#!/bin/bash\n\n# Copyright (c) 2025 SparkAudio\n#               2025 Xinsheng Wang (w.xinshawn@gmail.com)\n#\n# Licensed unde"
  },
  {
    "path": "requirements.txt",
    "chars": 206,
    "preview": "einops==0.8.1\neinx==0.3.0\nnumpy==2.2.3\nomegaconf==2.3.0\npackaging==24.2\nsafetensors==0.5.2\nsoundfile==0.12.1\nsoxr==0.5.0"
  },
  {
    "path": "runtime/triton_trtllm/Dockerfile.server",
    "chars": 377,
    "preview": "FROM nvcr.io/nvidia/tritonserver:25.02-trtllm-python-py3\nRUN apt-get update && apt-get install -y cmake\nRUN git clone ht"
  },
  {
    "path": "runtime/triton_trtllm/README.md",
    "chars": 4631,
    "preview": "## Nvidia Triton Inference Serving Best Practice for Spark TTS\n\n### Quick Start\nDirectly launch the service using docker"
  },
  {
    "path": "runtime/triton_trtllm/client_grpc.py",
    "chars": 33102,
    "preview": "#!/usr/bin/env python3\n# Copyright      2022  Xiaomi Corp.        (authors: Fangjun Kuang)\n#                2023  Nvidia"
  },
  {
    "path": "runtime/triton_trtllm/client_http.py",
    "chars": 5367,
    "preview": "# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n#\n# Redistribution and use in source and binary "
  },
  {
    "path": "runtime/triton_trtllm/docker-compose.yml",
    "chars": 558,
    "preview": "services:\n  tts:\n    image: soar97/triton-spark-tts:25.02\n    shm_size: '1gb'\n    ports:\n      - \"8000:8000\"\n      - \"80"
  },
  {
    "path": "runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py",
    "chars": 5399,
    "preview": "# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n#\n# Redistribution and use in source and binary "
  },
  {
    "path": "runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt",
    "chars": 1253,
    "preview": "# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the "
  },
  {
    "path": "runtime/triton_trtllm/model_repo/spark_tts/1/model.py",
    "chars": 18005,
    "preview": "# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n#\n# Redistribution and use in source and binary "
  },
  {
    "path": "runtime/triton_trtllm/model_repo/spark_tts/config.pbtxt",
    "chars": 1902,
    "preview": "# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the "
  },
  {
    "path": "runtime/triton_trtllm/model_repo/tensorrt_llm/1/.gitkeep",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "runtime/triton_trtllm/model_repo/tensorrt_llm/config.pbtxt",
    "chars": 18057,
    "preview": "# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n#\n# Redistribution and use in source and binary "
  },
  {
    "path": "runtime/triton_trtllm/model_repo/vocoder/1/model.py",
    "chars": 4514,
    "preview": "# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n#\n# Redistribution and use in source and binary "
  },
  {
    "path": "runtime/triton_trtllm/model_repo/vocoder/config.pbtxt",
    "chars": 1163,
    "preview": "# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the "
  },
  {
    "path": "runtime/triton_trtllm/run.sh",
    "chars": 4997,
    "preview": "export PYTHONPATH=../../../Spark-TTS/\nexport CUDA_VISIBLE_DEVICES=0\nstage=$1\nstop_stage=$2\nservice_type=$3\necho \"Start s"
  },
  {
    "path": "runtime/triton_trtllm/scripts/convert_checkpoint.py",
    "chars": 13074,
    "preview": "import argparse\nimport os\nimport time\nimport traceback\nfrom concurrent.futures import ThreadPoolExecutor, as_completed\n\n"
  },
  {
    "path": "runtime/triton_trtllm/scripts/fill_template.py",
    "chars": 1837,
    "preview": "#! /usr/bin/env python3\nfrom argparse import ArgumentParser\nfrom string import Template\n\n\ndef split(string, delimiter):\n"
  },
  {
    "path": "sparktts/models/audio_tokenizer.py",
    "chars": 5981,
    "preview": "# Copyright (c) 2025 SparkAudio\n#               2025 Xinsheng Wang (w.xinshawn@gmail.com)\n#\n# Licensed under the Apache "
  },
  {
    "path": "sparktts/models/bicodec.py",
    "chars": 8256,
    "preview": "# Copyright (c) 2025 SparkAudio\n#               2025 Xinsheng Wang (w.xinshawn@gmail.com)\n#\n# Licensed under the Apache "
  },
  {
    "path": "sparktts/modules/blocks/layers.py",
    "chars": 2161,
    "preview": "# Copyright (c) 2025 SparkAudio\n#               2025 Xinsheng Wang (w.xinshawn@gmail.com)\n#\n# Licensed under the Apache "
  },
  {
    "path": "sparktts/modules/blocks/samper.py",
    "chars": 3938,
    "preview": "# Copyright (c) 2025 SparkAudio\n#               2025 Xinsheng Wang (w.xinshawn@gmail.com)\n#\n# Licensed under the Apache "
  },
  {
    "path": "sparktts/modules/blocks/vocos.py",
    "chars": 13041,
    "preview": "# Copyright (c) 2025 SparkAudio\n#               2025 Xinsheng Wang (w.xinshawn@gmail.com)\n#\n# Licensed under the Apache "
  },
  {
    "path": "sparktts/modules/encoder_decoder/feat_decoder.py",
    "chars": 3555,
    "preview": "# Copyright (c) 2025 SparkAudio\n#               2025 Xinsheng Wang (w.xinshawn@gmail.com)\n#\n# Licensed under the Apache "
  },
  {
    "path": "sparktts/modules/encoder_decoder/feat_encoder.py",
    "chars": 3145,
    "preview": "# Copyright (c) 2025 SparkAudio\n#               2025 Xinsheng Wang (w.xinshawn@gmail.com)\n#\n# Licensed under the Apache "
  },
  {
    "path": "sparktts/modules/encoder_decoder/wave_generator.py",
    "chars": 2500,
    "preview": "# Copyright (c) 2024 Xinsheng Wang (w.xinshawn@gmail.com)\n#\n# Licensed under the Apache License, Version 2.0 (the \"Licen"
  },
  {
    "path": "sparktts/modules/fsq/finite_scalar_quantization.py",
    "chars": 7502,
    "preview": "\"\"\"\nFinite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505\nCode adapted from Jax version in A"
  },
  {
    "path": "sparktts/modules/fsq/residual_fsq.py",
    "chars": 10346,
    "preview": "import random\nimport torch\nimport torch.nn.functional as F\nimport torch.distributed as dist\n\nfrom typing import List\nfro"
  },
  {
    "path": "sparktts/modules/speaker/ecapa_tdnn.py",
    "chars": 7563,
    "preview": "# Copyright (c) 2021 Zhengyang Chen (chenzhengyang117@gmail.com)\n#               2022 Hongji Wang (jijijiang77@gmail.com"
  },
  {
    "path": "sparktts/modules/speaker/perceiver_encoder.py",
    "chars": 10552,
    "preview": "# Copyright (c) 2025 SparkAudio\n#               2025 Xinsheng Wang (w.xinshawn@gmail.com)\n#\n# Licensed under the Apache "
  },
  {
    "path": "sparktts/modules/speaker/pooling_layers.py",
    "chars": 10322,
    "preview": "# Copyright (c) 2021 Shuai Wang (wsstriving@gmail.com)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\""
  },
  {
    "path": "sparktts/modules/speaker/speaker_encoder.py",
    "chars": 4693,
    "preview": "# Copyright (c) 2025 SparkAudio\n#               2025 Xinsheng Wang (w.xinshawn@gmail.com)\n#\n# Licensed under the Apache "
  },
  {
    "path": "sparktts/modules/vq/factorized_vector_quantize.py",
    "chars": 6408,
    "preview": "# Copyright (c) 2025 SparkAudio\n#               2025 Xinsheng Wang (w.xinshawn@gmail.com)\n#\n# Licensed under the Apache "
  },
  {
    "path": "sparktts/utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "sparktts/utils/audio.py",
    "chars": 8462,
    "preview": "# Copyright (c) 2025 SparkAudio\n#               2025 Xinsheng Wang (w.xinshawn@gmail.com)\n#\n# Licensed under the Apache "
  },
  {
    "path": "sparktts/utils/file.py",
    "chars": 7170,
    "preview": "# Copyright (c) 2025 SparkAudio\n#               2025 Xinsheng Wang (w.xinshawn@gmail.com)\n#\n# Licensed under the Apache "
  },
  {
    "path": "sparktts/utils/parse_options.sh",
    "chars": 3663,
    "preview": "#!/bin/bash\n\n# Copyright 2012  Johns Hopkins University (Author: Daniel Povey);\n#                 Arnab Ghoshal, Karel V"
  },
  {
    "path": "sparktts/utils/token_parser.py",
    "chars": 5176,
    "preview": "TASK_TOKEN_MAP = {\n    \"vc\": \"<|task_vc|>\",\n    \"tts\": \"<|task_tts|>\",\n    \"asr\": \"<|task_asr|>\",\n    \"s2s\": \"<|task_s2s"
  },
  {
    "path": "webui.py",
    "chars": 8971,
    "preview": "# Copyright (c) 2025 SparkAudio\n#               2025 Xinsheng Wang (w.xinshawn@gmail.com)\n#\n# Licensed under the Apache "
  }
]

About this extraction

This page contains the full source code of the SparkAudio/Spark-TTS GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 44 files (280.4 KB), approximately 69.9k tokens, and a symbol index with 273 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!