Repository: SWivid/F5-TTS
Branch: main
Commit: 623c96c29496
Files: 90
Total size: 571.3 KB
Directory structure:
gitextract_4jvakfwh/
├── .github/
│ ├── ISSUE_TEMPLATE/
│ │ ├── bug_report.yml
│ │ ├── config.yml
│ │ ├── feature_request.yml
│ │ ├── help_wanted.yml
│ │ └── question.yml
│ └── workflows/
│ ├── pre-commit.yaml
│ ├── publish-docker-image.yaml
│ └── publish-pypi.yaml
├── .gitignore
├── .gitmodules
├── .pre-commit-config.yaml
├── Dockerfile
├── LICENSE
├── README.md
├── pyproject.toml
├── ruff.toml
└── src/
└── f5_tts/
├── api.py
├── configs/
│ ├── E2TTS_Base.yaml
│ ├── E2TTS_Small.yaml
│ ├── F5TTS_Base.yaml
│ ├── F5TTS_Small.yaml
│ └── F5TTS_v1_Base.yaml
├── eval/
│ ├── README.md
│ ├── ecapa_tdnn.py
│ ├── eval_infer_batch.py
│ ├── eval_infer_batch.sh
│ ├── eval_infer_batch_example.sh
│ ├── eval_librispeech_test_clean.py
│ ├── eval_seedtts_testset.py
│ ├── eval_utmos.py
│ └── utils_eval.py
├── infer/
│ ├── README.md
│ ├── SHARED.md
│ ├── examples/
│ │ ├── basic/
│ │ │ └── basic.toml
│ │ ├── multi/
│ │ │ ├── country.flac
│ │ │ ├── main.flac
│ │ │ ├── story.toml
│ │ │ ├── story.txt
│ │ │ └── town.flac
│ │ └── vocab.txt
│ ├── infer_cli.py
│ ├── infer_gradio.py
│ ├── speech_edit.py
│ └── utils_infer.py
├── model/
│ ├── __init__.py
│ ├── backbones/
│ │ ├── README.md
│ │ ├── dit.py
│ │ ├── mmdit.py
│ │ └── unett.py
│ ├── cfm.py
│ ├── dataset.py
│ ├── modules.py
│ ├── trainer.py
│ └── utils.py
├── runtime/
│ └── triton_trtllm/
│ ├── .gitignore
│ ├── Dockerfile.server
│ ├── README.md
│ ├── benchmark.py
│ ├── client_grpc.py
│ ├── client_http.py
│ ├── docker-compose.yml
│ ├── model_repo_f5_tts/
│ │ ├── f5_tts/
│ │ │ ├── 1/
│ │ │ │ ├── f5_tts_trtllm.py
│ │ │ │ └── model.py
│ │ │ └── config.pbtxt
│ │ └── vocoder/
│ │ ├── 1/
│ │ │ └── .gitkeep
│ │ └── config.pbtxt
│ ├── patch/
│ │ ├── __init__.py
│ │ └── f5tts/
│ │ ├── model.py
│ │ └── modules.py
│ ├── run.sh
│ └── scripts/
│ ├── conv_stft.py
│ ├── convert_checkpoint.py
│ ├── export_vocoder_to_onnx.py
│ ├── export_vocos_trt.sh
│ └── fill_template.py
├── scripts/
│ ├── count_max_epoch.py
│ ├── count_max_epoch_precise.py
│ └── count_params_gflops.py
├── socket_client.py
├── socket_server.py
└── train/
├── README.md
├── datasets/
│ ├── prepare_csv_wavs.py
│ ├── prepare_emilia.py
│ ├── prepare_emilia_v2.py
│ ├── prepare_libritts.py
│ ├── prepare_ljspeech.py
│ └── prepare_wenetspeech4tts.py
├── finetune_cli.py
├── finetune_gradio.py
└── train.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/ISSUE_TEMPLATE/bug_report.yml
================================================
name: "Bug Report"
description: |
Please provide as much details to help address the issue more efficiently, including input, output, logs and screenshots.
labels:
- bug
body:
- type: checkboxes
attributes:
label: Checks
description: "To ensure timely help, please confirm the following:"
options:
- label: This template is only for bug reports, usage problems go with 'Help Wanted'.
required: true
- label: I have thoroughly reviewed the project documentation but couldn't find information to solve my problem.
required: true
- label: I have searched for existing issues, including closed ones, and couldn't find a solution.
required: true
- label: I am using English to submit this issue to facilitate community communication.
required: true
- type: textarea
attributes:
label: Environment Details
description: "Provide details including OS, GPU info, Python version, any relevant software or dependencies, and trainer setting."
placeholder: e.g., CentOS Linux 7, 4 * RTX 3090, Python 3.10, torch==2.3.0+cu118, cuda 11.8, config yaml is ...
validations:
required: true
- type: textarea
attributes:
label: Steps to Reproduce
description: |
Include detailed steps, screenshots, and logs. Use the correct markdown syntax for code blocks.
placeholder: |
1. Create a new conda environment.
2. Clone the repository, install as local editable and properly set up.
3. Run the command: `accelerate launch src/f5_tts/train/train.py`.
4. Have following error message... (attach logs).
validations:
required: true
- type: textarea
attributes:
label: ✔️ Expected Behavior
placeholder: Describe in detail what you expected to happen.
validations:
required: false
- type: textarea
attributes:
label: ❌ Actual Behavior
placeholder: Describe in detail what actually happened.
validations:
required: false
================================================
FILE: .github/ISSUE_TEMPLATE/config.yml
================================================
blank_issues_enabled: false
================================================
FILE: .github/ISSUE_TEMPLATE/feature_request.yml
================================================
name: "Feature Request"
description: |
Some constructive suggestions and new ideas regarding current repo.
labels:
- enhancement
body:
- type: checkboxes
attributes:
label: Checks
description: "To help us grasp quickly, please confirm the following:"
options:
- label: This template is only for feature request.
required: true
- label: I have thoroughly reviewed the project documentation but couldn't find any relevant information that meets my needs.
required: true
- label: I have searched for existing issues, including closed ones, and found not discussion yet.
required: true
- label: I am using English to submit this issue to facilitate community communication.
required: true
- type: textarea
attributes:
label: 1. Is this request related to a challenge you're experiencing? Tell us your story.
description: |
Describe the specific problem or scenario you're facing in detail. For example:
*"I was trying to use [feature] for [specific task], but encountered [issue]. This was frustrating because...."*
placeholder: Please describe the situation in as much detail as possible.
validations:
required: true
- type: textarea
attributes:
label: 2. What is your suggested solution?
description: |
Provide a clear description of the feature or enhancement you'd like to propose.
How would this feature solve your issue or improve the project?
placeholder: Describe your idea or proposed solution here.
validations:
required: true
- type: textarea
attributes:
label: 3. Additional context or comments
description: |
Any other relevant information, links, documents, or screenshots that provide clarity.
Use this section for anything not covered above.
placeholder: Add any extra details here.
validations:
required: false
- type: checkboxes
attributes:
label: 4. Can you help us with this feature?
description: |
Let us know if you're interested in contributing. This is not a commitment but a way to express interest in collaboration.
options:
- label: I am interested in contributing to this feature.
required: false
- type: markdown
attributes:
value: |
**Note:** Please submit only one request per issue to keep discussions focused and manageable.
================================================
FILE: .github/ISSUE_TEMPLATE/help_wanted.yml
================================================
name: "Help Wanted"
description: |
Please provide as much details to help address the issue more efficiently, including input, output, logs and screenshots.
labels:
- help wanted
body:
- type: checkboxes
attributes:
label: Checks
description: "To ensure timely help, please confirm the following:"
options:
- label: This template is only for usage issues encountered.
required: true
- label: I have thoroughly reviewed the project documentation but couldn't find information to solve my problem.
required: true
- label: I have searched for existing issues, including closed ones, and couldn't find a solution.
required: true
- label: I am using English to submit this issue to facilitate community communication.
required: true
- type: textarea
attributes:
label: Environment Details
description: "Provide details such as OS, Python version, and any relevant software or dependencies."
placeholder: |
e.g., macOS 13.5, Python 3.10, torch==2.3.0, Gradio 4.44.1
If training or finetuning related, provide detailed configuration including GPU info and training setup.
validations:
required: true
- type: textarea
attributes:
label: Steps to Reproduce
description: |
Include detailed steps, screenshots, and logs. Provide used prompt wav and text. Use the correct markdown syntax for code blocks.
placeholder: |
1. Create a new conda environment.
2. Clone the repository and install as pip package.
3. Run the command: `f5-tts_infer-gradio` with no ref_text provided.
4. Stuck there with the following message... (attach logs and also error msg e.g. after ctrl-c).
5. Prompt & generated wavs are [change suffix to .mp4 to enable direct upload or pack all to .zip].
6. Reference audio's transcription or provided ref_text is `xxx`, and text to generate is `xxx`.
validations:
required: true
- type: textarea
attributes:
label: ✔️ Expected Behavior
placeholder: Describe what you expected to happen in detail, e.g. output a generated audio.
validations:
required: false
- type: textarea
attributes:
label: ❌ Actual Behavior
placeholder: Describe what actually happened in detail, failure messages, etc.
validations:
required: false
================================================
FILE: .github/ISSUE_TEMPLATE/question.yml
================================================
name: "Question"
description: |
Research question or pure inquiry about the project, usage issue goes with "help wanted".
labels:
- question
body:
- type: checkboxes
attributes:
label: Checks
description: "To help us grasp quickly, please confirm the following:"
options:
- label: This template is only for research question, not usage problems, feature requests or bug reports.
required: true
- label: I have thoroughly reviewed the project documentation and read the related paper(s).
required: true
- label: I have searched for existing issues, including closed ones, no similar questions.
required: true
- label: I am using English to submit this issue to facilitate community communication.
required: true
- type: textarea
attributes:
label: Question details
description: |
Question details, clearly stated using proper markdown syntax.
validations:
required: true
================================================
FILE: .github/workflows/pre-commit.yaml
================================================
name: pre-commit
on:
pull_request:
push:
branches: [main]
jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
- uses: pre-commit/action@v3.0.1
================================================
FILE: .github/workflows/publish-docker-image.yaml
================================================
name: Create and publish a Docker image
# Configures this workflow to run every time a change is pushed to the branch called `release`.
on:
push:
branches: ['main']
# Defines two custom environment variables for the workflow. These are used for the Container registry domain, and a name for the Docker image that this workflow builds.
env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}
# There is a single job in this workflow. It's configured to run on the latest available version of Ubuntu.
jobs:
build-and-push-image:
runs-on: ubuntu-latest
# Sets the permissions granted to the `GITHUB_TOKEN` for the actions in this job.
permissions:
contents: read
packages: write
#
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Free Up GitHub Actions Ubuntu Runner Disk Space 🔧
uses: jlumbroso/free-disk-space@main
with:
# This might remove tools that are actually needed, if set to "true" but frees about 6 GB
tool-cache: false
# All of these default to true, but feel free to set to "false" if necessary for your workflow
android: true
dotnet: true
haskell: true
large-packages: false
swap-storage: false
docker-images: false
# Uses the `docker/login-action` action to log in to the Container registry registry using the account and password that will publish the packages. Once published, the packages are scoped to the account defined here.
- name: Log in to the Container registry
uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
# This step uses [docker/metadata-action](https://github.com/docker/metadata-action#about) to extract tags and labels that will be applied to the specified image. The `id` "meta" allows the output of this step to be referenced in a subsequent step. The `images` value provides the base name for the tags and labels.
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
# This step uses the `docker/build-push-action` action to build the image, based on your repository's `Dockerfile`. If the build succeeds, it pushes the image to GitHub Packages.
# It uses the `context` parameter to define the build's context as the set of files located in the specified path. For more information, see "[Usage](https://github.com/docker/build-push-action#usage)" in the README of the `docker/build-push-action` repository.
# It uses the `tags` and `labels` parameters to tag and label the image with the output from the "meta" step.
- name: Build and push Docker image
uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4
with:
context: .
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
================================================
FILE: .github/workflows/publish-pypi.yaml
================================================
# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.
# GitHub recommends pinning actions to a commit SHA.
# To get a newer version, you will need to update the SHA.
# You can also reference a tag or branch, but the action may change without warning.
name: Upload Python Package
on:
release:
types: [published]
permissions:
contents: read
jobs:
release-build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.x"
- name: Build release distributions
run: |
# NOTE: put your own distribution build steps here.
python -m pip install build
python -m build
- name: Upload distributions
uses: actions/upload-artifact@v4
with:
name: release-dists
path: dist/
pypi-publish:
runs-on: ubuntu-latest
needs:
- release-build
permissions:
# IMPORTANT: this permission is mandatory for trusted publishing
id-token: write
# Dedicated environments with protections for publishing are strongly recommended.
environment:
name: pypi
# OPTIONAL: uncomment and update to include your PyPI project URL in the deployment status:
# url: https://pypi.org/p/YOURPROJECT
steps:
- name: Retrieve release distributions
uses: actions/download-artifact@v4
with:
name: release-dists
path: dist/
- name: Publish release distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
================================================
FILE: .gitignore
================================================
# Customed
.vscode/
tests/
runs/
data/
ckpts/
wandb/
results/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
================================================
FILE: .gitmodules
================================================
[submodule "src/third_party/BigVGAN"]
path = src/third_party/BigVGAN
url = https://github.com/NVIDIA/BigVGAN.git
================================================
FILE: .pre-commit-config.yaml
================================================
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.11.2
hooks:
- id: ruff
name: ruff linter
args: [--fix]
- id: ruff-format
name: ruff formatter
- id: ruff
name: ruff sorter
args: [--select, I, --fix]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: check-yaml
================================================
FILE: Dockerfile
================================================
FROM pytorch/pytorch:2.4.0-cuda12.4-cudnn9-devel
USER root
ARG DEBIAN_FRONTEND=noninteractive
LABEL github_repo="https://github.com/SWivid/F5-TTS"
RUN set -x \
&& apt-get update \
&& apt-get -y install wget curl man git less openssl libssl-dev unzip unar build-essential aria2 tmux vim \
&& apt-get install -y openssh-server sox libsox-fmt-all libsox-fmt-mp3 libsndfile1-dev ffmpeg \
&& apt-get install -y librdmacm1 libibumad3 librdmacm-dev libibverbs1 libibverbs-dev ibverbs-utils ibverbs-providers \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
WORKDIR /workspace
RUN git clone https://github.com/SWivid/F5-TTS.git \
&& cd F5-TTS \
&& git submodule update --init --recursive \
&& pip install -e . --no-cache-dir
ENV SHELL=/bin/bash
VOLUME /root/.cache/huggingface/hub/
EXPOSE 7860
WORKDIR /workspace/F5-TTS
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2024 Yushen CHEN
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching
[](https://github.com/SWivid/F5-TTS)
[](https://arxiv.org/abs/2410.06885)
[](https://swivid.github.io/F5-TTS/)
[](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
[](https://modelscope.cn/studios/AI-ModelScope/E2-F5-TTS)
[](https://x-lance.sjtu.edu.cn/)
[](https://www.sii.edu.cn/)
[](https://www.pcl.ac.cn)
**F5-TTS**: Diffusion Transformer with ConvNeXt V2, faster trained and inference.
**E2 TTS**: Flat-UNet Transformer, closest reproduction from [paper](https://arxiv.org/abs/2406.18009).
**Sway Sampling**: Inference-time flow step sampling strategy, greatly improves performance
### Thanks to all the contributors !
## News
- **2025/03/12**: 🔥 F5-TTS v1 base model with better training and inference performance. [Few demo](https://swivid.github.io/F5-TTS_updates).
- **2024/10/08**: F5-TTS & E2 TTS base models on [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS), [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), [🟣 Wisemodel](https://wisemodel.cn/models/SJTU_X-LANCE/F5-TTS_Emilia-ZH-EN).
## Installation
### Create a separate environment if needed
```bash
# Create a conda env with python_version>=3.10 (you could also use virtualenv)
conda create -n f5-tts python=3.11
conda activate f5-tts
# Install FFmpeg if you haven't yet
conda install ffmpeg
```
### Install PyTorch with matched device
NVIDIA GPU
> ```bash
> # Install pytorch with your CUDA version, e.g.
> pip install torch==2.8.0+cu128 torchaudio==2.8.0+cu128 --extra-index-url https://download.pytorch.org/whl/cu128
>
> # And also possible previous versions, e.g.
> pip install torch==2.4.0+cu124 torchaudio==2.4.0+cu124 --extra-index-url https://download.pytorch.org/whl/cu124
> # etc.
> ```
AMD GPU
> ```bash
> # Install pytorch with your ROCm version (Linux only), e.g.
> pip install torch==2.5.1+rocm6.2 torchaudio==2.5.1+rocm6.2 --extra-index-url https://download.pytorch.org/whl/rocm6.2
> ```
Intel GPU
> ```bash
> # Install pytorch with your XPU version, e.g.
> # Intel® Deep Learning Essentials or Intel® oneAPI Base Toolkit must be installed
> pip install torch torchaudio --index-url https://download.pytorch.org/whl/test/xpu
>
> # Intel GPU support is also available through IPEX (Intel® Extension for PyTorch)
> # IPEX does not require the Intel® Deep Learning Essentials or Intel® oneAPI Base Toolkit
> # See: https://pytorch-extension.intel.com/installation?request=platform
> ```
Apple Silicon
> ```bash
> # Install the stable pytorch, e.g.
> pip install torch torchaudio
> ```
### Then you can choose one from below:
> ### 1. As a pip package (if just for inference)
>
> ```bash
> pip install f5-tts
> ```
>
> ### 2. Local editable (if also do training, finetuning)
>
> ```bash
> git clone https://github.com/SWivid/F5-TTS.git
> cd F5-TTS
> # git submodule update --init --recursive # (optional, if use bigvgan as vocoder)
> pip install -e .
> ```
### Docker usage also available
```bash
# Build from Dockerfile
docker build -t f5tts:v1 .
# Run from GitHub Container Registry
docker container run --rm -it --gpus=all --mount 'type=volume,source=f5-tts,target=/root/.cache/huggingface/hub/' -p 7860:7860 ghcr.io/swivid/f5-tts:main
# Quickstart if you want to just run the web interface (not CLI)
docker container run --rm -it --gpus=all --mount 'type=volume,source=f5-tts,target=/root/.cache/huggingface/hub/' -p 7860:7860 ghcr.io/swivid/f5-tts:main f5-tts_infer-gradio --host 0.0.0.0
```
### Runtime
Deployment solution with Triton and TensorRT-LLM.
#### Benchmark Results
Decoding on a single L20 GPU, using 26 different prompt_audio & target_text pairs, 16 NFE.
| Model | Concurrency | Avg Latency | RTF | Mode |
|---------------------|----------------|-------------|--------|-----------------|
| F5-TTS Base (Vocos) | 2 | 253 ms | 0.0394 | Client-Server |
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.0402 | Offline TRT-LLM |
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.1467 | Offline Pytorch |
See [detailed instructions](src/f5_tts/runtime/triton_trtllm/README.md) for more information.
## Inference
- In order to achieve desired performance, take a moment to read [detailed guidance](src/f5_tts/infer).
- By properly searching the keywords of problem encountered, [issues](https://github.com/SWivid/F5-TTS/issues?q=is%3Aissue) are very helpful.
### 1. Gradio App
Currently supported features:
- Basic TTS with Chunk Inference
- Multi-Style / Multi-Speaker Generation
- Voice Chat powered by Qwen2.5-3B-Instruct
- [Custom inference with more language support](src/f5_tts/infer/SHARED.md)
```bash
# Launch a Gradio app (web interface)
f5-tts_infer-gradio
# Specify the port/host
f5-tts_infer-gradio --port 7860 --host 0.0.0.0
# Launch a share link
f5-tts_infer-gradio --share
```
NVIDIA device docker compose file example
```yaml
services:
f5-tts:
image: ghcr.io/swivid/f5-tts:main
ports:
- "7860:7860"
environment:
GRADIO_SERVER_PORT: 7860
entrypoint: ["f5-tts_infer-gradio", "--port", "7860", "--host", "0.0.0.0"]
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
volumes:
f5-tts:
driver: local
```
### 2. CLI Inference
```bash
# Run with flags
# Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
f5-tts_infer-cli --model F5TTS_v1_Base \
--ref_audio "provide_prompt_wav_path_here.wav" \
--ref_text "The content, subtitle or transcription of reference audio." \
--gen_text "Some text you want TTS model generate for you."
# Run with default setting. src/f5_tts/infer/examples/basic/basic.toml
f5-tts_infer-cli
# Or with your own .toml file
f5-tts_infer-cli -c custom.toml
# Multi voice. See src/f5_tts/infer/README.md
f5-tts_infer-cli -c src/f5_tts/infer/examples/multi/story.toml
```
## Training
### 1. With Hugging Face Accelerate
Refer to [training & finetuning guidance](src/f5_tts/train) for best practice.
### 2. With Gradio App
```bash
# Quick start with Gradio web interface
f5-tts_finetune-gradio
```
Read [training & finetuning guidance](src/f5_tts/train) for more instructions.
## [Evaluation](src/f5_tts/eval)
## Development
Use pre-commit to ensure code quality (will run linters and formatters automatically):
```bash
pip install pre-commit
pre-commit install
```
When making a pull request, before each commit, run:
```bash
pre-commit run --all-files
```
Note: Some model components have linting exceptions for E722 to accommodate tensor notation.
## Acknowledgements
- [E2-TTS](https://arxiv.org/abs/2406.18009) brilliant work, simple and effective
- [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763), [LibriTTS](https://arxiv.org/abs/1904.02882), [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) valuable datasets
- [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion
- [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure
- [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) and [BigVGAN](https://github.com/NVIDIA/BigVGAN) as vocoder
- [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech), [SpeechMOS](https://github.com/tarepan/SpeechMOS) for evaluation tools
- [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
- [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
- [f5-tts-mlx](https://github.com/lucasnewman/f5-tts-mlx/tree/main) Implementation with MLX framework by [Lucas Newman](https://github.com/lucasnewman)
- [F5-TTS-ONNX](https://github.com/DakeQQ/F5-TTS-ONNX) ONNX Runtime version by [DakeQQ](https://github.com/DakeQQ)
- [Yuekai Zhang](https://github.com/yuekaizhang) Triton and TensorRT-LLM support ~
## Citation
If our work and codebase is useful for you, please cite as:
```
@article{chen-etal-2024-f5tts,
title={F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching},
author={Yushen Chen and Zhikang Niu and Ziyang Ma and Keqi Deng and Chunhui Wang and Jian Zhao and Kai Yu and Xie Chen},
journal={arXiv preprint arXiv:2410.06885},
year={2024},
}
```
## License
Our code is released under MIT License. The pre-trained models are licensed under the CC-BY-NC license due to the training data Emilia, which is an in-the-wild dataset. Sorry for any inconvenience this may cause.
================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["setuptools >= 61.0", "setuptools-scm>=8.0"]
build-backend = "setuptools.build_meta"
[project]
name = "f5-tts"
version = "1.1.17"
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
readme = "README.md"
license = {text = "MIT License"}
classifiers = [
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
]
dependencies = [
"accelerate>=0.33.0",
"bitsandbytes>0.37.0; platform_machine!='arm64' and platform_system!='Darwin'",
"cached_path",
"click",
"datasets",
"ema_pytorch>=0.5.2",
"gradio>=6.0.0",
"hydra-core>=1.3.0",
"librosa",
"matplotlib",
"numpy<=1.26.4; python_version<='3.10'",
"pydub",
"pypinyin",
"rjieba",
"safetensors",
"soundfile",
"tomli",
"torch>=2.0.0",
"torchaudio>=2.0.0",
"torchcodec",
"torchdiffeq",
"tqdm>=4.65.0",
"transformers",
"transformers_stream_generator",
"unidecode",
"vocos",
"wandb",
"x_transformers>=1.31.14",
]
[project.optional-dependencies]
eval = [
"faster_whisper==0.10.1",
"funasr",
"jiwer",
"modelscope",
"zhconv",
"zhon",
]
[project.urls]
Homepage = "https://github.com/SWivid/F5-TTS"
[project.scripts]
"f5-tts_infer-cli" = "f5_tts.infer.infer_cli:main"
"f5-tts_infer-gradio" = "f5_tts.infer.infer_gradio:main"
"f5-tts_finetune-cli" = "f5_tts.train.finetune_cli:main"
"f5-tts_finetune-gradio" = "f5_tts.train.finetune_gradio:main"
================================================
FILE: ruff.toml
================================================
line-length = 120
target-version = "py310"
[lint]
# Only ignore variables with names starting with "_".
dummy-variable-rgx = "^_.*$"
[lint.isort]
force-single-line = false
lines-after-imports = 2
================================================
FILE: src/f5_tts/api.py
================================================
import random
import sys
from importlib.resources import files
import soundfile as sf
import tqdm
from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf
from f5_tts.infer.utils_infer import (
infer_process,
load_model,
load_vocoder,
preprocess_ref_audio_text,
remove_silence_for_generated_wav,
save_spectrogram,
transcribe,
)
from f5_tts.model.utils import seed_everything
class F5TTS:
def __init__(
self,
model="F5TTS_v1_Base",
ckpt_file="",
vocab_file="",
ode_method="euler",
use_ema=True,
vocoder_local_path=None,
device=None,
hf_cache_dir=None,
):
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
self.target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
self.ode_method = ode_method
self.use_ema = use_ema
if device is not None:
self.device = device
else:
import torch
self.device = (
"cuda"
if torch.cuda.is_available()
else "xpu"
if torch.xpu.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
# Load models
self.vocoder = load_vocoder(
self.mel_spec_type, vocoder_local_path is not None, vocoder_local_path, self.device, hf_cache_dir
)
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
# override for previous models
if model == "F5TTS_Base":
if self.mel_spec_type == "vocos":
ckpt_step = 1200000
elif self.mel_spec_type == "bigvgan":
model = "F5TTS_Base_bigvgan"
ckpt_type = "pt"
elif model == "E2TTS_Base":
repo_name = "E2-TTS"
ckpt_step = 1200000
if not ckpt_file:
ckpt_file = str(
cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}", cache_dir=hf_cache_dir)
)
self.ema_model = load_model(
model_cls, model_arc, ckpt_file, self.mel_spec_type, vocab_file, self.ode_method, self.use_ema, self.device
)
def transcribe(self, ref_audio, language=None):
return transcribe(ref_audio, language)
def export_wav(self, wav, file_wave, remove_silence=False):
sf.write(file_wave, wav, self.target_sample_rate)
if remove_silence:
remove_silence_for_generated_wav(file_wave)
def export_spectrogram(self, spec, file_spec):
save_spectrogram(spec, file_spec)
def infer(
self,
ref_file,
ref_text,
gen_text,
show_info=print,
progress=tqdm,
target_rms=0.1,
cross_fade_duration=0.15,
sway_sampling_coef=-1,
cfg_strength=2,
nfe_step=32,
speed=1.0,
fix_duration=None,
remove_silence=False,
file_wave=None,
file_spec=None,
seed=None,
):
if seed is None:
seed = random.randint(0, sys.maxsize)
seed_everything(seed)
self.seed = seed
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, show_info=show_info)
wav, sr, spec = infer_process(
ref_file,
ref_text,
gen_text,
self.ema_model,
self.vocoder,
self.mel_spec_type,
show_info=show_info,
progress=progress,
target_rms=target_rms,
cross_fade_duration=cross_fade_duration,
nfe_step=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
speed=speed,
fix_duration=fix_duration,
device=self.device,
)
if file_wave is not None:
self.export_wav(wav, file_wave, remove_silence)
if file_spec is not None:
self.export_spectrogram(spec, file_spec)
return wav, sr, spec
if __name__ == "__main__":
f5tts = F5TTS()
wav, sr, spec = f5tts.infer(
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
ref_text="Some call me nature, others call me mother nature.",
gen_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring.",
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")),
seed=None,
)
print("seed :", f5tts.seed)
================================================
FILE: src/f5_tts/configs/E2TTS_Base.yaml
================================================
hydra:
run:
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
datasets:
name: Emilia_ZH_EN # dataset name
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
batch_size_type: frame # frame | sample
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
num_workers: 16
optim:
epochs: 11
learning_rate: 7.5e-5
num_warmup_updates: 20000 # warmup updates
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
max_grad_norm: 1.0 # gradient clipping
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
model:
name: E2TTS_Base
tokenizer: pinyin
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
backbone: UNetT
arch:
dim: 1024
depth: 24
heads: 16
ff_mult: 4
text_mask_padding: False
pe_attn_head: 1
mel_spec:
target_sample_rate: 24000
n_mel_channels: 100
hop_length: 256
win_length: 1024
n_fft: 1024
mel_spec_type: vocos # vocos | bigvgan
vocoder:
is_local: False # use local offline ckpt or not
local_path: null # local vocoder path
ckpts:
logger: wandb # wandb | tensorboard | null
wandb_project: CFM-TTS # wandb project name
wandb_run_name: ${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} # wandb run name
wandb_resume_id: null # wandb run id for resuming, null to auto-detect from checkpoint
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
last_per_updates: 5000 # save last checkpoint per updates
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
================================================
FILE: src/f5_tts/configs/E2TTS_Small.yaml
================================================
hydra:
run:
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
datasets:
name: Emilia_ZH_EN
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
batch_size_type: frame # frame | sample
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
num_workers: 16
optim:
epochs: 11
learning_rate: 7.5e-5
num_warmup_updates: 20000 # warmup updates
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
max_grad_norm: 1.0
bnb_optimizer: False
model:
name: E2TTS_Small
tokenizer: pinyin
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
backbone: UNetT
arch:
dim: 768
depth: 20
heads: 12
ff_mult: 4
text_mask_padding: False
pe_attn_head: 1
mel_spec:
target_sample_rate: 24000
n_mel_channels: 100
hop_length: 256
win_length: 1024
n_fft: 1024
mel_spec_type: vocos # vocos | bigvgan
vocoder:
is_local: False # use local offline ckpt or not
local_path: null # local vocoder path
ckpts:
logger: wandb # wandb | tensorboard | null
wandb_project: CFM-TTS # wandb project name
wandb_run_name: ${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} # wandb run name
wandb_resume_id: null # wandb run id for resuming, null to auto-detect from checkpoint
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
last_per_updates: 5000 # save last checkpoint per updates
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
================================================
FILE: src/f5_tts/configs/F5TTS_Base.yaml
================================================
hydra:
run:
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
datasets:
name: Emilia_ZH_EN # dataset name
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
batch_size_type: frame # frame | sample
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
num_workers: 16
optim:
epochs: 11
learning_rate: 7.5e-5
num_warmup_updates: 20000 # warmup updates
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
max_grad_norm: 1.0 # gradient clipping
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
model:
name: F5TTS_Base # model name
tokenizer: pinyin # tokenizer type
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
backbone: DiT
arch:
dim: 1024
depth: 22
heads: 16
ff_mult: 2
text_dim: 512
text_mask_padding: False
conv_layers: 4
pe_attn_head: 1
attn_backend: torch # torch | flash_attn
attn_mask_enabled: False
checkpoint_activations: False # recompute activations and save memory for extra compute
mel_spec:
target_sample_rate: 24000
n_mel_channels: 100
hop_length: 256
win_length: 1024
n_fft: 1024
mel_spec_type: vocos # vocos | bigvgan
vocoder:
is_local: False # use local offline ckpt or not
local_path: null # local vocoder path
ckpts:
logger: wandb # wandb | tensorboard | null
wandb_project: CFM-TTS # wandb project name
wandb_run_name: ${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} # wandb run name
wandb_resume_id: null # wandb run id for resuming, null to auto-detect from checkpoint
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
last_per_updates: 5000 # save last checkpoint per updates
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
================================================
FILE: src/f5_tts/configs/F5TTS_Small.yaml
================================================
hydra:
run:
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
datasets:
name: Emilia_ZH_EN
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
batch_size_type: frame # frame | sample
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
num_workers: 16
optim:
epochs: 11 # only suitable for Emilia, if you want to train it on LibriTTS, set epoch 686
learning_rate: 7.5e-5
num_warmup_updates: 20000 # warmup updates
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
max_grad_norm: 1.0 # gradient clipping
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
model:
name: F5TTS_Small
tokenizer: pinyin
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
backbone: DiT
arch:
dim: 768
depth: 18
heads: 12
ff_mult: 2
text_dim: 512
text_mask_padding: False
conv_layers: 4
pe_attn_head: 1
attn_backend: torch # torch | flash_attn
attn_mask_enabled: False
checkpoint_activations: False # recompute activations and save memory for extra compute
mel_spec:
target_sample_rate: 24000
n_mel_channels: 100
hop_length: 256
win_length: 1024
n_fft: 1024
mel_spec_type: vocos # vocos | bigvgan
vocoder:
is_local: False # use local offline ckpt or not
local_path: null # local vocoder path
ckpts:
logger: wandb # wandb | tensorboard | null
wandb_project: CFM-TTS # wandb project name
wandb_run_name: ${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} # wandb run name
wandb_resume_id: null # wandb run id for resuming, null to auto-detect from checkpoint
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
last_per_updates: 5000 # save last checkpoint per updates
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
================================================
FILE: src/f5_tts/configs/F5TTS_v1_Base.yaml
================================================
hydra:
run:
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
datasets:
name: Emilia_ZH_EN # dataset name
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
batch_size_type: frame # frame | sample
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
num_workers: 16
optim:
epochs: 11
learning_rate: 7.5e-5
num_warmup_updates: 20000 # warmup updates
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
max_grad_norm: 1.0 # gradient clipping
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
model:
name: F5TTS_v1_Base # model name
tokenizer: pinyin # tokenizer type
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
backbone: DiT
arch:
dim: 1024
depth: 22
heads: 16
ff_mult: 2
text_dim: 512
text_mask_padding: True
qk_norm: null # null | rms_norm
conv_layers: 4
pe_attn_head: null
attn_backend: torch # torch | flash_attn
attn_mask_enabled: False
checkpoint_activations: False # recompute activations and save memory for extra compute
mel_spec:
target_sample_rate: 24000
n_mel_channels: 100
hop_length: 256
win_length: 1024
n_fft: 1024
mel_spec_type: vocos # vocos | bigvgan
vocoder:
is_local: False # use local offline ckpt or not
local_path: null # local vocoder path
ckpts:
logger: wandb # wandb | tensorboard | null
wandb_project: CFM-TTS # wandb project name
wandb_run_name: ${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} # wandb run name
wandb_resume_id: null # wandb run id for resuming, null to auto-detect from checkpoint
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
last_per_updates: 5000 # save last checkpoint per updates
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
================================================
FILE: src/f5_tts/eval/README.md
================================================
# Evaluation
Install packages for evaluation:
```bash
pip install -e .[eval]
```
> [!IMPORTANT]
> For [faster-whisper](https://github.com/SYSTRAN/faster-whisper), for various compatibilities:
> `pip install ctranslate2==4.5.0` if CUDA 12 and cuDNN 9;
> `pip install ctranslate2==4.4.0` if CUDA 12 and cuDNN 8;
> `pip install ctranslate2==3.24.0` if CUDA 11 and cuDNN 8.
## Generating Samples for Evaluation
### Prepare Test Datasets
1. *Seed-TTS testset*: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
2. *LibriSpeech test-clean*: Download from [OpenSLR](http://www.openslr.org/12/).
3. Unzip the downloaded datasets and place them in the `data/` directory.
4. Our filtered LibriSpeech-PC 4-10s subset: `data/librispeech_pc_test_clean_cross_sentence.lst`
### Batch Inference for Test Set
To run batch inference for evaluations, execute the following commands:
```bash
# if not setup accelerate config yet
accelerate config
# if only perform inference
bash src/f5_tts/eval/eval_infer_batch.sh --infer-only
# if inference and with corresponding evaluation, setup the following tools first
bash src/f5_tts/eval/eval_infer_batch.sh
```
## Objective Evaluation on Generated Results
### Download Evaluation Model Checkpoints
1. Chinese ASR Model: [Paraformer-zh](https://huggingface.co/funasr/paraformer-zh)
2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3)
3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view).
> [!NOTE]
> ASR model will be automatically downloaded if `--local` not set for evaluation scripts.
> Otherwise, you should update the `asr_ckpt_dir` path values in `eval_librispeech_test_clean.py` or `eval_seedtts_testset.py`.
>
> WavLM model must be downloaded and your `wavlm_ckpt_dir` path updated in `eval_librispeech_test_clean.py` and `eval_seedtts_testset.py`.
### Objective Evaluation Examples
Update the path with your batch-inferenced results, and carry out WER / SIM / UTMOS evaluations:
```bash
# Evaluation [WER] for Seed-TTS test [ZH] set
python src/f5_tts/eval/eval_seedtts_testset.py --eval_task wer --lang zh --gen_wav_dir --gpu_nums 8
# Evaluation [SIM] for LibriSpeech-PC test-clean (cross-sentence)
python src/f5_tts/eval/eval_librispeech_test_clean.py --eval_task sim --gen_wav_dir --librispeech_test_clean_path
# Evaluation [UTMOS]. --ext: Audio extension
python src/f5_tts/eval/eval_utmos.py --audio_dir --ext wav
```
> [!NOTE]
> Evaluation results can also be found in `_*_results.jsonl` files saved in ``/``.
================================================
FILE: src/f5_tts/eval/ecapa_tdnn.py
================================================
# just for speaker similarity evaluation, third-party code
# From https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
""" Res2Conv1d + BatchNorm1d + ReLU
"""
class Res2Conv1dReluBn(nn.Module):
"""
in_channels == out_channels == channels
"""
def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
super().__init__()
assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
self.scale = scale
self.width = channels // scale
self.nums = scale if scale == 1 else scale - 1
self.convs = []
self.bns = []
for i in range(self.nums):
self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
self.bns.append(nn.BatchNorm1d(self.width))
self.convs = nn.ModuleList(self.convs)
self.bns = nn.ModuleList(self.bns)
def forward(self, x):
out = []
spx = torch.split(x, self.width, 1)
for i in range(self.nums):
if i == 0:
sp = spx[i]
else:
sp = sp + spx[i]
# Order: conv -> relu -> bn
sp = self.convs[i](sp)
sp = self.bns[i](F.relu(sp))
out.append(sp)
if self.scale != 1:
out.append(spx[self.nums])
out = torch.cat(out, dim=1)
return out
""" Conv1d + BatchNorm1d + ReLU
"""
class Conv1dReluBn(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
super().__init__()
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
self.bn = nn.BatchNorm1d(out_channels)
def forward(self, x):
return self.bn(F.relu(self.conv(x)))
""" The SE connection of 1D case.
"""
class SE_Connect(nn.Module):
def __init__(self, channels, se_bottleneck_dim=128):
super().__init__()
self.linear1 = nn.Linear(channels, se_bottleneck_dim)
self.linear2 = nn.Linear(se_bottleneck_dim, channels)
def forward(self, x):
out = x.mean(dim=2)
out = F.relu(self.linear1(out))
out = torch.sigmoid(self.linear2(out))
out = x * out.unsqueeze(2)
return out
""" SE-Res2Block of the ECAPA-TDNN architecture.
"""
# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
# return nn.Sequential(
# Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
# Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
# Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
# SE_Connect(channels)
# )
class SE_Res2Block(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
super().__init__()
self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale)
self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
self.shortcut = None
if in_channels != out_channels:
self.shortcut = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
)
def forward(self, x):
residual = x
if self.shortcut:
residual = self.shortcut(x)
x = self.Conv1dReluBn1(x)
x = self.Res2Conv1dReluBn(x)
x = self.Conv1dReluBn2(x)
x = self.SE_Connect(x)
return x + residual
""" Attentive weighted mean and standard deviation pooling.
"""
class AttentiveStatsPool(nn.Module):
def __init__(self, in_dim, attention_channels=128, global_context_att=False):
super().__init__()
self.global_context_att = global_context_att
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
if global_context_att:
self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper
else:
self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper
self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
def forward(self, x):
if self.global_context_att:
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
x_in = torch.cat((x, context_mean, context_std), dim=1)
else:
x_in = x
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
alpha = torch.tanh(self.linear1(x_in))
# alpha = F.relu(self.linear1(x_in))
alpha = torch.softmax(self.linear2(alpha), dim=2)
mean = torch.sum(alpha * x, dim=2)
residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
std = torch.sqrt(residuals.clamp(min=1e-9))
return torch.cat([mean, std], dim=1)
class ECAPA_TDNN(nn.Module):
def __init__(
self,
feat_dim=80,
channels=512,
emb_dim=192,
global_context_att=False,
feat_type="wavlm_large",
sr=16000,
feature_selection="hidden_states",
update_extract=False,
config_path=None,
):
super().__init__()
self.feat_type = feat_type
self.feature_selection = feature_selection
self.update_extract = update_extract
self.sr = sr
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
try:
local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source="local", config_path=config_path)
except: # noqa: E722
self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type)
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"
):
self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"
):
self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
self.feat_num = self.get_feat_num()
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
if feat_type != "fbank" and feat_type != "mfcc":
freeze_list = ["final_proj", "label_embs_concat", "mask_emb", "project_q", "quantizer"]
for name, param in self.feature_extract.named_parameters():
for freeze_val in freeze_list:
if freeze_val in name:
param.requires_grad = False
break
if not self.update_extract:
for param in self.feature_extract.parameters():
param.requires_grad = False
self.instance_norm = nn.InstanceNorm1d(feat_dim)
# self.channels = [channels] * 4 + [channels * 3]
self.channels = [channels] * 4 + [1536]
self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
self.layer2 = SE_Res2Block(
self.channels[0],
self.channels[1],
kernel_size=3,
stride=1,
padding=2,
dilation=2,
scale=8,
se_bottleneck_dim=128,
)
self.layer3 = SE_Res2Block(
self.channels[1],
self.channels[2],
kernel_size=3,
stride=1,
padding=3,
dilation=3,
scale=8,
se_bottleneck_dim=128,
)
self.layer4 = SE_Res2Block(
self.channels[2],
self.channels[3],
kernel_size=3,
stride=1,
padding=4,
dilation=4,
scale=8,
se_bottleneck_dim=128,
)
# self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
cat_channels = channels * 3
self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
self.pooling = AttentiveStatsPool(
self.channels[-1], attention_channels=128, global_context_att=global_context_att
)
self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
def get_feat_num(self):
self.feature_extract.eval()
wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
with torch.no_grad():
features = self.feature_extract(wav)
select_feature = features[self.feature_selection]
if isinstance(select_feature, (list, tuple)):
return len(select_feature)
else:
return 1
def get_feat(self, x):
if self.update_extract:
x = self.feature_extract([sample for sample in x])
else:
with torch.no_grad():
if self.feat_type == "fbank" or self.feat_type == "mfcc":
x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
else:
x = self.feature_extract([sample for sample in x])
if self.feat_type == "fbank":
x = x.log()
if self.feat_type != "fbank" and self.feat_type != "mfcc":
x = x[self.feature_selection]
if isinstance(x, (list, tuple)):
x = torch.stack(x, dim=0)
else:
x = x.unsqueeze(0)
norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
x = (norm_weights * x).sum(dim=0)
x = torch.transpose(x, 1, 2) + 1e-6
x = self.instance_norm(x)
return x
def forward(self, x):
x = self.get_feat(x)
out1 = self.layer1(x)
out2 = self.layer2(out1)
out3 = self.layer3(out2)
out4 = self.layer4(out3)
out = torch.cat([out2, out3, out4], dim=1)
out = F.relu(self.conv(out))
out = self.bn(self.pooling(out))
out = self.linear(out)
return out
def ECAPA_TDNN_SMALL(
feat_dim,
emb_dim=256,
feat_type="wavlm_large",
sr=16000,
feature_selection="hidden_states",
update_extract=False,
config_path=None,
):
return ECAPA_TDNN(
feat_dim=feat_dim,
channels=512,
emb_dim=emb_dim,
feat_type=feat_type,
sr=sr,
feature_selection=feature_selection,
update_extract=update_extract,
config_path=config_path,
)
================================================
FILE: src/f5_tts/eval/eval_infer_batch.py
================================================
import os
import sys
sys.path.append(os.getcwd())
import argparse
import time
from importlib.resources import files
import torch
import torchaudio
from accelerate import Accelerator
from hydra.utils import get_class
from omegaconf import OmegaConf
from tqdm import tqdm
from f5_tts.eval.utils_eval import (
get_inference_prompt,
get_librispeech_test_clean_metainfo,
get_seedtts_testset_metainfo,
)
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
from f5_tts.model import CFM
from f5_tts.model.utils import get_tokenizer
accelerator = Accelerator()
device = f"cuda:{accelerator.process_index}"
use_ema = True
target_rms = 0.1
rel_path = str(files("f5_tts").joinpath("../../"))
def main():
parser = argparse.ArgumentParser(description="batch inference")
parser.add_argument("-s", "--seed", default=None, type=int)
parser.add_argument("-n", "--expname", required=True)
parser.add_argument("-c", "--ckptstep", default=1250000, type=int)
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
parser.add_argument("-o", "--odemethod", default="euler")
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
parser.add_argument("-t", "--testset", required=True)
parser.add_argument(
"-p", "--librispeech_test_clean_path", default=f"{rel_path}/data/LibriSpeech/test-clean", type=str
)
parser.add_argument("--local", action="store_true", help="Use local vocoder checkpoint directory")
args = parser.parse_args()
seed = args.seed
exp_name = args.expname
ckpt_step = args.ckptstep
nfe_step = args.nfestep
ode_method = args.odemethod
sway_sampling_coef = args.swaysampling
testset = args.testset
infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
cfg_strength = 2.0
speed = 1.0
use_truth_duration = False
no_ref_audio = False
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
dataset_name = model_cfg.datasets.name
tokenizer = model_cfg.model.tokenizer
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
n_mel_channels = model_cfg.model.mel_spec.n_mel_channels
hop_length = model_cfg.model.mel_spec.hop_length
win_length = model_cfg.model.mel_spec.win_length
n_fft = model_cfg.model.mel_spec.n_fft
if testset == "ls_pc_test_clean":
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
librispeech_test_clean_path = args.librispeech_test_clean_path
metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
elif testset == "seedtts_test_zh":
metalst = rel_path + "/data/seedtts_testset/zh/meta.lst"
metainfo = get_seedtts_testset_metainfo(metalst)
elif testset == "seedtts_test_en":
metalst = rel_path + "/data/seedtts_testset/en/meta.lst"
metainfo = get_seedtts_testset_metainfo(metalst)
# path to save genereted wavs
output_dir = (
f"{rel_path}/"
f"results/{exp_name}_{ckpt_step}/{testset}/"
f"seed{seed}_{ode_method}_nfe{nfe_step}_{mel_spec_type}"
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
f"_cfg{cfg_strength}_speed{speed}"
f"{'_gt-dur' if use_truth_duration else ''}"
f"{'_no-ref-audio' if no_ref_audio else ''}"
)
# -------------------------------------------------#
prompts_all = get_inference_prompt(
metainfo,
speed=speed,
tokenizer=tokenizer,
target_sample_rate=target_sample_rate,
n_mel_channels=n_mel_channels,
hop_length=hop_length,
mel_spec_type=mel_spec_type,
target_rms=target_rms,
use_truth_duration=use_truth_duration,
infer_batch_size=infer_batch_size,
)
# Vocoder model
local = args.local
if mel_spec_type == "vocos":
vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
elif mel_spec_type == "bigvgan":
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path)
# Tokenizer
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
# Model
model = CFM(
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
mel_spec_kwargs=dict(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type,
),
odeint_kwargs=dict(
method=ode_method,
),
vocab_char_map=vocab_char_map,
).to(device)
ckpt_prefix = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}"
if os.path.exists(ckpt_prefix + ".pt"):
ckpt_path = ckpt_prefix + ".pt"
elif os.path.exists(ckpt_prefix + ".safetensors"):
ckpt_path = ckpt_prefix + ".safetensors"
else:
print("Loading from self-organized training checkpoints rather than released pretrained.")
ckpt_prefix = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}"
if os.path.exists(ckpt_prefix + ".pt"):
ckpt_path = ckpt_prefix + ".pt"
elif os.path.exists(ckpt_prefix + ".safetensors"):
ckpt_path = ckpt_prefix + ".safetensors"
else:
raise ValueError("The checkpoint does not exist or cannot be found in given location.")
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
if not os.path.exists(output_dir) and accelerator.is_main_process:
os.makedirs(output_dir)
# start batch inference
accelerator.wait_for_everyone()
start = time.time()
with accelerator.split_between_processes(prompts_all) as prompts:
for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
ref_mels = ref_mels.to(device)
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
# Inference
with torch.inference_mode():
generated, _ = model.sample(
cond=ref_mels,
text=final_text_list,
duration=total_mel_lens,
lens=ref_mel_lens,
steps=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
no_ref_audio=no_ref_audio,
seed=seed,
)
# Final result
for i, gen in enumerate(generated):
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
if mel_spec_type == "vocos":
generated_wave = vocoder.decode(gen_mel_spec).cpu()
elif mel_spec_type == "bigvgan":
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
if ref_rms_list[i] < target_rms:
generated_wave = generated_wave * ref_rms_list[i] / target_rms
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
accelerator.wait_for_everyone()
if accelerator.is_main_process:
timediff = time.time() - start
print(f"Done batch inference in {timediff / 60:.2f} minutes.")
if __name__ == "__main__":
main()
================================================
FILE: src/f5_tts/eval/eval_infer_batch.sh
================================================
#!/bin/bash
set -e
export PYTHONWARNINGS="ignore::UserWarning,ignore::FutureWarning"
# Configuration parameters
MODEL_NAME="F5TTS_v1_Base"
SEEDS=(0 1 2)
CKPTSTEPS=(1250000)
TASKS=("seedtts_test_zh" "seedtts_test_en" "ls_pc_test_clean")
LS_TEST_CLEAN_PATH="data/LibriSpeech/test-clean"
GPUS="[0,1,2,3,4,5,6,7]"
OFFLINE_MODE=false
# Parse arguments
if [ $OFFLINE_MODE = true ]; then
LOCAL="--local"
else
LOCAL=""
fi
INFER_ONLY=false
while [[ $# -gt 0 ]]; do
case $1 in
--infer-only)
INFER_ONLY=true
shift
;;
*)
echo "======== Unknown parameter: $1"
exit 1
;;
esac
done
echo "======== Starting F5-TTS batch evaluation task..."
if [ "$INFER_ONLY" = true ]; then
echo "======== Mode: Execute infer tasks only"
else
echo "======== Mode: Execute full pipeline (infer + eval)"
fi
# Function: Execute eval tasks
execute_eval_tasks() {
local ckptstep=$1
local seed=$2
local task_name=$3
local gen_wav_dir="results/${MODEL_NAME}_${ckptstep}/${task_name}/seed${seed}_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0"
echo ">>>>>>>> Starting eval task: ckptstep=${ckptstep}, seed=${seed}, task=${task_name}"
case $task_name in
"seedtts_test_zh")
python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh -g "$gen_wav_dir" -n "$GPUS" $LOCAL
python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh -g "$gen_wav_dir" -n "$GPUS" $LOCAL
python src/f5_tts/eval/eval_utmos.py --audio_dir "$gen_wav_dir"
;;
"seedtts_test_en")
python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l en -g "$gen_wav_dir" -n "$GPUS" $LOCAL
python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l en -g "$gen_wav_dir" -n "$GPUS" $LOCAL
python src/f5_tts/eval/eval_utmos.py --audio_dir "$gen_wav_dir"
;;
"ls_pc_test_clean")
python src/f5_tts/eval/eval_librispeech_test_clean.py -e wer -g "$gen_wav_dir" -n "$GPUS" -p "$LS_TEST_CLEAN_PATH" $LOCAL
python src/f5_tts/eval/eval_librispeech_test_clean.py -e sim -g "$gen_wav_dir" -n "$GPUS" -p "$LS_TEST_CLEAN_PATH" $LOCAL
python src/f5_tts/eval/eval_utmos.py --audio_dir "$gen_wav_dir"
;;
esac
echo ">>>>>>>> Completed eval task: ckptstep=${ckptstep}, seed=${seed}, task=${task_name}"
}
# Main execution loop
for ckptstep in "${CKPTSTEPS[@]}"; do
echo "======== Processing ckptstep: ${ckptstep}"
for seed in "${SEEDS[@]}"; do
echo "-------- Processing seed: ${seed}"
# Store eval task PIDs for current seed (if not infer-only mode)
if [ "$INFER_ONLY" = false ]; then
declare -a eval_pids
fi
# Execute each infer task sequentially
for task in "${TASKS[@]}"; do
echo ">>>>>>>> Executing infer task: accelerate launch src/f5_tts/eval/eval_infer_batch.py -s ${seed} -n \"${MODEL_NAME}\" -t \"${task}\" -c ${ckptstep} $LOCAL"
# Execute infer task (foreground execution, wait for completion)
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s ${seed} -n "${MODEL_NAME}" -t "${task}" -c ${ckptstep} -p "${LS_TEST_CLEAN_PATH}" $LOCAL
# If not infer-only mode, launch corresponding eval task
if [ "$INFER_ONLY" = false ]; then
# Launch corresponding eval task (background execution, non-blocking for next infer)
execute_eval_tasks $ckptstep $seed $task &
eval_pids+=($!)
fi
done
# If not infer-only mode, wait for all eval tasks of current seed to complete
if [ "$INFER_ONLY" = false ]; then
echo ">>>>>>>> All infer tasks for seed ${seed} completed, waiting for corresponding eval tasks to finish..."
for pid in "${eval_pids[@]}"; do
wait $pid
done
unset eval_pids # Clean up array
fi
echo "-------- All eval tasks for seed ${seed} completed"
done
echo "======== Completed ckptstep: ${ckptstep}"
echo
done
echo "======== All tasks completed!"
================================================
FILE: src/f5_tts/eval/eval_infer_batch_example.sh
================================================
#!/bin/bash
# e.g. F5-TTS, 16 NFE
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_zh" -nfe 16
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_en" -nfe 16
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "ls_pc_test_clean" -nfe 16 -p data/LibriSpeech/test-clean
# e.g. Vanilla E2 TTS, 32 NFE
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_zh" -o "midpoint" -ss 0
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_en" -o "midpoint" -ss 0
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "ls_pc_test_clean" -o "midpoint" -ss 0 -p data/LibriSpeech/test-clean
# e.g. evaluate F5-TTS 16 NFE result on Seed-TTS test-zh
python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
python src/f5_tts/eval/eval_utmos.py --audio_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0
# etc.
================================================
FILE: src/f5_tts/eval/eval_librispeech_test_clean.py
================================================
# Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
import argparse
import ast
import json
import os
import sys
sys.path.append(os.getcwd())
import multiprocessing as mp
from importlib.resources import files
import numpy as np
from f5_tts.eval.utils_eval import get_librispeech_test, run_asr_wer, run_sim
rel_path = str(files("f5_tts").joinpath("../../"))
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"])
parser.add_argument("-l", "--lang", type=str, default="en")
parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
parser.add_argument("-p", "--librispeech_test_clean_path", type=str, required=True)
parser.add_argument(
"-n", "--gpu_nums", type=str, default="8", help="Number of GPUs to use (e.g., 8) or GPU list (e.g., [0,1,2,3])"
)
parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory")
return parser.parse_args()
def parse_gpu_nums(gpu_nums_str):
try:
if gpu_nums_str.startswith("[") and gpu_nums_str.endswith("]"):
gpu_list = ast.literal_eval(gpu_nums_str)
if isinstance(gpu_list, list):
return gpu_list
return list(range(int(gpu_nums_str)))
except (ValueError, SyntaxError):
raise argparse.ArgumentTypeError(
f"Invalid GPU specification: {gpu_nums_str}. Use a number (e.g., 8) or a list (e.g., [0,1,2,3])"
)
def main():
args = get_args()
eval_task = args.eval_task
lang = args.lang
librispeech_test_clean_path = args.librispeech_test_clean_path # test-clean path
gen_wav_dir = args.gen_wav_dir
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
gpus = parse_gpu_nums(args.gpu_nums)
test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
## leading to a low similarity for the ground truth in some cases.
# test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = True) # eval ground truth
local = args.local
if local: # use local custom checkpoint dir
asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
else:
asr_ckpt_dir = "" # auto download to cache dir
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
# --------------------------------------------------------------------------
full_results = []
metrics = []
if eval_task == "wer":
with mp.Pool(processes=len(gpus)) as pool:
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
results = pool.map(run_asr_wer, args)
for r in results:
full_results.extend(r)
elif eval_task == "sim":
with mp.Pool(processes=len(gpus)) as pool:
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
results = pool.map(run_sim, args)
for r in results:
full_results.extend(r)
else:
raise ValueError(f"Unknown metric type: {eval_task}")
result_path = f"{gen_wav_dir}/_{eval_task}_results.jsonl"
with open(result_path, "w") as f:
for line in full_results:
metrics.append(line[eval_task])
f.write(json.dumps(line, ensure_ascii=False) + "\n")
metric = round(np.mean(metrics), 5)
f.write(f"\n{eval_task.upper()}: {metric}\n")
print(f"\nTotal {len(metrics)} samples")
print(f"{eval_task.upper()}: {metric}")
print(f"{eval_task.upper()} results saved to {result_path}")
if __name__ == "__main__":
main()
================================================
FILE: src/f5_tts/eval/eval_seedtts_testset.py
================================================
# Evaluate with Seed-TTS testset
import argparse
import ast
import json
import os
import sys
sys.path.append(os.getcwd())
import multiprocessing as mp
from importlib.resources import files
import numpy as np
from f5_tts.eval.utils_eval import get_seed_tts_test, run_asr_wer, run_sim
rel_path = str(files("f5_tts").joinpath("../../"))
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"])
parser.add_argument("-l", "--lang", type=str, default="en", choices=["zh", "en"])
parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
parser.add_argument(
"-n", "--gpu_nums", type=str, default="8", help="Number of GPUs to use (e.g., 8) or GPU list (e.g., [0,1,2,3])"
)
parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory")
return parser.parse_args()
def parse_gpu_nums(gpu_nums_str):
try:
if gpu_nums_str.startswith("[") and gpu_nums_str.endswith("]"):
gpu_list = ast.literal_eval(gpu_nums_str)
if isinstance(gpu_list, list):
return gpu_list
return list(range(int(gpu_nums_str)))
except (ValueError, SyntaxError):
raise argparse.ArgumentTypeError(
f"Invalid GPU specification: {gpu_nums_str}. Use a number (e.g., 8) or a list (e.g., [0,1,2,3])"
)
def main():
args = get_args()
eval_task = args.eval_task
lang = args.lang
gen_wav_dir = args.gen_wav_dir
metalst = rel_path + f"/data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
# NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
# zh 1.254 seems a result of 4 workers wer_seed_tts
gpus = parse_gpu_nums(args.gpu_nums)
test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
local = args.local
if local: # use local custom checkpoint dir
if lang == "zh":
asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
elif lang == "en":
asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
else:
asr_ckpt_dir = "" # auto download to cache dir
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
# --------------------------------------------------------------------------
full_results = []
metrics = []
if eval_task == "wer":
with mp.Pool(processes=len(gpus)) as pool:
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
results = pool.map(run_asr_wer, args)
for r in results:
full_results.extend(r)
elif eval_task == "sim":
with mp.Pool(processes=len(gpus)) as pool:
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
results = pool.map(run_sim, args)
for r in results:
full_results.extend(r)
else:
raise ValueError(f"Unknown metric type: {eval_task}")
result_path = f"{gen_wav_dir}/_{eval_task}_results.jsonl"
with open(result_path, "w") as f:
for line in full_results:
metrics.append(line[eval_task])
f.write(json.dumps(line, ensure_ascii=False) + "\n")
metric = round(np.mean(metrics), 5)
f.write(f"\n{eval_task.upper()}: {metric}\n")
print(f"\nTotal {len(metrics)} samples")
print(f"{eval_task.upper()}: {metric}")
print(f"{eval_task.upper()} results saved to {result_path}")
if __name__ == "__main__":
main()
================================================
FILE: src/f5_tts/eval/eval_utmos.py
================================================
import argparse
import json
from pathlib import Path
import librosa
import torch
from tqdm import tqdm
def main():
parser = argparse.ArgumentParser(description="UTMOS Evaluation")
parser.add_argument("--audio_dir", type=str, required=True, help="Audio file path.")
parser.add_argument("--ext", type=str, default="wav", help="Audio extension.")
args = parser.parse_args()
device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "cpu"
predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True)
predictor = predictor.to(device)
audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
utmos_score = 0
utmos_result_path = Path(args.audio_dir) / "_utmos_results.jsonl"
with open(utmos_result_path, "w", encoding="utf-8") as f:
for audio_path in tqdm(audio_paths, desc="Processing"):
wav, sr = librosa.load(audio_path, sr=None, mono=True)
wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0)
score = predictor(wav_tensor, sr)
line = {}
line["wav"], line["utmos"] = str(audio_path.stem), score.item()
utmos_score += score.item()
f.write(json.dumps(line, ensure_ascii=False) + "\n")
avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0
f.write(f"\nUTMOS: {avg_score:.4f}\n")
print(f"UTMOS: {avg_score:.4f}")
print(f"UTMOS results saved to {utmos_result_path}")
if __name__ == "__main__":
main()
================================================
FILE: src/f5_tts/eval/utils_eval.py
================================================
import math
import os
import random
import string
from pathlib import Path
import torch
import torch.nn.functional as F
import torchaudio
from tqdm import tqdm
from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL
from f5_tts.model.modules import MelSpec
from f5_tts.model.utils import convert_char_to_pinyin
# seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
def get_seedtts_testset_metainfo(metalst):
f = open(metalst)
lines = f.readlines()
f.close()
metainfo = []
for line in lines:
if len(line.strip().split("|")) == 5:
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
elif len(line.strip().split("|")) == 4:
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
if not os.path.isabs(prompt_wav):
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
return metainfo
# librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
f = open(metalst)
lines = f.readlines()
f.close()
metainfo = []
for line in lines:
ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
# ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
# gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
return metainfo
# padded to max length mel batch
def padded_mel_batch(ref_mels):
max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
padded_ref_mels = []
for mel in ref_mels:
padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0)
padded_ref_mels.append(padded_ref_mel)
padded_ref_mels = torch.stack(padded_ref_mels)
padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
return padded_ref_mels
# get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
def get_inference_prompt(
metainfo,
speed=1.0,
tokenizer="pinyin",
polyphone=True,
target_sample_rate=24000,
n_fft=1024,
win_length=1024,
n_mel_channels=100,
hop_length=256,
mel_spec_type="vocos",
target_rms=0.1,
use_truth_duration=False,
infer_batch_size=1,
num_buckets=200,
min_secs=3,
max_secs=40,
):
prompts_all = []
min_tokens = min_secs * target_sample_rate // hop_length
max_tokens = max_secs * target_sample_rate // hop_length
batch_accum = [0] * num_buckets
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
[[] for _ in range(num_buckets)] for _ in range(6)
)
mel_spectrogram = MelSpec(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type,
)
for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
# Audio
ref_audio, ref_sr = torchaudio.load(prompt_wav)
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
if ref_rms < target_rms:
ref_audio = ref_audio * target_rms / ref_rms
assert ref_audio.shape[-1] > 5000, f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
if ref_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
ref_audio = resampler(ref_audio)
# Text
if len(prompt_text[-1].encode("utf-8")) == 1:
prompt_text = prompt_text + " "
text = [prompt_text + gt_text]
if tokenizer == "pinyin":
text_list = convert_char_to_pinyin(text, polyphone=polyphone)
else:
text_list = text
# to mel spectrogram
ref_mel = mel_spectrogram(ref_audio)
ref_mel = ref_mel.squeeze(0)
# Duration, mel frame length
ref_mel_len = ref_mel.shape[-1]
if use_truth_duration:
gt_audio, gt_sr = torchaudio.load(gt_wav)
if gt_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
gt_audio = resampler(gt_audio)
total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
# # test vocoder resynthesis
# ref_audio = gt_audio
else:
ref_text_len = len(prompt_text.encode("utf-8"))
gen_text_len = len(gt_text.encode("utf-8"))
total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
# deal with batch
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
assert min_tokens <= total_mel_len <= max_tokens, (
f"Audio {utt} has duration {total_mel_len * hop_length // target_sample_rate}s out of range [{min_secs}, {max_secs}]."
)
bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
utts[bucket_i].append(utt)
ref_rms_list[bucket_i].append(ref_rms)
ref_mels[bucket_i].append(ref_mel)
ref_mel_lens[bucket_i].append(ref_mel_len)
total_mel_lens[bucket_i].append(total_mel_len)
final_text_list[bucket_i].extend(text_list)
batch_accum[bucket_i] += total_mel_len
if batch_accum[bucket_i] >= infer_batch_size:
# print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
prompts_all.append(
(
utts[bucket_i],
ref_rms_list[bucket_i],
padded_mel_batch(ref_mels[bucket_i]),
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i],
)
)
batch_accum[bucket_i] = 0
(
utts[bucket_i],
ref_rms_list[bucket_i],
ref_mels[bucket_i],
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i],
) = [], [], [], [], [], []
# add residual
for bucket_i, bucket_frames in enumerate(batch_accum):
if bucket_frames > 0:
prompts_all.append(
(
utts[bucket_i],
ref_rms_list[bucket_i],
padded_mel_batch(ref_mels[bucket_i]),
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i],
)
)
# not only leave easy work for last workers
random.seed(666)
random.shuffle(prompts_all)
return prompts_all
# get wav_res_ref_text of seed-tts test metalst
# https://github.com/BytedanceSpeech/seed-tts-eval
def get_seed_tts_test(metalst, gen_wav_dir, gpus):
f = open(metalst)
lines = f.readlines()
f.close()
test_set_ = []
for line in tqdm(lines):
if len(line.strip().split("|")) == 5:
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
elif len(line.strip().split("|")) == 4:
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
if not os.path.exists(os.path.join(gen_wav_dir, utt + ".wav")):
continue
gen_wav = os.path.join(gen_wav_dir, utt + ".wav")
if not os.path.isabs(prompt_wav):
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
test_set_.append((gen_wav, prompt_wav, gt_text))
num_jobs = len(gpus)
if num_jobs == 1:
return [(gpus[0], test_set_)]
wav_per_job = len(test_set_) // num_jobs + 1
test_set = []
for i in range(num_jobs):
test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
return test_set
# get librispeech test-clean cross sentence test
def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth=False):
f = open(metalst)
lines = f.readlines()
f.close()
test_set_ = []
for line in tqdm(lines):
ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
if eval_ground_truth:
gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
else:
if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + ".wav")):
raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
gen_wav = os.path.join(gen_wav_dir, gen_utt + ".wav")
ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
test_set_.append((gen_wav, ref_wav, gen_txt))
num_jobs = len(gpus)
if num_jobs == 1:
return [(gpus[0], test_set_)]
wav_per_job = len(test_set_) // num_jobs + 1
test_set = []
for i in range(num_jobs):
test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
return test_set
# load asr model
def load_asr_model(lang, ckpt_dir=""):
if lang == "zh":
from funasr import AutoModel
model = AutoModel(
model=os.path.join(ckpt_dir, "paraformer-zh"),
# vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
# punc_model = os.path.join(ckpt_dir, "ct-punc"),
# spk_model = os.path.join(ckpt_dir, "cam++"),
disable_update=True,
) # following seed-tts setting
elif lang == "en":
from faster_whisper import WhisperModel
model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
model = WhisperModel(model_size, device="cuda", compute_type="float16")
return model
# WER Evaluation, the way Seed-TTS does
def run_asr_wer(args):
rank, lang, test_set, ckpt_dir = args
if lang == "zh":
import zhconv
torch.cuda.set_device(rank)
elif lang == "en":
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
else:
raise NotImplementedError(
"lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now."
)
asr_model = load_asr_model(lang, ckpt_dir=ckpt_dir)
from zhon.hanzi import punctuation
punctuation_all = punctuation + string.punctuation
wer_results = []
from jiwer import process_words
for gen_wav, prompt_wav, truth in tqdm(test_set):
if lang == "zh":
res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
hypo = res[0]["text"]
hypo = zhconv.convert(hypo, "zh-cn")
elif lang == "en":
segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
hypo = ""
for segment in segments:
hypo = hypo + " " + segment.text
raw_truth = truth
raw_hypo = hypo
for x in punctuation_all:
truth = truth.replace(x, "")
hypo = hypo.replace(x, "")
truth = truth.replace(" ", " ")
hypo = hypo.replace(" ", " ")
if lang == "zh":
truth = " ".join([x for x in truth])
hypo = " ".join([x for x in hypo])
elif lang == "en":
truth = truth.lower()
hypo = hypo.lower()
measures = process_words(truth, hypo)
wer = measures.wer
# ref_list = truth.split(" ")
# subs = measures.substitutions / len(ref_list)
# dele = measures.deletions / len(ref_list)
# inse = measures.insertions / len(ref_list)
wer_results.append(
{
"wav": Path(gen_wav).stem,
"truth": raw_truth,
"hypo": raw_hypo,
"wer": wer,
}
)
return wer_results
# SIM Evaluation
def run_sim(args):
rank, test_set, ckpt_dir = args
device = f"cuda:{rank}"
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type="wavlm_large", config_path=None)
state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
model.load_state_dict(state_dict["model"], strict=False)
use_gpu = True if torch.cuda.is_available() else False
if use_gpu:
model = model.cuda(device)
model.eval()
sim_results = []
for gen_wav, prompt_wav, truth in tqdm(test_set):
wav1, sr1 = torchaudio.load(gen_wav)
wav2, sr2 = torchaudio.load(prompt_wav)
if use_gpu:
wav1 = wav1.cuda(device)
wav2 = wav2.cuda(device)
if sr1 != 16000:
resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
if use_gpu:
resample1 = resample1.cuda(device)
wav1 = resample1(wav1)
if sr2 != 16000:
resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
if use_gpu:
resample2 = resample2.cuda(device)
wav2 = resample2(wav2)
with torch.no_grad():
emb1 = model(wav1)
emb2 = model(wav2)
sim = F.cosine_similarity(emb1, emb2)[0].item()
# print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
sim_results.append(
{
"wav": Path(gen_wav).stem,
"sim": sim,
}
)
return sim_results
================================================
FILE: src/f5_tts/infer/README.md
================================================
# Inference
The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS) and [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), or will be automatically downloaded when running inference scripts.
**More checkpoints with whole community efforts can be found in [SHARED.md](SHARED.md), supporting more languages.**
Currently support **30s for a single** generation, which is the **total length** (same logic if `fix_duration`) including both prompt and output audio. However, `infer_cli` and `infer_gradio` will automatically do chunk generation for longer text. Long reference audio will be **clip short to ~12s**.
To avoid possible inference failures, make sure you have seen through the following instructions.
- Use reference audio <12s and leave proper silence space (e.g. 1s) at the end. Otherwise there is a risk of truncating in the middle of word, leading to suboptimal generation.
- Uppercased letters (best with form like K.F.C.) will be uttered letter by letter, and lowercased letters used for common words.
- Add some spaces (blank: " ") or punctuations (e.g. "," ".") to explicitly introduce some pauses.
- If English punctuation marks the end of a sentence, make sure there is a space " " after it. Otherwise not regarded as when chunk.
- Preprocess numbers to Chinese letters if you want to have them read in Chinese, otherwise in English.
- If the generation output is blank (pure silence), check for FFmpeg installation.
- Try turn off `use_ema` if using an early-stage finetuned checkpoint (which goes just few updates).
## Gradio App
Currently supported features:
- Basic TTS with Chunk Inference
- Multi-Style / Multi-Speaker Generation
- Voice Chat powered by Qwen2.5-3B-Instruct
- [Custom inference with more language support](SHARED.md)
The cli command `f5-tts_infer-gradio` equals to `python src/f5_tts/infer/infer_gradio.py`, which launches a Gradio APP (web interface) for inference.
The script will load model checkpoints from Huggingface. You can also manually download files and update the path to `load_model()` in `infer_gradio.py`. Currently only load TTS models first, will load ASR model to do transcription if `ref_text` not provided, will load LLM model if use Voice Chat.
More flags options:
```bash
# Automatically launch the interface in the default web browser
f5-tts_infer-gradio --inbrowser
# Set the root path of the application, if it's not served from the root ("/") of the domain
# For example, if the application is served at "https://example.com/myapp"
f5-tts_infer-gradio --root_path "/myapp"
```
Could also be used as a component for larger application:
```python
import gradio as gr
from f5_tts.infer.infer_gradio import app
with gr.Blocks() as main_app:
gr.Markdown("# This is an example of using F5-TTS within a bigger Gradio app")
# ... other Gradio components
app.render()
main_app.launch()
```
## CLI Inference
The cli command `f5-tts_infer-cli` equals to `python src/f5_tts/infer/infer_cli.py`, which is a command line tool for inference.
The script will load model checkpoints from Huggingface. You can also manually download files and use `--ckpt_file` to specify the model you want to load, or directly update in `infer_cli.py`.
For change vocab.txt use `--vocab_file` to provide your `vocab.txt` file.
Basically you can inference with flags:
```bash
# Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
f5-tts_infer-cli \
--model F5TTS_v1_Base \
--ref_audio "ref_audio.wav" \
--ref_text "The content, subtitle or transcription of reference audio." \
--gen_text "Some text you want TTS model generate for you."
# Use BigVGAN as vocoder. Currently only support F5TTS_Base.
f5-tts_infer-cli --model F5TTS_Base --vocoder_name bigvgan --load_vocoder_from_local
# Use custom path checkpoint, e.g.
f5-tts_infer-cli --ckpt_file ckpts/F5TTS_v1_Base/model_1250000.safetensors
# More instructions
f5-tts_infer-cli --help
```
And a `.toml` file would help with more flexible usage.
```bash
f5-tts_infer-cli -c custom.toml
```
For example, you can use `.toml` to pass in variables, refer to `src/f5_tts/infer/examples/basic/basic.toml`:
```toml
# F5TTS_v1_Base | E2TTS_Base
model = "F5TTS_v1_Base"
ref_audio = "infer/examples/basic/basic_ref_en.wav"
# If an empty "", transcribes the reference audio automatically.
ref_text = "Some call me nature, others call me mother nature."
gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
# File with text to generate. Ignores the text above.
gen_file = ""
remove_silence = false
output_dir = "tests"
```
You can also leverage `.toml` file to do multi-style generation, refer to `src/f5_tts/infer/examples/multi/story.toml`.
```toml
# F5TTS_v1_Base | E2TTS_Base
model = "F5TTS_v1_Base"
ref_audio = "infer/examples/multi/main.flac"
# If an empty "", transcribes the reference audio automatically.
ref_text = ""
gen_text = ""
# File with text to generate. Ignores the text above.
gen_file = "infer/examples/multi/story.txt"
remove_silence = true
output_dir = "tests"
[voices.town]
ref_audio = "infer/examples/multi/town.flac"
ref_text = ""
[voices.country]
ref_audio = "infer/examples/multi/country.flac"
ref_text = ""
```
You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`.
## API Usage
```python
from importlib.resources import files
from f5_tts.api import F5TTS
f5tts = F5TTS()
wav, sr, spec = f5tts.infer(
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
ref_text="some call me nature, others call me mother nature.",
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")),
seed=None,
)
```
Check [api.py](../api.py) for more details.
## TensorRT-LLM Deployment
See [detailed instructions](../runtime/triton_trtllm/README.md) for more information.
## Socket Real-time Service
Real-time voice output with chunk stream:
```bash
# Start socket server
python src/f5_tts/socket_server.py
# If PyAudio not installed
sudo apt-get install portaudio19-dev
pip install pyaudio
# Communicate with socket client
python src/f5_tts/socket_client.py
```
## Speech Editing
To test speech editing capabilities, use the following command:
```bash
python src/f5_tts/infer/speech_edit.py
```
================================================
FILE: src/f5_tts/infer/SHARED.md
================================================
# Shared Model Cards
### **Prerequisites of using**
- This document is serving as a quick lookup table for the community training/finetuning result, with various language support.
- The models in this repository are open source and are based on voluntary contributions from contributors.
- The use of models must be conditioned on respect for the respective creators. The convenience brought comes from their efforts.
### **Welcome to share here**
- Have a pretrained/finetuned result: model checkpoint (pruned best to facilitate inference, i.e. leave only `ema_model_state_dict`) and corresponding vocab file (for tokenization).
- Host a public [huggingface model repository](https://huggingface.co/new) and upload the model related files.
- Make a pull request adding a model card to the current page, i.e. `src\f5_tts\infer\SHARED.md`.
### Supported Languages
- [Multilingual](#multilingual)
- [F5-TTS v1 v0 Base @ zh \& en @ F5-TTS](#f5-tts-v1-v0-base--zh--en--f5-tts)
- [Arabic](#arabic)
- [F5-TTS Small @ ar & en @ SILMA AI](#f5-tts-small--ar--en--silma-ai)
- [English](#english)
- [Finnish](#finnish)
- [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen)
- [French](#french)
- [F5-TTS Base @ fr @ RASPIAUDIO](#f5-tts-base--fr--raspiaudio)
- [German](#german)
- [F5-TTS Base @ de @ hvoss-techfak](#f5-tts-base--de--hvoss-techfak)
- [Hindi](#hindi)
- [F5-TTS Small @ hi @ SPRINGLab](#f5-tts-small--hi--springlab)
- [Italian](#italian)
- [F5-TTS Base @ it @ alien79](#f5-tts-base--it--alien79)
- [Japanese](#japanese)
- [F5-TTS Base @ ja @ Jmica](#f5-tts-base--ja--jmica)
- [Latvian](#latvian)
- [F5-TTS Base @ lv @ RaivisDejus](#f5-tts-base--lv--raivisdejus)
- [Mandarin](#mandarin)
- [Russian](#russian)
- [F5-TTS Base @ ru @ HotDro4illa](#f5-tts-base--ru--hotdro4illa)
- [Spanish](#spanish)
- [F5-TTS Base @ es @ jpgallegoar](#f5-tts-base--es--jpgallegoar)
## Multilingual
#### F5-TTS v1 v0 Base @ zh & en @ F5-TTS
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS v1 Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_v1_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
```bash
Model: hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors
# A Variant Model: hf://SWivid/F5-TTS/F5TTS_v1_Base_no_zero_init/model_1250000.safetensors
Vocab: hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
```
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
```bash
Model: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors
Vocab: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
*Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...*
## Arabic
#### F5-TTS Small @ ar & en @ SILMA AI
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Small|[ckpt & vocab](https://huggingface.co/silma-ai/silma-tts)| Tens of thousands EN/AR |Apache-2.0|
- Pretrained by [SILMA.AI](https://silma.ai)
- [GitHub repo](https://github.com/SILMA-AI/silma-tts), Inference code
## English
## Finnish
#### F5-TTS Base @ fi @ AsmoKoskinen
|Model|🤗Hugging Face|Data|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/AsmoKoskinen/F5-TTS_Finnish_Model)|[Common Voice](https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0), [Vox Populi](https://huggingface.co/datasets/facebook/voxpopuli)|cc-by-nc-4.0|
```bash
Model: hf://AsmoKoskinen/F5-TTS_Finnish_Model/model_common_voice_fi_vox_populi_fi_20241206.safetensors
Vocab: hf://AsmoKoskinen/F5-TTS_Finnish_Model/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
## French
#### F5-TTS Base @ fr @ RASPIAUDIO
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/RASPIAUDIO/F5-French-MixedSpeakers-reduced)|[LibriVox](https://librivox.org/)|cc-by-nc-4.0|
```bash
Model: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt
Vocab: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
- [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
- [Tutorial video to train a new language model](https://www.youtube.com/watch?v=UO4usaOojys).
- [Discussion about this training can be found here](https://github.com/SWivid/F5-TTS/issues/434).
## German
#### F5-TTS Base @ de @ hvoss-techfak
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/hvoss-techfak/F5-TTS-German)|[Mozilla Common Voice 19.0](https://commonvoice.mozilla.org/en/datasets) & 800 hours Crowdsourced |cc-by-nc-4.0|
```bash
Model: hf://hvoss-techfak/F5-TTS-German/model_f5tts_german.pt
Vocab: hf://hvoss-techfak/F5-TTS-German/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
- Finetuned by [@hvoss-techfak](https://github.com/hvoss-techfak)
## Hindi
#### F5-TTS Small @ hi @ SPRINGLab
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Small|[ckpt & vocab](https://huggingface.co/SPRINGLab/F5-Hindi-24KHz)|[IndicTTS Hi](https://huggingface.co/datasets/SPRINGLab/IndicTTS-Hindi) & [IndicVoices-R Hi](https://huggingface.co/datasets/SPRINGLab/IndicVoices-R_Hindi) |cc-by-4.0|
```bash
Model: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors
Vocab: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt
Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
- Authors: SPRING Lab, Indian Institute of Technology, Madras
- Website: https://asr.iitm.ac.in/
## Italian
#### F5-TTS Base @ it @ alien79
|Model|🤗Hugging Face|Data|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/alien79/F5-TTS-italian)|[ylacombe/cml-tts](https://huggingface.co/datasets/ylacombe/cml-tts) |cc-by-nc-4.0|
```bash
Model: hf://alien79/F5-TTS-italian/model_159600.safetensors
Vocab: hf://alien79/F5-TTS-italian/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
- Trained by [Mithril Man](https://github.com/MithrilMan)
- Model details on [hf project home](https://huggingface.co/alien79/F5-TTS-italian)
- Open to collaborations to further improve the model
## Japanese
#### F5-TTS Base @ ja @ Jmica
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/Jmica/F5TTS/tree/main/JA_21999120)|[Emilia 1.7k JA](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07) & [Galgame Dataset 5.4k](https://huggingface.co/datasets/OOPPEENN/Galgame_Dataset)|cc-by-nc-4.0|
```bash
Model: hf://Jmica/F5TTS/JA_21999120/model_21999120.pt
Vocab: hf://Jmica/F5TTS/JA_21999120/vocab_japanese.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
## Latvian
#### F5-TTS Base @ lv @ RaivisDejus
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/RaivisDejus/F5-TTS-Latvian)|[Common voice](https://datacollective.mozillafoundation.org/datasets/cmj8u3pec00flnxxbntvfb4as)|cc0-1.0|
```bash
Model: hf://RaivisDejus/F5-TTS-Latvian/model.safetensors
Vocab: hf://RaivisDejus/F5-TTS-Latvian/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
## Mandarin
## Russian
#### F5-TTS Base @ ru @ HotDro4illa
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/hotstone228/F5-TTS-Russian)|[Common voice](https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0)|cc-by-nc-4.0|
```bash
Model: hf://hotstone228/F5-TTS-Russian/model_last.safetensors
Vocab: hf://hotstone228/F5-TTS-Russian/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
- Finetuned by [HotDro4illa](https://github.com/HotDro4illa)
- Any improvements are welcome
## Spanish
#### F5-TTS Base @ es @ jpgallegoar
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/jpgallegoar/F5-Spanish)|[Voxpopuli](https://huggingface.co/datasets/facebook/voxpopuli) & Crowdsourced & TEDx, 218 hours|cc0-1.0|
- @jpgallegoar [GitHub repo](https://github.com/jpgallegoar/Spanish-F5), Jupyter Notebook and Gradio usage for Spanish model.
================================================
FILE: src/f5_tts/infer/examples/basic/basic.toml
================================================
# F5TTS_v1_Base | E2TTS_Base
model = "F5TTS_v1_Base"
ref_audio = "infer/examples/basic/basic_ref_en.wav"
# If an empty "", transcribes the reference audio automatically.
ref_text = "Some call me nature, others call me mother nature."
gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
# File with text to generate. Ignores the text above.
gen_file = ""
remove_silence = false
output_dir = "tests"
output_file = "infer_cli_basic.wav"
================================================
FILE: src/f5_tts/infer/examples/multi/story.toml
================================================
# F5TTS_v1_Base | E2TTS_Base
model = "F5TTS_v1_Base"
ref_audio = "infer/examples/multi/main.flac"
# If an empty "", transcribes the reference audio automatically.
ref_text = ""
gen_text = ""
# File with text to generate. Ignores the text above.
gen_file = "infer/examples/multi/story.txt"
remove_silence = true
output_dir = "tests"
output_file = "infer_cli_story.wav"
[voices.town]
ref_audio = "infer/examples/multi/town.flac"
ref_text = ""
speed = 0.8 # will ignore global speed
[voices.country]
ref_audio = "infer/examples/multi/country.flac"
ref_text = ""
================================================
FILE: src/f5_tts/infer/examples/multi/story.txt
================================================
A Town Mouse and a Country Mouse were acquaintances, and the Country Mouse one day invited his friend to come and see him at his home in the fields. The Town Mouse came, and they sat down to a dinner of barleycorns and roots, the latter of which had a distinctly earthy flavour. The fare was not much to the taste of the guest, and presently he broke out with [town] "My poor dear friend, you live here no better than the ants! Now, you should just see how I fare! My larder is a regular horn of plenty. You must come and stay with me, and I promise you you shall live on the fat of the land." [main] So when he returned to town he took the Country Mouse with him, and showed him into a larder containing flour and oatmeal and figs and honey and dates. The Country Mouse had never seen anything like it, and sat down to enjoy the luxuries his friend provided: but before they had well begun, the door of the larder opened and someone came in. The two Mice scampered off and hid themselves in a narrow and exceedingly uncomfortable hole. Presently, when all was quiet, they ventured out again; but someone else came in, and off they scuttled again. This was too much for the visitor. [country] "Goodbye," [main] said he, [country] "I'm off. You live in the lap of luxury, I can see, but you are surrounded by dangers; whereas at home I can enjoy my simple dinner of roots and corn in peace."
================================================
FILE: src/f5_tts/infer/examples/vocab.txt
================================================
!
"
#
$
%
&
'
(
)
*
+
,
-
.
/
0
1
2
3
4
5
6
7
8
9
:
;
=
>
?
@
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
[
\
]
_
a
a1
ai1
ai2
ai3
ai4
an1
an3
an4
ang1
ang2
ang4
ao1
ao2
ao3
ao4
b
ba
ba1
ba2
ba3
ba4
bai1
bai2
bai3
bai4
ban1
ban2
ban3
ban4
bang1
bang2
bang3
bang4
bao1
bao2
bao3
bao4
bei
bei1
bei2
bei3
bei4
ben1
ben2
ben3
ben4
beng
beng1
beng2
beng3
beng4
bi1
bi2
bi3
bi4
bian1
bian2
bian3
bian4
biao1
biao2
biao3
bie1
bie2
bie3
bie4
bin1
bin4
bing1
bing2
bing3
bing4
bo
bo1
bo2
bo3
bo4
bu2
bu3
bu4
c
ca1
cai1
cai2
cai3
cai4
can1
can2
can3
can4
cang1
cang2
cao1
cao2
cao3
ce4
cen1
cen2
ceng1
ceng2
ceng4
cha1
cha2
cha3
cha4
chai1
chai2
chan1
chan2
chan3
chan4
chang1
chang2
chang3
chang4
chao1
chao2
chao3
che1
che2
che3
che4
chen1
chen2
chen3
chen4
cheng1
cheng2
cheng3
cheng4
chi1
chi2
chi3
chi4
chong1
chong2
chong3
chong4
chou1
chou2
chou3
chou4
chu1
chu2
chu3
chu4
chua1
chuai1
chuai2
chuai3
chuai4
chuan1
chuan2
chuan3
chuan4
chuang1
chuang2
chuang3
chuang4
chui1
chui2
chun1
chun2
chun3
chuo1
chuo4
ci1
ci2
ci3
ci4
cong1
cong2
cou4
cu1
cu4
cuan1
cuan2
cuan4
cui1
cui3
cui4
cun1
cun2
cun4
cuo1
cuo2
cuo4
d
da
da1
da2
da3
da4
dai1
dai2
dai3
dai4
dan1
dan2
dan3
dan4
dang1
dang2
dang3
dang4
dao1
dao2
dao3
dao4
de
de1
de2
dei3
den4
deng1
deng2
deng3
deng4
di1
di2
di3
di4
dia3
dian1
dian2
dian3
dian4
diao1
diao3
diao4
die1
die2
die4
ding1
ding2
ding3
ding4
diu1
dong1
dong3
dong4
dou1
dou2
dou3
dou4
du1
du2
du3
du4
duan1
duan2
duan3
duan4
dui1
dui4
dun1
dun3
dun4
duo1
duo2
duo3
duo4
e
e1
e2
e3
e4
ei2
en1
en4
er
er2
er3
er4
f
fa1
fa2
fa3
fa4
fan1
fan2
fan3
fan4
fang1
fang2
fang3
fang4
fei1
fei2
fei3
fei4
fen1
fen2
fen3
fen4
feng1
feng2
feng3
feng4
fo2
fou2
fou3
fu1
fu2
fu3
fu4
g
ga1
ga2
ga3
ga4
gai1
gai2
gai3
gai4
gan1
gan2
gan3
gan4
gang1
gang2
gang3
gang4
gao1
gao2
gao3
gao4
ge1
ge2
ge3
ge4
gei2
gei3
gen1
gen2
gen3
gen4
geng1
geng3
geng4
gong1
gong3
gong4
gou1
gou2
gou3
gou4
gu
gu1
gu2
gu3
gu4
gua1
gua2
gua3
gua4
guai1
guai2
guai3
guai4
guan1
guan2
guan3
guan4
guang1
guang2
guang3
guang4
gui1
gui2
gui3
gui4
gun3
gun4
guo1
guo2
guo3
guo4
h
ha1
ha2
ha3
hai1
hai2
hai3
hai4
han1
han2
han3
han4
hang1
hang2
hang4
hao1
hao2
hao3
hao4
he1
he2
he4
hei1
hen2
hen3
hen4
heng1
heng2
heng4
hong1
hong2
hong3
hong4
hou1
hou2
hou3
hou4
hu1
hu2
hu3
hu4
hua1
hua2
hua4
huai2
huai4
huan1
huan2
huan3
huan4
huang1
huang2
huang3
huang4
hui1
hui2
hui3
hui4
hun1
hun2
hun4
huo
huo1
huo2
huo3
huo4
i
j
ji1
ji2
ji3
ji4
jia
jia1
jia2
jia3
jia4
jian1
jian2
jian3
jian4
jiang1
jiang2
jiang3
jiang4
jiao1
jiao2
jiao3
jiao4
jie1
jie2
jie3
jie4
jin1
jin2
jin3
jin4
jing1
jing2
jing3
jing4
jiong3
jiu1
jiu2
jiu3
jiu4
ju1
ju2
ju3
ju4
juan1
juan2
juan3
juan4
jue1
jue2
jue4
jun1
jun4
k
ka1
ka2
ka3
kai1
kai2
kai3
kai4
kan1
kan2
kan3
kan4
kang1
kang2
kang4
kao1
kao2
kao3
kao4
ke1
ke2
ke3
ke4
ken3
keng1
kong1
kong3
kong4
kou1
kou2
kou3
kou4
ku1
ku2
ku3
ku4
kua1
kua3
kua4
kuai3
kuai4
kuan1
kuan2
kuan3
kuang1
kuang2
kuang4
kui1
kui2
kui3
kui4
kun1
kun3
kun4
kuo4
l
la
la1
la2
la3
la4
lai2
lai4
lan2
lan3
lan4
lang1
lang2
lang3
lang4
lao1
lao2
lao3
lao4
le
le1
le4
lei
lei1
lei2
lei3
lei4
leng1
leng2
leng3
leng4
li
li1
li2
li3
li4
lia3
lian2
lian3
lian4
liang2
liang3
liang4
liao1
liao2
liao3
liao4
lie1
lie2
lie3
lie4
lin1
lin2
lin3
lin4
ling2
ling3
ling4
liu1
liu2
liu3
liu4
long1
long2
long3
long4
lou1
lou2
lou3
lou4
lu1
lu2
lu3
lu4
luan2
luan3
luan4
lun1
lun2
lun4
luo1
luo2
luo3
luo4
lv2
lv3
lv4
lve3
lve4
m
ma
ma1
ma2
ma3
ma4
mai2
mai3
mai4
man1
man2
man3
man4
mang2
mang3
mao1
mao2
mao3
mao4
me
mei2
mei3
mei4
men
men1
men2
men4
meng
meng1
meng2
meng3
meng4
mi1
mi2
mi3
mi4
mian2
mian3
mian4
miao1
miao2
miao3
miao4
mie1
mie4
min2
min3
ming2
ming3
ming4
miu4
mo1
mo2
mo3
mo4
mou1
mou2
mou3
mu2
mu3
mu4
n
n2
na1
na2
na3
na4
nai2
nai3
nai4
nan1
nan2
nan3
nan4
nang1
nang2
nang3
nao1
nao2
nao3
nao4
ne
ne2
ne4
nei3
nei4
nen4
neng2
ni1
ni2
ni3
ni4
nian1
nian2
nian3
nian4
niang2
niang4
niao2
niao3
niao4
nie1
nie4
nin2
ning2
ning3
ning4
niu1
niu2
niu3
niu4
nong2
nong4
nou4
nu2
nu3
nu4
nuan3
nuo2
nuo4
nv2
nv3
nve4
o
o1
o2
ou1
ou2
ou3
ou4
p
pa1
pa2
pa4
pai1
pai2
pai3
pai4
pan1
pan2
pan4
pang1
pang2
pang4
pao1
pao2
pao3
pao4
pei1
pei2
pei4
pen1
pen2
pen4
peng1
peng2
peng3
peng4
pi1
pi2
pi3
pi4
pian1
pian2
pian4
piao1
piao2
piao3
piao4
pie1
pie2
pie3
pin1
pin2
pin3
pin4
ping1
ping2
po1
po2
po3
po4
pou1
pu1
pu2
pu3
pu4
q
qi1
qi2
qi3
qi4
qia1
qia3
qia4
qian1
qian2
qian3
qian4
qiang1
qiang2
qiang3
qiang4
qiao1
qiao2
qiao3
qiao4
qie1
qie2
qie3
qie4
qin1
qin2
qin3
qin4
qing1
qing2
qing3
qing4
qiong1
qiong2
qiu1
qiu2
qiu3
qu1
qu2
qu3
qu4
quan1
quan2
quan3
quan4
que1
que2
que4
qun2
r
ran2
ran3
rang1
rang2
rang3
rang4
rao2
rao3
rao4
re2
re3
re4
ren2
ren3
ren4
reng1
reng2
ri4
rong1
rong2
rong3
rou2
rou4
ru2
ru3
ru4
ruan2
ruan3
rui3
rui4
run4
ruo4
s
sa1
sa2
sa3
sa4
sai1
sai4
san1
san2
san3
san4
sang1
sang3
sang4
sao1
sao2
sao3
sao4
se4
sen1
seng1
sha1
sha2
sha3
sha4
shai1
shai2
shai3
shai4
shan1
shan3
shan4
shang
shang1
shang3
shang4
shao1
shao2
shao3
shao4
she1
she2
she3
she4
shei2
shen1
shen2
shen3
shen4
sheng1
sheng2
sheng3
sheng4
shi
shi1
shi2
shi3
shi4
shou1
shou2
shou3
shou4
shu1
shu2
shu3
shu4
shua1
shua2
shua3
shua4
shuai1
shuai3
shuai4
shuan1
shuan4
shuang1
shuang3
shui2
shui3
shui4
shun3
shun4
shuo1
shuo4
si1
si2
si3
si4
song1
song3
song4
sou1
sou3
sou4
su1
su2
su4
suan1
suan4
sui1
sui2
sui3
sui4
sun1
sun3
suo
suo1
suo2
suo3
t
ta1
ta2
ta3
ta4
tai1
tai2
tai4
tan1
tan2
tan3
tan4
tang1
tang2
tang3
tang4
tao1
tao2
tao3
tao4
te4
teng2
ti1
ti2
ti3
ti4
tian1
tian2
tian3
tiao1
tiao2
tiao3
tiao4
tie1
tie2
tie3
tie4
ting1
ting2
ting3
tong1
tong2
tong3
tong4
tou
tou1
tou2
tou4
tu1
tu2
tu3
tu4
tuan1
tuan2
tui1
tui2
tui3
tui4
tun1
tun2
tun4
tuo1
tuo2
tuo3
tuo4
u
v
w
wa
wa1
wa2
wa3
wa4
wai1
wai3
wai4
wan1
wan2
wan3
wan4
wang1
wang2
wang3
wang4
wei1
wei2
wei3
wei4
wen1
wen2
wen3
wen4
weng1
weng4
wo1
wo2
wo3
wo4
wu1
wu2
wu3
wu4
x
xi1
xi2
xi3
xi4
xia1
xia2
xia4
xian1
xian2
xian3
xian4
xiang1
xiang2
xiang3
xiang4
xiao1
xiao2
xiao3
xiao4
xie1
xie2
xie3
xie4
xin1
xin2
xin4
xing1
xing2
xing3
xing4
xiong1
xiong2
xiu1
xiu3
xiu4
xu
xu1
xu2
xu3
xu4
xuan1
xuan2
xuan3
xuan4
xue1
xue2
xue3
xue4
xun1
xun2
xun4
y
ya
ya1
ya2
ya3
ya4
yan1
yan2
yan3
yan4
yang1
yang2
yang3
yang4
yao1
yao2
yao3
yao4
ye1
ye2
ye3
ye4
yi
yi1
yi2
yi3
yi4
yin1
yin2
yin3
yin4
ying1
ying2
ying3
ying4
yo1
yong1
yong2
yong3
yong4
you1
you2
you3
you4
yu1
yu2
yu3
yu4
yuan1
yuan2
yuan3
yuan4
yue1
yue4
yun1
yun2
yun3
yun4
z
za1
za2
za3
zai1
zai3
zai4
zan1
zan2
zan3
zan4
zang1
zang4
zao1
zao2
zao3
zao4
ze2
ze4
zei2
zen3
zeng1
zeng4
zha1
zha2
zha3
zha4
zhai1
zhai2
zhai3
zhai4
zhan1
zhan2
zhan3
zhan4
zhang1
zhang2
zhang3
zhang4
zhao1
zhao2
zhao3
zhao4
zhe
zhe1
zhe2
zhe3
zhe4
zhen1
zhen2
zhen3
zhen4
zheng1
zheng2
zheng3
zheng4
zhi1
zhi2
zhi3
zhi4
zhong1
zhong2
zhong3
zhong4
zhou1
zhou2
zhou3
zhou4
zhu1
zhu2
zhu3
zhu4
zhua1
zhua2
zhua3
zhuai1
zhuai3
zhuai4
zhuan1
zhuan2
zhuan3
zhuan4
zhuang1
zhuang4
zhui1
zhui4
zhun1
zhun2
zhun3
zhuo1
zhuo2
zi
zi1
zi2
zi3
zi4
zong1
zong2
zong3
zong4
zou1
zou2
zou3
zou4
zu1
zu2
zu3
zuan1
zuan3
zuan4
zui2
zui3
zui4
zun1
zuo
zuo1
zuo2
zuo3
zuo4
{
~
¡
¢
£
¥
§
¨
©
«
®
¯
°
±
²
³
´
µ
·
¹
º
»
¼
½
¾
¿
À
Á
Â
Ã
Ä
Å
Æ
Ç
È
É
Ê
Í
Î
Ñ
Ó
Ö
×
Ø
Ú
Ü
Ý
Þ
ß
à
á
â
ã
ä
å
æ
ç
è
é
ê
ë
ì
í
î
ï
ð
ñ
ò
ó
ô
õ
ö
ø
ù
ú
û
ü
ý
Ā
ā
ă
ą
ć
Č
č
Đ
đ
ē
ė
ę
ě
ĝ
ğ
ħ
ī
į
İ
ı
Ł
ł
ń
ņ
ň
ŋ
Ō
ō
ő
œ
ř
Ś
ś
Ş
ş
Š
š
Ť
ť
ũ
ū
ź
Ż
ż
Ž
ž
ơ
ư
ǎ
ǐ
ǒ
ǔ
ǚ
ș
ț
ɑ
ɔ
ɕ
ə
ɛ
ɜ
ɡ
ɣ
ɪ
ɫ
ɴ
ɹ
ɾ
ʃ
ʊ
ʌ
ʒ
ʔ
ʰ
ʷ
ʻ
ʾ
ʿ
ˈ
ː
˙
˜
ˢ
́
̅
Α
Β
Δ
Ε
Θ
Κ
Λ
Μ
Ξ
Π
Σ
Τ
Φ
Χ
Ψ
Ω
ά
έ
ή
ί
α
β
γ
δ
ε
ζ
η
θ
ι
κ
λ
μ
ν
ξ
ο
π
ρ
ς
σ
τ
υ
φ
χ
ψ
ω
ϊ
ό
ύ
ώ
ϕ
ϵ
Ё
А
Б
В
Г
Д
Е
Ж
З
И
Й
К
Л
М
Н
О
П
Р
С
Т
У
Ф
Х
Ц
Ч
Ш
Щ
Ы
Ь
Э
Ю
Я
а
б
в
г
д
е
ж
з
и
й
к
л
м
н
о
п
р
с
т
у
ф
х
ц
ч
ш
щ
ъ
ы
ь
э
ю
я
ё
і
ְ
ִ
ֵ
ֶ
ַ
ָ
ֹ
ּ
־
ׁ
א
ב
ג
ד
ה
ו
ז
ח
ט
י
כ
ל
ם
מ
ן
נ
ס
ע
פ
ק
ר
ש
ת
أ
ب
ة
ت
ج
ح
د
ر
ز
س
ص
ط
ع
ق
ك
ل
م
ن
ه
و
ي
َ
ُ
ِ
ْ
ก
ข
ง
จ
ต
ท
น
ป
ย
ร
ว
ส
ห
อ
ฮ
ั
า
ี
ึ
โ
ใ
ไ
่
้
์
ḍ
Ḥ
ḥ
ṁ
ṃ
ṅ
ṇ
Ṛ
ṛ
Ṣ
ṣ
Ṭ
ṭ
ạ
ả
Ấ
ấ
ầ
ậ
ắ
ằ
ẻ
ẽ
ế
ề
ể
ễ
ệ
ị
ọ
ỏ
ố
ồ
ộ
ớ
ờ
ở
ụ
ủ
ứ
ữ
ἀ
ἁ
Ἀ
ἐ
ἔ
ἰ
ἱ
ὀ
ὁ
ὐ
ὲ
ὸ
ᾶ
᾽
ῆ
ῇ
ῶ
‑
‒
–
—
―
‖
†
‡
•
…
‧
′
″
⁄
⁰
⁴
⁵
⁶
⁷
⁸
⁹
₁
₂
₃
€
₱
₹
₽
℃
ℏ
ℓ
№
ℝ
™
⅓
⅔
⅛
→
∂
∈
∑
−
∗
√
∞
∫
≈
≠
≡
≤
≥
⋅
⋯
█
♪
⟨
⟩
、
。
《
》
「
」
【
】
あ
う
え
お
か
が
き
ぎ
く
ぐ
け
げ
こ
ご
さ
し
じ
す
ず
せ
ぜ
そ
ぞ
た
だ
ち
っ
つ
で
と
ど
な
に
ね
の
は
ば
ひ
ぶ
へ
べ
ま
み
む
め
も
ゃ
や
ゆ
ょ
よ
ら
り
る
れ
ろ
わ
を
ん
ァ
ア
ィ
イ
ウ
ェ
エ
オ
カ
ガ
キ
ク
ケ
ゲ
コ
ゴ
サ
ザ
シ
ジ
ス
ズ
セ
ゾ
タ
ダ
チ
ッ
ツ
テ
デ
ト
ド
ナ
ニ
ネ
ノ
バ
パ
ビ
ピ
フ
プ
ヘ
ベ
ペ
ホ
ボ
ポ
マ
ミ
ム
メ
モ
ャ
ヤ
ュ
ユ
ョ
ヨ
ラ
リ
ル
レ
ロ
ワ
ン
・
ー
ㄋ
ㄍ
ㄎ
ㄏ
ㄓ
ㄕ
ㄚ
ㄜ
ㄟ
ㄤ
ㄥ
ㄧ
ㄱ
ㄴ
ㄷ
ㄹ
ㅁ
ㅂ
ㅅ
ㅈ
ㅍ
ㅎ
ㅏ
ㅓ
ㅗ
ㅜ
ㅡ
ㅣ
㗎
가
각
간
갈
감
갑
갓
갔
강
같
개
거
건
걸
겁
것
겉
게
겠
겨
결
겼
경
계
고
곤
골
곱
공
과
관
광
교
구
국
굴
귀
귄
그
근
글
금
기
긴
길
까
깍
깔
깜
깨
께
꼬
꼭
꽃
꾸
꿔
끔
끗
끝
끼
나
난
날
남
납
내
냐
냥
너
넘
넣
네
녁
년
녕
노
녹
놀
누
눈
느
는
늘
니
님
닙
다
닥
단
달
닭
당
대
더
덕
던
덥
데
도
독
동
돼
됐
되
된
될
두
둑
둥
드
들
등
디
따
딱
딸
땅
때
떤
떨
떻
또
똑
뚱
뛰
뜻
띠
라
락
란
람
랍
랑
래
랜
러
런
럼
렇
레
려
력
렵
렸
로
록
롬
루
르
른
를
름
릉
리
릴
림
마
막
만
많
말
맑
맙
맛
매
머
먹
멍
메
면
명
몇
모
목
몸
못
무
문
물
뭐
뭘
미
민
밌
밑
바
박
밖
반
받
발
밤
밥
방
배
백
밸
뱀
버
번
벌
벚
베
벼
벽
별
병
보
복
본
볼
봐
봤
부
분
불
비
빔
빛
빠
빨
뼈
뽀
뿅
쁘
사
산
살
삼
샀
상
새
색
생
서
선
설
섭
섰
성
세
셔
션
셨
소
속
손
송
수
숙
순
술
숫
숭
숲
쉬
쉽
스
슨
습
슷
시
식
신
실
싫
심
십
싶
싸
써
쓰
쓴
씌
씨
씩
씬
아
악
안
않
알
야
약
얀
양
얘
어
언
얼
엄
업
없
었
엉
에
여
역
연
염
엽
영
옆
예
옛
오
온
올
옷
옹
와
왔
왜
요
욕
용
우
운
울
웃
워
원
월
웠
위
윙
유
육
윤
으
은
을
음
응
의
이
익
인
일
읽
임
입
있
자
작
잔
잖
잘
잡
잤
장
재
저
전
점
정
제
져
졌
조
족
좀
종
좋
죠
주
준
줄
중
줘
즈
즐
즘
지
진
집
짜
짝
쩌
쪼
쪽
쫌
쭈
쯔
찌
찍
차
착
찾
책
처
천
철
체
쳐
쳤
초
촌
추
출
춤
춥
춰
치
친
칠
침
칩
칼
커
켓
코
콩
쿠
퀴
크
큰
큽
키
킨
타
태
터
턴
털
테
토
통
투
트
특
튼
틀
티
팀
파
팔
패
페
펜
펭
평
포
폭
표
품
풍
프
플
피
필
하
학
한
할
함
합
항
해
햇
했
행
허
험
형
혜
호
혼
홀
화
회
획
후
휴
흐
흔
희
히
힘
ﷺ
ﷻ
!
,
?
�
𠮶
================================================
FILE: src/f5_tts/infer/infer_cli.py
================================================
import argparse
import codecs
import os
import re
from datetime import datetime
from importlib.resources import files
from pathlib import Path
import numpy as np
import soundfile as sf
import tomli
from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf
from unidecode import unidecode
from f5_tts.infer.utils_infer import (
cfg_strength,
cross_fade_duration,
device,
fix_duration,
infer_process,
load_model,
load_vocoder,
mel_spec_type,
nfe_step,
preprocess_ref_audio_text,
remove_silence_for_generated_wav,
speed,
sway_sampling_coef,
target_rms,
)
parser = argparse.ArgumentParser(
prog="python3 infer-cli.py",
description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.",
epilog="Specify options above to override one or more settings from config.",
)
parser.add_argument(
"-c",
"--config",
type=str,
default=os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"),
help="The configuration file, default see infer/examples/basic/basic.toml",
)
# Note. Not to provide default value here in order to read default from config file
parser.add_argument(
"-m",
"--model",
type=str,
help="The model name: F5TTS_v1_Base | F5TTS_Base | E2TTS_Base | etc.",
)
parser.add_argument(
"-mc",
"--model_cfg",
type=str,
help="The path to F5-TTS model config file .yaml",
)
parser.add_argument(
"-p",
"--ckpt_file",
type=str,
help="The path to model checkpoint .pt, leave blank to use default",
)
parser.add_argument(
"-v",
"--vocab_file",
type=str,
help="The path to vocab file .txt, leave blank to use default",
)
parser.add_argument(
"-r",
"--ref_audio",
type=str,
help="The reference audio file.",
)
parser.add_argument(
"-s",
"--ref_text",
type=str,
help="The transcript/subtitle for the reference audio",
)
parser.add_argument(
"-t",
"--gen_text",
type=str,
help="The text to make model synthesize a speech",
)
parser.add_argument(
"-f",
"--gen_file",
type=str,
help="The file with text to generate, will ignore --gen_text",
)
parser.add_argument(
"-o",
"--output_dir",
type=str,
help="The path to output folder",
)
parser.add_argument(
"-w",
"--output_file",
type=str,
help="The name of output file",
)
parser.add_argument(
"--save_chunk",
action="store_true",
help="To save each audio chunks during inference",
)
parser.add_argument(
"--no_legacy_text",
action="store_false",
help="Not to use lossy ASCII transliterations of unicode text in saved file names.",
)
parser.add_argument(
"--remove_silence",
action="store_true",
help="To remove long silence found in ouput",
)
parser.add_argument(
"--load_vocoder_from_local",
action="store_true",
help="To load vocoder from local dir, default to ../checkpoints/vocos-mel-24khz",
)
parser.add_argument(
"--vocoder_name",
type=str,
choices=["vocos", "bigvgan"],
help=f"Used vocoder name: vocos | bigvgan, default {mel_spec_type}",
)
parser.add_argument(
"--target_rms",
type=float,
help=f"Target output speech loudness normalization value, default {target_rms}",
)
parser.add_argument(
"--cross_fade_duration",
type=float,
help=f"Duration of cross-fade between audio segments in seconds, default {cross_fade_duration}",
)
parser.add_argument(
"--nfe_step",
type=int,
help=f"The number of function evaluation (denoising steps), default {nfe_step}",
)
parser.add_argument(
"--cfg_strength",
type=float,
help=f"Classifier-free guidance strength, default {cfg_strength}",
)
parser.add_argument(
"--sway_sampling_coef",
type=float,
help=f"Sway Sampling coefficient, default {sway_sampling_coef}",
)
parser.add_argument(
"--speed",
type=float,
help=f"The speed of the generated audio, default {speed}",
)
parser.add_argument(
"--fix_duration",
type=float,
help=f"Fix the total duration (ref and gen audios) in seconds, default {fix_duration}",
)
parser.add_argument(
"--device",
type=str,
help="Specify the device to run on",
)
args = parser.parse_args()
# config file
config = tomli.load(open(args.config, "rb"))
# command-line interface parameters
model = args.model or config.get("model", "F5TTS_v1_Base")
ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
vocab_file = args.vocab_file or config.get("vocab_file", "")
ref_audio = args.ref_audio or config.get("ref_audio", "infer/examples/basic/basic_ref_en.wav")
ref_text = (
args.ref_text
if args.ref_text is not None
else config.get("ref_text", "Some call me nature, others call me mother nature.")
)
gen_text = args.gen_text or config.get("gen_text", "Here we generate something just for test.")
gen_file = args.gen_file or config.get("gen_file", "")
output_dir = args.output_dir or config.get("output_dir", "tests")
output_file = args.output_file or config.get(
"output_file", f"infer_cli_{datetime.now().strftime(r'%Y%m%d_%H%M%S')}.wav"
)
save_chunk = args.save_chunk or config.get("save_chunk", False)
use_legacy_text = args.no_legacy_text or config.get("no_legacy_text", False) # no_legacy_text is a store_false arg
if save_chunk and use_legacy_text:
print(
"\nWarning to --save_chunk: lossy ASCII transliterations of unicode text for legacy (.wav) file names, --no_legacy_text to disable.\n"
)
remove_silence = args.remove_silence or config.get("remove_silence", False)
load_vocoder_from_local = args.load_vocoder_from_local or config.get("load_vocoder_from_local", False)
vocoder_name = args.vocoder_name or config.get("vocoder_name", mel_spec_type)
target_rms = args.target_rms or config.get("target_rms", target_rms)
cross_fade_duration = args.cross_fade_duration or config.get("cross_fade_duration", cross_fade_duration)
nfe_step = args.nfe_step or config.get("nfe_step", nfe_step)
cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength)
sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef)
speed = args.speed or config.get("speed", speed)
fix_duration = args.fix_duration or config.get("fix_duration", fix_duration)
device = args.device or config.get("device", device)
# patches for pip pkg user
if "infer/examples/" in ref_audio:
ref_audio = str(files("f5_tts").joinpath(f"{ref_audio}"))
if "infer/examples/" in gen_file:
gen_file = str(files("f5_tts").joinpath(f"{gen_file}"))
if "voices" in config:
for voice in config["voices"]:
voice_ref_audio = config["voices"][voice]["ref_audio"]
if "infer/examples/" in voice_ref_audio:
config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}"))
# ignore gen_text if gen_file provided
if gen_file:
gen_text = codecs.open(gen_file, "r", "utf-8").read()
# output path
wave_path = Path(output_dir) / output_file
# spectrogram_path = Path(output_dir) / "infer_cli_out.png"
if save_chunk:
output_chunk_dir = os.path.join(output_dir, f"{Path(output_file).stem}_chunks")
if not os.path.exists(output_chunk_dir):
os.makedirs(output_chunk_dir)
# load vocoder
if vocoder_name == "vocos":
vocoder_local_path = "../checkpoints/vocos-mel-24khz"
elif vocoder_name == "bigvgan":
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
vocoder = load_vocoder(
vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path, device=device
)
# load TTS model
model_cfg = OmegaConf.load(
args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
)
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
if model != "F5TTS_Base":
assert vocoder_name == model_cfg.model.mel_spec.mel_spec_type
# override for previous models
if model == "F5TTS_Base":
if vocoder_name == "vocos":
ckpt_step = 1200000
elif vocoder_name == "bigvgan":
model = "F5TTS_Base_bigvgan"
ckpt_type = "pt"
elif model == "E2TTS_Base":
repo_name = "E2-TTS"
ckpt_step = 1200000
if not ckpt_file:
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
elif ckpt_file.startswith("hf://"):
ckpt_file = str(cached_path(ckpt_file))
if vocab_file.startswith("hf://"):
vocab_file = str(cached_path(vocab_file))
print(f"Using {model}...")
ema_model = load_model(
model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file, device=device
)
# inference process
def main():
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
if "voices" not in config:
voices = {"main": main_voice}
else:
voices = config["voices"]
voices["main"] = main_voice
for voice in voices:
print("Voice:", voice)
print("ref_audio ", voices[voice]["ref_audio"])
voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
voices[voice]["ref_audio"], voices[voice]["ref_text"]
)
print("ref_audio_", voices[voice]["ref_audio"], "\n\n")
generated_audio_segments = []
reg1 = r"(?=\[\w+\])"
chunks = re.split(reg1, gen_text)
reg2 = r"\[(\w+)\]"
for text in chunks:
if not text.strip():
continue
match = re.match(reg2, text)
if match:
voice = match[1]
else:
print("No voice tag found, using main.")
voice = "main"
if voice not in voices:
print(f"Voice {voice} not found, using main.")
voice = "main"
text = re.sub(reg2, "", text)
ref_audio_ = voices[voice]["ref_audio"]
ref_text_ = voices[voice]["ref_text"]
local_speed = voices[voice].get("speed", speed)
gen_text_ = text.strip()
print(f"Voice: {voice}")
audio_segment, final_sample_rate, spectrogram = infer_process(
ref_audio_,
ref_text_,
gen_text_,
ema_model,
vocoder,
mel_spec_type=vocoder_name,
target_rms=target_rms,
cross_fade_duration=cross_fade_duration,
nfe_step=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
speed=local_speed,
fix_duration=fix_duration,
device=device,
)
generated_audio_segments.append(audio_segment)
if save_chunk:
if len(gen_text_) > 200:
gen_text_ = gen_text_[:200] + " ... "
if use_legacy_text:
gen_text_ = unidecode(gen_text_)
sf.write(
os.path.join(output_chunk_dir, f"{len(generated_audio_segments) - 1}_{gen_text_}.wav"),
audio_segment,
final_sample_rate,
)
if generated_audio_segments:
final_wave = np.concatenate(generated_audio_segments)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
with open(wave_path, "wb") as f:
sf.write(f.name, final_wave, final_sample_rate)
# Remove silence
if remove_silence:
remove_silence_for_generated_wav(f.name)
print(f.name)
if __name__ == "__main__":
main()
================================================
FILE: src/f5_tts/infer/infer_gradio.py
================================================
# ruff: noqa: E402
# Above allows ruff to ignore E402: module level import not at top of file
import gc
import json
import os
import re
import tempfile
from collections import OrderedDict
from functools import lru_cache
from importlib.resources import files
import click
import gradio as gr
import numpy as np
import soundfile as sf
import torch
import torchaudio
from cached_path import cached_path
from transformers import AutoModelForCausalLM, AutoTokenizer
try:
import spaces
USING_SPACES = True
except ImportError:
USING_SPACES = False
def gpu_decorator(func):
if USING_SPACES:
return spaces.GPU(func)
else:
return func
from f5_tts.infer.utils_infer import (
infer_process,
load_model,
load_vocoder,
preprocess_ref_audio_text,
remove_silence_for_generated_wav,
save_spectrogram,
tempfile_kwargs,
)
from f5_tts.model import DiT, UNetT
DEFAULT_TTS_MODEL = "F5-TTS_v1"
tts_model_choice = DEFAULT_TTS_MODEL
DEFAULT_TTS_MODEL_CFG = [
"hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors",
"hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt",
json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)),
]
# load models
vocoder = load_vocoder()
def load_f5tts():
ckpt_path = str(cached_path(DEFAULT_TTS_MODEL_CFG[0]))
F5TTS_model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2])
return load_model(DiT, F5TTS_model_cfg, ckpt_path)
def load_e2tts():
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4, text_mask_padding=False, pe_attn_head=1)
return load_model(UNetT, E2TTS_model_cfg, ckpt_path)
def load_custom(ckpt_path: str, vocab_path="", model_cfg=None):
ckpt_path, vocab_path = ckpt_path.strip(), vocab_path.strip()
if ckpt_path.startswith("hf://"):
ckpt_path = str(cached_path(ckpt_path))
if vocab_path.startswith("hf://"):
vocab_path = str(cached_path(vocab_path))
if model_cfg is None:
model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2])
elif isinstance(model_cfg, str):
model_cfg = json.loads(model_cfg)
return load_model(DiT, model_cfg, ckpt_path, vocab_file=vocab_path)
F5TTS_ema_model = load_f5tts()
E2TTS_ema_model = load_e2tts() if USING_SPACES else None
custom_ema_model, pre_custom_path = None, ""
chat_model_state = None
chat_tokenizer_state = None
@gpu_decorator
def chat_model_inference(messages, model, tokenizer):
"""Generate response using Qwen"""
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=512,
temperature=0.7,
top_p=0.95,
)
generated_ids = [
output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
@gpu_decorator
def load_text_from_file(file):
if file:
with open(file, "r", encoding="utf-8") as f:
text = f.read().strip()
else:
text = ""
return gr.update(value=text)
@lru_cache(maxsize=1000) # NOTE. need to ensure params of infer() hashable
@gpu_decorator
def infer(
ref_audio_orig,
ref_text,
gen_text,
model,
remove_silence,
seed,
cross_fade_duration=0.15,
nfe_step=32,
speed=1,
show_info=gr.Info,
):
if not ref_audio_orig:
gr.Warning("Please provide reference audio.")
return gr.update(), gr.update(), ref_text
# Set inference seed
if seed < 0 or seed > 2**31 - 1:
gr.Warning("Seed must in range 0 ~ 2147483647. Using random seed instead.")
seed = np.random.randint(0, 2**31 - 1)
torch.manual_seed(seed)
used_seed = seed
if not gen_text.strip():
gr.Warning("Please enter text to generate or upload a text file.")
return gr.update(), gr.update(), ref_text
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
if model == DEFAULT_TTS_MODEL:
ema_model = F5TTS_ema_model
elif model == "E2-TTS":
global E2TTS_ema_model
if E2TTS_ema_model is None:
show_info("Loading E2-TTS model...")
E2TTS_ema_model = load_e2tts()
ema_model = E2TTS_ema_model
elif isinstance(model, tuple) and model[0] == "Custom":
assert not USING_SPACES, "Only official checkpoints allowed in Spaces."
global custom_ema_model, pre_custom_path
if pre_custom_path != model[1]:
show_info("Loading Custom TTS model...")
custom_ema_model = load_custom(model[1], vocab_path=model[2], model_cfg=model[3])
pre_custom_path = model[1]
ema_model = custom_ema_model
final_wave, final_sample_rate, combined_spectrogram = infer_process(
ref_audio,
ref_text,
gen_text,
ema_model,
vocoder,
cross_fade_duration=cross_fade_duration,
nfe_step=nfe_step,
speed=speed,
show_info=show_info,
progress=gr.Progress(),
)
# Remove silence
if remove_silence:
with tempfile.NamedTemporaryFile(suffix=".wav", **tempfile_kwargs) as f:
temp_path = f.name
try:
sf.write(temp_path, final_wave, final_sample_rate)
remove_silence_for_generated_wav(f.name)
final_wave, _ = torchaudio.load(f.name)
finally:
os.unlink(temp_path)
final_wave = final_wave.squeeze().cpu().numpy()
# Save the spectrogram
with tempfile.NamedTemporaryFile(suffix=".png", **tempfile_kwargs) as tmp_spectrogram:
spectrogram_path = tmp_spectrogram.name
save_spectrogram(combined_spectrogram, spectrogram_path)
return (final_sample_rate, final_wave), spectrogram_path, ref_text, used_seed
with gr.Blocks() as app_tts:
gr.Markdown("# Batched TTS")
ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
with gr.Row():
gen_text_input = gr.Textbox(
label="Text to Generate",
lines=10,
max_lines=40,
scale=4,
)
gen_text_file = gr.File(label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1)
generate_btn = gr.Button("Synthesize", variant="primary")
with gr.Accordion("Advanced Settings", open=True) as adv_settn:
with gr.Row():
ref_text_input = gr.Textbox(
label="Reference Text",
info="Leave blank to automatically transcribe the reference audio. If you enter text or upload a file, it will override automatic transcription.",
lines=2,
scale=4,
)
ref_text_file = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"], scale=1)
with gr.Row():
randomize_seed = gr.Checkbox(
label="Randomize Seed",
info="Check to use a random seed for each generation. Uncheck to use the seed specified.",
value=True,
scale=3,
)
seed_input = gr.Number(show_label=False, value=0, precision=0, scale=1)
with gr.Column(scale=4):
remove_silence = gr.Checkbox(
label="Remove Silences",
info="If undesired long silence(s) produced, turn on to automatically detect and crop.",
value=False,
)
speed_slider = gr.Slider(
label="Speed",
minimum=0.3,
maximum=2.0,
value=1.0,
step=0.1,
info="Adjust the speed of the audio.",
)
nfe_slider = gr.Slider(
label="NFE Steps",
minimum=4,
maximum=64,
value=32,
step=2,
info="Set the number of denoising steps.",
)
cross_fade_duration_slider = gr.Slider(
label="Cross-Fade Duration (s)",
minimum=0.0,
maximum=1.0,
value=0.15,
step=0.01,
info="Set the duration of the cross-fade between audio clips.",
)
def collapse_accordion():
return gr.Accordion(open=False)
# Workaround for https://github.com/SWivid/F5-TTS/issues/1239#issuecomment-3677987413
# i.e. to set gr.Accordion(open=True) by default, then collapse manually Blocks loaded
app_tts.load(
fn=collapse_accordion,
inputs=None,
outputs=adv_settn,
)
audio_output = gr.Audio(label="Synthesized Audio")
spectrogram_output = gr.Image(label="Spectrogram")
@gpu_decorator
def basic_tts(
ref_audio_input,
ref_text_input,
gen_text_input,
remove_silence,
randomize_seed,
seed_input,
cross_fade_duration_slider,
nfe_slider,
speed_slider,
):
if randomize_seed:
seed_input = np.random.randint(0, 2**31 - 1)
audio_out, spectrogram_path, ref_text_out, used_seed = infer(
ref_audio_input,
ref_text_input,
gen_text_input,
tts_model_choice,
remove_silence,
seed=seed_input,
cross_fade_duration=cross_fade_duration_slider,
nfe_step=nfe_slider,
speed=speed_slider,
)
return audio_out, spectrogram_path, ref_text_out, used_seed
gen_text_file.upload(
load_text_from_file,
inputs=[gen_text_file],
outputs=[gen_text_input],
)
ref_text_file.upload(
load_text_from_file,
inputs=[ref_text_file],
outputs=[ref_text_input],
)
ref_audio_input.clear(
lambda: [None, None],
None,
[ref_text_input, ref_text_file],
)
generate_btn.click(
basic_tts,
inputs=[
ref_audio_input,
ref_text_input,
gen_text_input,
remove_silence,
randomize_seed,
seed_input,
cross_fade_duration_slider,
nfe_slider,
speed_slider,
],
outputs=[audio_output, spectrogram_output, ref_text_input, seed_input],
)
def parse_speechtypes_text(gen_text):
# Pattern to find {str} or {"name": str, "seed": int, "speed": float}
pattern = r"(\{.*?\})"
# Split the text by the pattern
tokens = re.split(pattern, gen_text)
segments = []
current_type_dict = {
"name": "Regular",
"seed": -1,
"speed": 1.0,
}
for i in range(len(tokens)):
if i % 2 == 0:
# This is text
text = tokens[i].strip()
if text:
current_type_dict["text"] = text
segments.append(current_type_dict)
else:
# This is type
type_str = tokens[i].strip()
try: # if type dict
current_type_dict = json.loads(type_str)
except json.decoder.JSONDecodeError:
type_str = type_str[1:-1] # remove brace {}
current_type_dict = {"name": type_str, "seed": -1, "speed": 1.0}
return segments
with gr.Blocks() as app_multistyle:
# New section for multistyle generation
gr.Markdown(
"""
# Multiple Speech-Type Generation
This section allows you to generate multiple speech types or multiple people's voices. Enter your text in the format shown below, or upload a .txt file with the same format. The system will generate speech using the appropriate type. If unspecified, the model will use the regular speech type. The current speech type will be used until the next speech type is specified.
"""
)
with gr.Row():
gr.Markdown(
"""
**Example Input:**
{Regular} Hello, I'd like to order a sandwich please.
{Surprised} What do you mean you're out of bread?
{Sad} I really wanted a sandwich though...
{Angry} You know what, darn you and your little shop!
{Whisper} I'll just go back home and cry now.
{Shouting} Why me?!
"""
)
gr.Markdown(
"""
**Example Input 2:**
{"name": "Speaker1_Happy", "seed": -1, "speed": 1} Hello, I'd like to order a sandwich please.
{"name": "Speaker2_Regular", "seed": -1, "speed": 1} Sorry, we're out of bread.
{"name": "Speaker1_Sad", "seed": -1, "speed": 1} I really wanted a sandwich though...
{"name": "Speaker2_Whisper", "seed": -1, "speed": 1} I'll give you the last one I was hiding.
"""
)
gr.Markdown(
'Upload different audio clips for each speech type. The first speech type is mandatory. You can add additional speech types by clicking the "Add Speech Type" button.'
)
# Regular speech type (mandatory)
with gr.Row(variant="compact") as regular_row:
with gr.Column(scale=1, min_width=160):
regular_name = gr.Textbox(value="Regular", label="Speech Type Name")
regular_insert = gr.Button("Insert Label", variant="secondary")
with gr.Column(scale=3):
regular_audio = gr.Audio(label="Regular Reference Audio", type="filepath")
with gr.Column(scale=3):
regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=4)
with gr.Row():
regular_seed_slider = gr.Slider(
show_label=False, minimum=-1, maximum=999, value=-1, step=1, info="Seed, -1 for random"
)
regular_speed_slider = gr.Slider(
show_label=False, minimum=0.3, maximum=2.0, value=1.0, step=0.1, info="Adjust the speed"
)
with gr.Column(scale=1, min_width=160):
regular_ref_text_file = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"])
# Regular speech type (max 100)
max_speech_types = 100
speech_type_rows = [regular_row]
speech_type_names = [regular_name]
speech_type_audios = [regular_audio]
speech_type_ref_texts = [regular_ref_text]
speech_type_ref_text_files = [regular_ref_text_file]
speech_type_seeds = [regular_seed_slider]
speech_type_speeds = [regular_speed_slider]
speech_type_delete_btns = [None]
speech_type_insert_btns = [regular_insert]
# Additional speech types (99 more)
for i in range(max_speech_types - 1):
with gr.Row(variant="compact", visible=False) as row:
with gr.Column(scale=1, min_width=160):
name_input = gr.Textbox(label="Speech Type Name")
insert_btn = gr.Button("Insert Label", variant="secondary")
delete_btn = gr.Button("Delete Type", variant="stop")
with gr.Column(scale=3):
audio_input = gr.Audio(label="Reference Audio", type="filepath")
with gr.Column(scale=3):
ref_text_input = gr.Textbox(label="Reference Text", lines=4)
with gr.Row():
seed_input = gr.Slider(
show_label=False, minimum=-1, maximum=999, value=-1, step=1, info="Seed. -1 for random"
)
speed_input = gr.Slider(
show_label=False, minimum=0.3, maximum=2.0, value=1.0, step=0.1, info="Adjust the speed"
)
with gr.Column(scale=1, min_width=160):
ref_text_file_input = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"])
speech_type_rows.append(row)
speech_type_names.append(name_input)
speech_type_audios.append(audio_input)
speech_type_ref_texts.append(ref_text_input)
speech_type_ref_text_files.append(ref_text_file_input)
speech_type_seeds.append(seed_input)
speech_type_speeds.append(speed_input)
speech_type_delete_btns.append(delete_btn)
speech_type_insert_btns.append(insert_btn)
# Global logic for all speech types
for i in range(max_speech_types):
speech_type_audios[i].clear(
lambda: [None, None],
None,
[speech_type_ref_texts[i], speech_type_ref_text_files[i]],
)
speech_type_ref_text_files[i].upload(
load_text_from_file,
inputs=[speech_type_ref_text_files[i]],
outputs=[speech_type_ref_texts[i]],
)
# Button to add speech type
add_speech_type_btn = gr.Button("Add Speech Type")
# Keep track of autoincrement of speech types, no roll back
speech_type_count = 1
# Function to add a speech type
def add_speech_type_fn():
row_updates = [gr.update() for _ in range(max_speech_types)]
global speech_type_count
if speech_type_count < max_speech_types:
row_updates[speech_type_count] = gr.update(visible=True)
speech_type_count += 1
else:
gr.Warning("Exhausted maximum number of speech types. Consider restart the app.")
return row_updates
add_speech_type_btn.click(add_speech_type_fn, outputs=speech_type_rows)
# Function to delete a speech type
def delete_speech_type_fn():
return gr.update(visible=False), None, None, None, None
# Update delete button clicks and ref text file changes
for i in range(1, len(speech_type_delete_btns)):
speech_type_delete_btns[i].click(
delete_speech_type_fn,
outputs=[
speech_type_rows[i],
speech_type_names[i],
speech_type_audios[i],
speech_type_ref_texts[i],
speech_type_ref_text_files[i],
],
)
# Text input for the prompt
with gr.Row():
gen_text_input_multistyle = gr.Textbox(
label="Text to Generate",
lines=10,
max_lines=40,
scale=4,
placeholder="Enter the script with speaker names (or emotion types) at the start of each block, e.g.:\n\n{Regular} Hello, I'd like to order a sandwich please.\n{Surprised} What do you mean you're out of bread?\n{Sad} I really wanted a sandwich though...\n{Angry} You know what, darn you and your little shop!\n{Whisper} I'll just go back home and cry now.\n{Shouting} Why me?!",
)
gen_text_file_multistyle = gr.File(label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1)
def make_insert_speech_type_fn(index):
def insert_speech_type_fn(current_text, speech_type_name, speech_type_seed, speech_type_speed):
current_text = current_text or ""
if not speech_type_name:
gr.Warning("Please enter speech type name before insert.")
return current_text
speech_type_dict = {
"name": speech_type_name,
"seed": speech_type_seed,
"speed": speech_type_speed,
}
updated_text = current_text + json.dumps(speech_type_dict) + " "
return updated_text
return insert_speech_type_fn
for i, insert_btn in enumerate(speech_type_insert_btns):
insert_fn = make_insert_speech_type_fn(i)
insert_btn.click(
insert_fn,
inputs=[gen_text_input_multistyle, speech_type_names[i], speech_type_seeds[i], speech_type_speeds[i]],
outputs=gen_text_input_multistyle,
)
with gr.Accordion("Advanced Settings", open=True):
with gr.Row():
with gr.Column():
show_cherrypick_multistyle = gr.Checkbox(
label="Show Cherry-pick Interface",
info="Turn on to show interface, picking seeds from previous generations.",
value=False,
)
with gr.Column():
remove_silence_multistyle = gr.Checkbox(
label="Remove Silences",
info="Turn on to automatically detect and crop long silences.",
value=True,
)
# Generate button
generate_multistyle_btn = gr.Button("Generate Multi-Style Speech", variant="primary")
# Output audio
audio_output_multistyle = gr.Audio(label="Synthesized Audio")
# Used seed gallery
cherrypick_interface_multistyle = gr.Textbox(
label="Cherry-pick Interface",
lines=10,
max_lines=40,
buttons=["copy"], # show_copy_button=True if gradio<6.0
interactive=False,
visible=False,
)
# Logic control to show/hide the cherrypick interface
show_cherrypick_multistyle.change(
lambda is_visible: gr.update(visible=is_visible),
show_cherrypick_multistyle,
cherrypick_interface_multistyle,
)
# Function to load text to generate from file
gen_text_file_multistyle.upload(
load_text_from_file,
inputs=[gen_text_file_multistyle],
outputs=[gen_text_input_multistyle],
)
@gpu_decorator
def generate_multistyle_speech(
gen_text,
*args,
):
speech_type_names_list = args[:max_speech_types]
speech_type_audios_list = args[max_speech_types : 2 * max_speech_types]
speech_type_ref_texts_list = args[2 * max_speech_types : 3 * max_speech_types]
remove_silence = args[3 * max_speech_types]
# Collect the speech types and their audios into a dict
speech_types = OrderedDict()
ref_text_idx = 0
for name_input, audio_input, ref_text_input in zip(
speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list
):
if name_input and audio_input:
speech_types[name_input] = {"audio": audio_input, "ref_text": ref_text_input}
else:
speech_types[f"@{ref_text_idx}@"] = {"audio": "", "ref_text": ""}
ref_text_idx += 1
# Parse the gen_text into segments
segments = parse_speechtypes_text(gen_text)
# For each segment, generate speech
generated_audio_segments = []
current_type_name = "Regular"
inference_meta_data = ""
for segment in segments:
name = segment["name"]
seed_input = segment["seed"]
speed = segment["speed"]
text = segment["text"]
if name in speech_types:
current_type_name = name
else:
gr.Warning(f"Type {name} is not available, will use Regular as default.")
current_type_name = "Regular"
try:
ref_audio = speech_types[current_type_name]["audio"]
except KeyError:
gr.Warning(f"Please provide reference audio for type {current_type_name}.")
return [None] + [speech_types[name]["ref_text"] for name in speech_types] + [None]
ref_text = speech_types[current_type_name].get("ref_text", "")
if seed_input == -1:
seed_input = np.random.randint(0, 2**31 - 1)
# Generate or retrieve speech for this segment
audio_out, _, ref_text_out, used_seed = infer(
ref_audio,
ref_text,
text,
tts_model_choice,
remove_silence,
seed=seed_input,
cross_fade_duration=0,
speed=speed,
show_info=print, # no pull to top when generating
)
sr, audio_data = audio_out
generated_audio_segments.append(audio_data)
speech_types[current_type_name]["ref_text"] = ref_text_out
inference_meta_data += json.dumps(dict(name=name, seed=used_seed, speed=speed)) + f" {text}\n"
# Concatenate all audio segments
if generated_audio_segments:
final_audio_data = np.concatenate(generated_audio_segments)
return (
[(sr, final_audio_data)]
+ [speech_types[name]["ref_text"] for name in speech_types]
+ [inference_meta_data]
)
else:
gr.Warning("No audio generated.")
return [None] + [speech_types[name]["ref_text"] for name in speech_types] + [None]
generate_multistyle_btn.click(
generate_multistyle_speech,
inputs=[
gen_text_input_multistyle,
]
+ speech_type_names
+ speech_type_audios
+ speech_type_ref_texts
+ [
remove_silence_multistyle,
],
outputs=[audio_output_multistyle] + speech_type_ref_texts + [cherrypick_interface_multistyle],
)
# Validation function to disable Generate button if speech types are missing
def validate_speech_types(gen_text, regular_name, *args):
speech_type_names_list = args
# Collect the speech types names
speech_types_available = set()
if regular_name:
speech_types_available.add(regular_name)
for name_input in speech_type_names_list:
if name_input:
speech_types_available.add(name_input)
# Parse the gen_text to get the speech types used
segments = parse_speechtypes_text(gen_text)
speech_types_in_text = set(segment["name"] for segment in segments)
# Check if all speech types in text are available
missing_speech_types = speech_types_in_text - speech_types_available
if missing_speech_types:
# Disable the generate button
return gr.update(interactive=False)
else:
# Enable the generate button
return gr.update(interactive=True)
gen_text_input_multistyle.change(
validate_speech_types,
inputs=[gen_text_input_multistyle, regular_name] + speech_type_names,
outputs=generate_multistyle_btn,
)
with gr.Blocks() as app_chat:
gr.Markdown(
"""
# Voice Chat
Have a conversation with an AI using your reference voice!
1. Upload a reference audio clip and optionally its transcript (via text or .txt file).
2. Load the chat model.
3. Record your message through your microphone or type it.
4. The AI will respond using the reference voice.
"""
)
chat_model_name_list = [
"Qwen/Qwen2.5-3B-Instruct",
"microsoft/Phi-4-mini-instruct",
]
@gpu_decorator
def load_chat_model(chat_model_name):
show_info = gr.Info
global chat_model_state, chat_tokenizer_state
if chat_model_state is not None:
chat_model_state = None
chat_tokenizer_state = None
gc.collect()
torch.cuda.empty_cache()
show_info(f"Loading chat model: {chat_model_name}")
chat_model_state = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype="auto", device_map="auto")
chat_tokenizer_state = AutoTokenizer.from_pretrained(chat_model_name)
show_info(f"Chat model {chat_model_name} loaded successfully!")
return gr.update(visible=False), gr.update(visible=True)
if USING_SPACES:
load_chat_model(chat_model_name_list[0])
chat_model_name_input = gr.Dropdown(
choices=chat_model_name_list,
value=chat_model_name_list[0],
label="Chat Model Name",
info="Enter the name of a HuggingFace chat model",
allow_custom_value=not USING_SPACES,
)
load_chat_model_btn = gr.Button("Load Chat Model", variant="primary", visible=not USING_SPACES)
chat_interface_container = gr.Column(visible=USING_SPACES)
chat_model_name_input.change(
lambda: gr.update(visible=True),
None,
load_chat_model_btn,
show_progress="hidden",
)
load_chat_model_btn.click(
load_chat_model, inputs=[chat_model_name_input], outputs=[load_chat_model_btn, chat_interface_container]
)
with chat_interface_container:
with gr.Row():
with gr.Column():
ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath")
with gr.Column():
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
ref_text_chat = gr.Textbox(
label="Reference Text",
info="Optional: Leave blank to auto-transcribe",
lines=2,
scale=3,
)
ref_text_file_chat = gr.File(
label="Load Reference Text from File (.txt)", file_types=[".txt"], scale=1
)
with gr.Row():
randomize_seed_chat = gr.Checkbox(
label="Randomize Seed",
value=True,
info="Uncheck to use the seed specified.",
scale=3,
)
seed_input_chat = gr.Number(show_label=False, value=0, precision=0, scale=1)
remove_silence_chat = gr.Checkbox(
label="Remove Silences",
value=True,
)
system_prompt_chat = gr.Textbox(
label="System Prompt",
value="You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
lines=2,
)
chatbot_interface = gr.Chatbot(
label="Conversation"
) # type="messages" hard-coded and no need to pass in since gradio 6.0
with gr.Row():
with gr.Column():
audio_input_chat = gr.Microphone(
label="Speak your message",
type="filepath",
)
audio_output_chat = gr.Audio(autoplay=True)
with gr.Column():
text_input_chat = gr.Textbox(
label="Type your message",
lines=1,
)
send_btn_chat = gr.Button("Send Message")
clear_btn_chat = gr.Button("Clear Conversation")
# Modify process_audio_input to generate user input
@gpu_decorator
def process_audio_input(conv_state, audio_path, text):
"""Handle audio or text input from user"""
if not audio_path and not text.strip():
return conv_state
if audio_path:
text = preprocess_ref_audio_text(audio_path, text)[1]
if not text.strip():
return conv_state
conv_state.append({"role": "user", "content": text})
return conv_state
# Use model and tokenizer from state to get text response
@gpu_decorator
def generate_text_response(conv_state, system_prompt):
"""Generate text response from AI"""
for single_state in conv_state:
if isinstance(single_state["content"], list):
assert len(single_state["content"]) == 1 and single_state["content"][0]["type"] == "text"
single_state["content"] = single_state["content"][0]["text"]
system_prompt_state = [{"role": "system", "content": system_prompt}]
response = chat_model_inference(system_prompt_state + conv_state, chat_model_state, chat_tokenizer_state)
conv_state.append({"role": "assistant", "content": response})
return conv_state
@gpu_decorator
def generate_audio_response(conv_state, ref_audio, ref_text, remove_silence, randomize_seed, seed_input):
"""Generate TTS audio for AI response"""
if not conv_state or not ref_audio:
return None, ref_text, seed_input
last_ai_response = conv_state[-1]["content"][0]["text"]
if not last_ai_response or conv_state[-1]["role"] != "assistant":
return None, ref_text, seed_input
if randomize_seed:
seed_input = np.random.randint(0, 2**31 - 1)
audio_result, _, ref_text_out, used_seed = infer(
ref_audio,
ref_text,
last_ai_response,
tts_model_choice,
remove_silence,
seed=seed_input,
cross_fade_duration=0.15,
speed=1.0,
show_info=print, # show_info=print no pull to top when generating
)
return audio_result, ref_text_out, used_seed
def clear_conversation():
"""Reset the conversation"""
return [], None
ref_text_file_chat.upload(
load_text_from_file,
inputs=[ref_text_file_chat],
outputs=[ref_text_chat],
)
for user_operation in [audio_input_chat.stop_recording, text_input_chat.submit, send_btn_chat.click]:
user_operation(
process_audio_input,
inputs=[chatbot_interface, audio_input_chat, text_input_chat],
outputs=[chatbot_interface],
).then(
generate_text_response,
inputs=[chatbot_interface, system_prompt_chat],
outputs=[chatbot_interface],
).then(
generate_audio_response,
inputs=[
chatbot_interface,
ref_audio_chat,
ref_text_chat,
remove_silence_chat,
randomize_seed_chat,
seed_input_chat,
],
outputs=[audio_output_chat, ref_text_chat, seed_input_chat],
).then(
lambda: [None, None],
None,
[audio_input_chat, text_input_chat],
)
# Handle clear button or system prompt change and reset conversation
for user_operation in [clear_btn_chat.click, system_prompt_chat.change, chatbot_interface.clear]:
user_operation(
clear_conversation,
outputs=[chatbot_interface, audio_output_chat],
)
with gr.Blocks() as app_credits:
gr.Markdown("""
# Credits
* [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
* [RootingInLoad](https://github.com/RootingInLoad) for initial chunk generation and podcast app exploration
* [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation & voice chat
""")
with gr.Blocks() as app:
gr.Markdown(
f"""
# F5-TTS Demo Space
This is {"a local web UI for [F5-TTS](https://github.com/SWivid/F5-TTS)" if not USING_SPACES else "an online demo for [F5-TTS](https://github.com/SWivid/F5-TTS)"} with advanced batch processing support. This app supports the following TTS models:
* [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
* [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
The checkpoints currently support English and Chinese.
If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 12s with ✂ in the bottom right corner (otherwise might have non-optimal auto-trimmed result).
**NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<12s). Ensure the audio is fully uploaded before generating.**
"""
)
last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom_model_info_v1.txt")
def load_last_used_custom():
try:
custom = []
with open(last_used_custom, "r", encoding="utf-8") as f:
for line in f:
custom.append(line.strip())
return custom
except FileNotFoundError:
last_used_custom.parent.mkdir(parents=True, exist_ok=True)
return DEFAULT_TTS_MODEL_CFG
def switch_tts_model(new_choice):
global tts_model_choice
if new_choice == "Custom": # override in case webpage is refreshed
custom_ckpt_path, custom_vocab_path, custom_model_cfg = load_last_used_custom()
tts_model_choice = ("Custom", custom_ckpt_path, custom_vocab_path, custom_model_cfg)
return (
gr.update(visible=True, value=custom_ckpt_path),
gr.update(visible=True, value=custom_vocab_path),
gr.update(visible=True, value=custom_model_cfg),
)
else:
tts_model_choice = new_choice
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
def set_custom_model(custom_ckpt_path, custom_vocab_path, custom_model_cfg):
global tts_model_choice
tts_model_choice = ("Custom", custom_ckpt_path, custom_vocab_path, custom_model_cfg)
with open(last_used_custom, "w", encoding="utf-8") as f:
f.write(custom_ckpt_path + "\n" + custom_vocab_path + "\n" + custom_model_cfg + "\n")
with gr.Row():
if not USING_SPACES:
choose_tts_model = gr.Radio(
choices=[DEFAULT_TTS_MODEL, "E2-TTS", "Custom"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL
)
else:
choose_tts_model = gr.Radio(
choices=[DEFAULT_TTS_MODEL, "E2-TTS"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL
)
custom_ckpt_path = gr.Dropdown(
choices=[DEFAULT_TTS_MODEL_CFG[0]],
value=load_last_used_custom()[0],
allow_custom_value=True,
label="Model: local_path | hf://user_id/repo_id/model_ckpt",
visible=False,
)
custom_vocab_path = gr.Dropdown(
choices=[DEFAULT_TTS_MODEL_CFG[1]],
value=load_last_used_custom()[1],
allow_custom_value=True,
label="Vocab: local_path | hf://user_id/repo_id/vocab_file",
visible=False,
)
custom_model_cfg = gr.Dropdown(
choices=[
DEFAULT_TTS_MODEL_CFG[2],
json.dumps(
dict(
dim=1024,
depth=22,
heads=16,
ff_mult=2,
text_dim=512,
text_mask_padding=False,
conv_layers=4,
pe_attn_head=1,
)
),
json.dumps(
dict(
dim=768,
depth=18,
heads=12,
ff_mult=2,
text_dim=512,
text_mask_padding=False,
conv_layers=4,
pe_attn_head=1,
)
),
],
value=load_last_used_custom()[2],
allow_custom_value=True,
label="Config: in a dictionary form",
visible=False,
)
choose_tts_model.change(
switch_tts_model,
inputs=[choose_tts_model],
outputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
show_progress="hidden",
)
custom_ckpt_path.change(
set_custom_model,
inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
show_progress="hidden",
)
custom_vocab_path.change(
set_custom_model,
inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
show_progress="hidden",
)
custom_model_cfg.change(
set_custom_model,
inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
show_progress="hidden",
)
gr.TabbedInterface(
[app_tts, app_multistyle, app_chat, app_credits],
["Basic-TTS", "Multi-Speech", "Voice-Chat", "Credits"],
)
@click.command()
@click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
@click.option("--host", "-H", default=None, help="Host to run the app on")
@click.option(
"--share",
"-s",
default=False,
is_flag=True,
help="Share the app via Gradio share link",
)
@click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
@click.option(
"--root_path",
"-r",
default=None,
type=str,
help='The root path (or "mount point") of the application, if it\'s not served from the root ("/") of the domain. Often used when the application is behind a reverse proxy that forwards requests to the application, e.g. set "/myapp" or full URL for application served at "https://example.com/myapp".',
)
@click.option(
"--inbrowser",
"-i",
is_flag=True,
default=False,
help="Automatically launch the interface in the default web browser",
)
def main(port, host, share, api, root_path, inbrowser):
global app
print("Starting app...")
app.queue(api_open=api).launch(
server_name=host,
server_port=port,
share=share,
root_path=root_path,
inbrowser=inbrowser,
)
if __name__ == "__main__":
if not USING_SPACES:
main()
else:
app.queue().launch()
================================================
FILE: src/f5_tts/infer/speech_edit.py
================================================
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
from importlib.resources import files
import torch
import torch.nn.functional as F
import torchaudio
from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
from f5_tts.model import CFM
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
device = (
"cuda"
if torch.cuda.is_available()
else "xpu"
if torch.xpu.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
# ---------------------- infer setting ---------------------- #
seed = None # int | None
exp_name = "F5TTS_v1_Base" # F5TTS_v1_Base | E2TTS_Base
ckpt_step = 1250000
nfe_step = 32 # 16, 32
cfg_strength = 2.0
ode_method = "euler" # euler | midpoint
sway_sampling_coef = -1.0
speed = 1.0
target_rms = 0.1
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
dataset_name = model_cfg.datasets.name
tokenizer = model_cfg.model.tokenizer
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
n_mel_channels = model_cfg.model.mel_spec.n_mel_channels
hop_length = model_cfg.model.mel_spec.hop_length
win_length = model_cfg.model.mel_spec.win_length
n_fft = model_cfg.model.mel_spec.n_fft
# ckpt_path = str(files("f5_tts").joinpath("../../")) + f"/ckpts/{exp_name}/model_{ckpt_step}.safetensors"
ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.safetensors"))
output_dir = "tests"
# [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
# pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
# [write the origin_text into a file, e.g. tests/test_edit.txt]
# ctc-forced-aligner --audio_path "src/f5_tts/infer/examples/basic/basic_ref_en.wav" --text_path "tests/test_edit.txt" --language "zho" --romanize --split_size "char"
# [result will be saved at same path of audio file]
# [--language "zho" for Chinese, "eng" for English]
# [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"]
audio_to_edit = str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav"))
origin_text = "Some call me nature, others call me mother nature."
target_text = "Some call me optimist, others call me realist."
parts_to_edit = [
[1.42, 2.44],
[4.04, 4.9],
] # stard_ends of "nature" & "mother nature", in seconds
fix_duration = [
1.2,
1,
] # fix duration for "optimist" & "realist", in seconds
# audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_zh.wav"
# origin_text = "对,这就是我,万人敬仰的太乙真人。"
# target_text = "对,那就是你,万人敬仰的太白金星。"
# parts_to_edit = [[0.84, 1.4], [1.92, 2.4], [4.26, 6.26], ]
# fix_duration = None # use origin text duration
# audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_zh.wav"
# origin_text = "对,这就是我,万人敬仰的太乙真人。"
# target_text = "对,这就是你,万人敬仰的李白金星。"
# parts_to_edit = [[1.500, 2.784], [4.083, 6.760]]
# fix_duration = [1.284, 2.677]
# -------------------------------------------------#
use_ema = True
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Vocoder model
local = False
if mel_spec_type == "vocos":
vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
elif mel_spec_type == "bigvgan":
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path)
# Tokenizer
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
# Model
model = CFM(
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
mel_spec_kwargs=dict(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type,
),
odeint_kwargs=dict(
method=ode_method,
),
vocab_char_map=vocab_char_map,
).to(device)
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
# Audio
audio, sr = torchaudio.load(audio_to_edit)
if audio.shape[0] > 1:
audio = torch.mean(audio, dim=0, keepdim=True)
rms = torch.sqrt(torch.mean(torch.square(audio)))
if rms < target_rms:
audio = audio * target_rms / rms
if sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
audio = resampler(audio)
# Convert to mel spectrogram FIRST (on clean original audio)
# This avoids boundary artifacts from mel windows straddling zeros and real audio
audio = audio.to(device)
with torch.inference_mode():
original_mel = model.mel_spec(audio) # (batch, n_mel, n_frames)
original_mel = original_mel.permute(0, 2, 1) # (batch, n_frames, n_mel)
# Build mel_cond and edit_mask at FRAME level
# Insert zero frames in mel domain instead of zero samples in wav domain
offset_frame = 0
mel_cond = torch.zeros(1, 0, n_mel_channels, device=device)
edit_mask = torch.zeros(1, 0, dtype=torch.bool, device=device)
fix_dur_list = fix_duration.copy() if fix_duration is not None else None
for part in parts_to_edit:
start, end = part
part_dur_sec = end - start if fix_dur_list is None else fix_dur_list.pop(0)
# Convert to frames (this is the authoritative unit)
start_frame = round(start * target_sample_rate / hop_length)
end_frame = round(end * target_sample_rate / hop_length)
part_dur_frames = round(part_dur_sec * target_sample_rate / hop_length)
# Number of frames for the kept (non-edited) region
keep_frames = start_frame - offset_frame
# Build mel_cond: original mel frames + zero frames for edit region
mel_cond = torch.cat(
(
mel_cond,
original_mel[:, offset_frame:start_frame, :],
torch.zeros(1, part_dur_frames, n_mel_channels, device=device),
),
dim=1,
)
edit_mask = torch.cat(
(
edit_mask,
torch.ones(1, keep_frames, dtype=torch.bool, device=device),
torch.zeros(1, part_dur_frames, dtype=torch.bool, device=device),
),
dim=-1,
)
offset_frame = end_frame
# Append remaining mel frames after last edit
mel_cond = torch.cat((mel_cond, original_mel[:, offset_frame:, :]), dim=1)
edit_mask = F.pad(edit_mask, (0, mel_cond.shape[1] - edit_mask.shape[-1]), value=True)
# Text
text_list = [target_text]
if tokenizer == "pinyin":
final_text_list = convert_char_to_pinyin(text_list)
else:
final_text_list = [text_list]
print(f"text : {text_list}")
print(f"pinyin: {final_text_list}")
# Duration - use mel_cond length (not raw audio length)
duration = mel_cond.shape[1]
# Inference - pass mel_cond directly (not wav)
with torch.inference_mode():
generated, trajectory = model.sample(
cond=mel_cond, # Now passing mel directly, not wav
text=final_text_list,
duration=duration,
steps=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
seed=seed,
edit_mask=edit_mask,
)
print(f"Generated mel: {generated.shape}")
# Final result
generated = generated.to(torch.float32)
gen_mel_spec = generated.permute(0, 2, 1)
if mel_spec_type == "vocos":
generated_wave = vocoder.decode(gen_mel_spec).cpu()
elif mel_spec_type == "bigvgan":
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
if rms < target_rms:
generated_wave = generated_wave * rms / target_rms
save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave, target_sample_rate)
print(f"Generated wav: {generated_wave.shape}")
================================================
FILE: src/f5_tts/infer/utils_infer.py
================================================
# A unified script for inference process
# Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format
import os
import sys
from concurrent.futures import ThreadPoolExecutor
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/")
import hashlib
import re
import tempfile
from importlib.resources import files
import matplotlib
matplotlib.use("Agg")
import matplotlib.pylab as plt
import numpy as np
import torch
import torchaudio
import tqdm
from huggingface_hub import hf_hub_download
from pydub import AudioSegment, silence
from transformers import pipeline
from vocos import Vocos
from f5_tts.model import CFM
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
_ref_audio_cache = {}
_ref_text_cache = {}
device = (
"cuda"
if torch.cuda.is_available()
else "xpu"
if torch.xpu.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
tempfile_kwargs = {"delete_on_close": False} if sys.version_info >= (3, 12) else {"delete": False}
# -----------------------------------------
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
win_length = 1024
n_fft = 1024
mel_spec_type = "vocos"
target_rms = 0.1
cross_fade_duration = 0.15
ode_method = "euler"
nfe_step = 32 # 16, 32
cfg_strength = 2.0
sway_sampling_coef = -1.0
speed = 1.0
fix_duration = None
# -----------------------------------------
# chunk text into smaller pieces
def chunk_text(text, max_chars=135):
"""
Splits the input text into chunks, each with a maximum number of characters.
Args:
text (str): The text to be split.
max_chars (int): The maximum number of characters per chunk.
Returns:
List[str]: A list of text chunks.
"""
chunks = []
current_chunk = ""
# Split the text into sentences based on punctuation followed by whitespace
sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
for sentence in sentences:
if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
# load vocoder
def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=device, hf_cache_dir=None):
if vocoder_name == "vocos":
# vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
if is_local:
print(f"Load vocos from local path {local_path}")
config_path = f"{local_path}/config.yaml"
model_path = f"{local_path}/pytorch_model.bin"
else:
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
repo_id = "charactr/vocos-mel-24khz"
config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
vocoder = Vocos.from_hparams(config_path)
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
from vocos.feature_extractors import EncodecFeatures
if isinstance(vocoder.feature_extractor, EncodecFeatures):
encodec_parameters = {
"feature_extractor.encodec." + key: value
for key, value in vocoder.feature_extractor.encodec.state_dict().items()
}
state_dict.update(encodec_parameters)
vocoder.load_state_dict(state_dict)
vocoder = vocoder.eval().to(device)
elif vocoder_name == "bigvgan":
try:
from third_party.BigVGAN import bigvgan
except ImportError:
print("You need to follow the README to init submodule and change the BigVGAN source code.")
if is_local:
# download generator from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
else:
vocoder = bigvgan.BigVGAN.from_pretrained(
"nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False, cache_dir=hf_cache_dir
)
vocoder.remove_weight_norm()
vocoder = vocoder.eval().to(device)
return vocoder
# load asr pipeline
asr_pipe = None
def initialize_asr_pipeline(device: str = device, dtype=None):
if dtype is None:
dtype = (
torch.float16
if "cuda" in device
and torch.cuda.get_device_properties(device).major >= 7
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
else torch.float32
)
global asr_pipe
asr_pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-large-v3-turbo",
torch_dtype=dtype,
device=device,
)
# transcribe
def transcribe(ref_audio, language=None):
global asr_pipe
if asr_pipe is None:
initialize_asr_pipeline(device=device)
return asr_pipe(
ref_audio,
chunk_length_s=30,
batch_size=128,
generate_kwargs={"task": "transcribe", "language": language} if language else {"task": "transcribe"},
return_timestamps=False,
)["text"].strip()
# load model checkpoint for inference
def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
if dtype is None:
dtype = (
torch.float16
if "cuda" in device
and torch.cuda.get_device_properties(device).major >= 7
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
else torch.float32
)
model = model.to(dtype)
ckpt_type = ckpt_path.split(".")[-1]
if ckpt_type == "safetensors":
from safetensors.torch import load_file
checkpoint = load_file(ckpt_path, device=device)
else:
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
if use_ema:
if ckpt_type == "safetensors":
checkpoint = {"ema_model_state_dict": checkpoint}
checkpoint["model_state_dict"] = {
k.replace("ema_model.", ""): v
for k, v in checkpoint["ema_model_state_dict"].items()
if k not in ["initted", "step"]
}
# patch for backward compatibility, 305e3ea
for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
if key in checkpoint["model_state_dict"]:
del checkpoint["model_state_dict"][key]
model.load_state_dict(checkpoint["model_state_dict"])
else:
if ckpt_type == "safetensors":
checkpoint = {"model_state_dict": checkpoint}
model.load_state_dict(checkpoint["model_state_dict"])
del checkpoint
torch.cuda.empty_cache()
return model.to(device)
# load model for inference
def load_model(
model_cls,
model_cfg,
ckpt_path,
mel_spec_type=mel_spec_type,
vocab_file="",
ode_method=ode_method,
use_ema=True,
device=device,
):
if vocab_file == "":
vocab_file = str(files("f5_tts").joinpath("infer/examples/vocab.txt"))
tokenizer = "custom"
print("\nvocab : ", vocab_file)
print("token : ", tokenizer)
print("model : ", ckpt_path, "\n")
vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer)
model = CFM(
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
mel_spec_kwargs=dict(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type,
),
odeint_kwargs=dict(
method=ode_method,
),
vocab_char_map=vocab_char_map,
).to(device)
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
return model
def remove_silence_edges(audio, silence_threshold=-42):
# Remove silence from the start
non_silent_start_idx = silence.detect_leading_silence(audio, silence_threshold=silence_threshold)
audio = audio[non_silent_start_idx:]
# Remove silence from the end
non_silent_end_duration = audio.duration_seconds
for ms in reversed(audio):
if ms.dBFS > silence_threshold:
break
non_silent_end_duration -= 0.001
trimmed_audio = audio[: int(non_silent_end_duration * 1000)]
return trimmed_audio
# preprocess reference audio and text
def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
show_info("Converting audio...")
# Compute a hash of the reference audio file
with open(ref_audio_orig, "rb") as audio_file:
audio_data = audio_file.read()
audio_hash = hashlib.md5(audio_data).hexdigest()
global _ref_audio_cache
if audio_hash in _ref_audio_cache:
show_info("Using cached preprocessed reference audio...")
ref_audio = _ref_audio_cache[audio_hash]
else: # first pass, do preprocess
with tempfile.NamedTemporaryFile(suffix=".wav", **tempfile_kwargs) as f:
temp_path = f.name
aseg = AudioSegment.from_file(ref_audio_orig)
# 1. try to find long silence for clipping
non_silent_segs = silence.split_on_silence(
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10
)
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
show_info("Audio is over 12s, clipping short. (1)")
break
non_silent_wave += non_silent_seg
# 2. try to find short silence for clipping if 1. failed
if len(non_silent_wave) > 12000:
non_silent_segs = silence.split_on_silence(
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
)
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
show_info("Audio is over 12s, clipping short. (2)")
break
non_silent_wave += non_silent_seg
aseg = non_silent_wave
# 3. if no proper silence found for clipping
if len(aseg) > 12000:
aseg = aseg[:12000]
show_info("Audio is over 12s, clipping short. (3)")
aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
aseg.export(temp_path, format="wav")
ref_audio = temp_path
# Cache the processed reference audio
_ref_audio_cache[audio_hash] = ref_audio
if not ref_text.strip():
global _ref_text_cache
if audio_hash in _ref_text_cache:
# Use cached asr transcription
show_info("Using cached reference text...")
ref_text = _ref_text_cache[audio_hash]
else:
show_info("No reference text provided, transcribing reference audio...")
ref_text = transcribe(ref_audio)
# Cache the transcribed text (not caching custom ref_text, enabling users to do manual tweak)
_ref_text_cache[audio_hash] = ref_text
else:
show_info("Using custom reference text...")
# Ensure ref_text ends with a proper sentence-ending punctuation
if not ref_text.endswith(". ") and not ref_text.endswith("。"):
if ref_text.endswith("."):
ref_text += " "
else:
ref_text += ". "
print("\nref_text ", ref_text)
return ref_audio, ref_text
# infer process: chunk text -> infer batches [i.e. infer_batch_process()]
def infer_process(
ref_audio,
ref_text,
gen_text,
model_obj,
vocoder,
mel_spec_type=mel_spec_type,
show_info=print,
progress=tqdm,
target_rms=target_rms,
cross_fade_duration=cross_fade_duration,
nfe_step=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
speed=speed,
fix_duration=fix_duration,
device=device,
):
# Split the input text into batches
audio, sr = torchaudio.load(ref_audio)
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr) * speed)
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
for i, gen_text in enumerate(gen_text_batches):
print(f"gen_text {i}", gen_text)
print("\n")
show_info(f"Generating audio in {len(gen_text_batches)} batches...")
return next(
infer_batch_process(
(audio, sr),
ref_text,
gen_text_batches,
model_obj,
vocoder,
mel_spec_type=mel_spec_type,
progress=progress,
target_rms=target_rms,
cross_fade_duration=cross_fade_duration,
nfe_step=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
speed=speed,
fix_duration=fix_duration,
device=device,
)
)
# infer batches
def infer_batch_process(
ref_audio,
ref_text,
gen_text_batches,
model_obj,
vocoder,
mel_spec_type="vocos",
progress=tqdm,
target_rms=0.1,
cross_fade_duration=0.15,
nfe_step=32,
cfg_strength=2.0,
sway_sampling_coef=-1,
speed=1,
fix_duration=None,
device=None,
streaming=False,
chunk_size=2048,
):
audio, sr = ref_audio
if audio.shape[0] > 1:
audio = torch.mean(audio, dim=0, keepdim=True)
rms = torch.sqrt(torch.mean(torch.square(audio)))
if rms < target_rms:
audio = audio * target_rms / rms
if sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
audio = resampler(audio)
audio = audio.to(device)
generated_waves = []
spectrograms = []
if len(ref_text[-1].encode("utf-8")) == 1:
ref_text = ref_text + " "
def process_batch(gen_text):
local_speed = speed
if len(gen_text.encode("utf-8")) < 10:
local_speed = 0.3
# Prepare the text
text_list = [ref_text + gen_text]
final_text_list = convert_char_to_pinyin(text_list)
ref_audio_len = audio.shape[-1] // hop_length
if fix_duration is not None:
duration = int(fix_duration * target_sample_rate / hop_length)
else:
# Calculate duration
ref_text_len = len(ref_text.encode("utf-8"))
gen_text_len = len(gen_text.encode("utf-8"))
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / local_speed)
# inference
with torch.inference_mode():
generated, _ = model_obj.sample(
cond=audio,
text=final_text_list,
duration=duration,
steps=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
)
del _
generated = generated.to(torch.float32) # generated mel spectrogram
generated = generated[:, ref_audio_len:, :]
generated = generated.permute(0, 2, 1)
if mel_spec_type == "vocos":
generated_wave = vocoder.decode(generated)
elif mel_spec_type == "bigvgan":
generated_wave = vocoder(generated)
if rms < target_rms:
generated_wave = generated_wave * rms / target_rms
# wav -> numpy
generated_wave = generated_wave.squeeze().cpu().numpy()
if streaming:
for j in range(0, len(generated_wave), chunk_size):
yield generated_wave[j : j + chunk_size], target_sample_rate
else:
generated_cpu = generated[0].cpu().numpy()
del generated
yield generated_wave, generated_cpu
if streaming:
for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches:
for chunk in process_batch(gen_text):
yield chunk
else:
with ThreadPoolExecutor() as executor:
futures = [executor.submit(process_batch, gen_text) for gen_text in gen_text_batches]
for future in progress.tqdm(futures) if progress is not None else futures:
result = future.result()
if result:
generated_wave, generated_mel_spec = next(result)
generated_waves.append(generated_wave)
spectrograms.append(generated_mel_spec)
if generated_waves:
if cross_fade_duration <= 0:
# Simply concatenate
final_wave = np.concatenate(generated_waves)
else:
# Combine all generated waves with cross-fading
final_wave = generated_waves[0]
for i in range(1, len(generated_waves)):
prev_wave = final_wave
next_wave = generated_waves[i]
# Calculate cross-fade samples, ensuring it does not exceed wave lengths
cross_fade_samples = int(cross_fade_duration * target_sample_rate)
cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
if cross_fade_samples <= 0:
# No overlap possible, concatenate
final_wave = np.concatenate([prev_wave, next_wave])
continue
# Overlapping parts
prev_overlap = prev_wave[-cross_fade_samples:]
next_overlap = next_wave[:cross_fade_samples]
# Fade out and fade in
fade_out = np.linspace(1, 0, cross_fade_samples)
fade_in = np.linspace(0, 1, cross_fade_samples)
# Cross-faded overlap
cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
# Combine
new_wave = np.concatenate(
[prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
)
final_wave = new_wave
# Create a combined spectrogram
combined_spectrogram = np.concatenate(spectrograms, axis=1)
yield final_wave, target_sample_rate, combined_spectrogram
else:
yield None, target_sample_rate, None
# remove silence from generated wav
def remove_silence_for_generated_wav(filename):
aseg = AudioSegment.from_file(filename)
non_silent_segs = silence.split_on_silence(
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500, seek_step=10
)
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
non_silent_wave += non_silent_seg
aseg = non_silent_wave
aseg.export(filename, format="wav")
# save spectrogram
def save_spectrogram(spectrogram, path):
plt.figure(figsize=(12, 4))
plt.imshow(spectrogram, origin="lower", aspect="auto")
plt.colorbar()
plt.savefig(path)
plt.close()
================================================
FILE: src/f5_tts/model/__init__.py
================================================
from f5_tts.model.backbones.dit import DiT
from f5_tts.model.backbones.mmdit import MMDiT
from f5_tts.model.backbones.unett import UNetT
from f5_tts.model.cfm import CFM
from f5_tts.model.trainer import Trainer
__all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]
================================================
FILE: src/f5_tts/model/backbones/README.md
================================================
## Backbones quick introduction
### unett.py
- flat unet transformer
- structure same as in e2-tts & voicebox paper except using rotary pos emb
- possible abs pos emb & convnextv2 blocks for embedded text before concat
### dit.py
- adaln-zero dit
- embedded timestep as condition
- concatted noised_input + masked_cond + embedded_text, linear proj in
- possible abs pos emb & convnextv2 blocks for embedded text before concat
- possible long skip connection (first layer to last layer)
### mmdit.py
- stable diffusion 3 block structure
- timestep as condition
- left stream: text embedded and applied a abs pos emb
- right stream: masked_cond & noised_input concatted and with same conv pos emb as unett
================================================
FILE: src/f5_tts/model/backbones/dit.py
================================================
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
# ruff: noqa: F722 F821
from __future__ import annotations
import torch
import torch.nn.functional as F
from torch import nn
from x_transformers.x_transformers import RotaryEmbedding
from f5_tts.model.modules import (
AdaLayerNorm_Final,
ConvNeXtV2Block,
ConvPositionEmbedding,
DiTBlock,
TimestepEmbedding,
precompute_freqs_cis,
)
# Text embedding
class TextEmbedding(nn.Module):
def __init__(
self, text_num_embeds, text_dim, mask_padding=True, average_upsampling=False, conv_layers=0, conv_mult=2
):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
self.average_upsampling = average_upsampling # zipvoice-style text late average upsampling (after text encoder)
if average_upsampling:
assert mask_padding, "text_embedding_average_upsampling requires text_mask_padding to be True"
if conv_layers > 0:
self.extra_modeling = True
self.precompute_max_pos = 8192 # 8192 is ~87.38s of 24khz audio; 4096 is ~43.69s of 24khz audio
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
self.text_blocks = nn.Sequential(
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
)
else:
self.extra_modeling = False
def average_upsample_text_by_mask(self, text, text_mask, target_lens):
batch, max_seq_len, text_dim = text.shape
text_lens = text_mask.sum(dim=1) # [batch]
upsampled_text = torch.zeros_like(text)
for i in range(batch):
text_len = int(text_lens[i].item())
audio_len = int(target_lens[i].item())
if text_len == 0 or audio_len <= 0:
continue
valid_ind = torch.where(text_mask[i])[0]
valid_data = text[i, valid_ind, :] # [text_len, text_dim]
base_repeat = audio_len // text_len
remainder = audio_len % text_len
indices = []
for j in range(text_len):
repeat_count = base_repeat + (1 if j >= text_len - remainder else 0)
indices.extend([j] * repeat_count)
indices = torch.tensor(indices[:audio_len], device=text.device, dtype=torch.long)
upsampled = valid_data[indices] # [audio_len, text_dim]
upsampled_text[i, :audio_len, :] = upsampled
return upsampled_text
def forward(self, text: int["b nt"], seq_len, drop_text=False):
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
valid_pos_mask = None
if torch.is_tensor(seq_len):
seq_len = seq_len.to(device=text.device, dtype=torch.long)
max_seq_len = int(seq_len.max().item())
else:
max_seq_len = int(seq_len)
text = text[:, :max_seq_len] # curtail if character tokens are more than the mel spec tokens
text = F.pad(text, (0, max_seq_len - text.shape[1]), value=0)
if torch.is_tensor(seq_len):
seq_pos = torch.arange(max_seq_len, device=text.device).unsqueeze(0)
valid_pos_mask = seq_pos < seq_len.unsqueeze(1)
text = text.masked_fill(~valid_pos_mask, 0)
if self.mask_padding:
text_mask = text == 0
if drop_text: # cfg for text
text = torch.zeros_like(text)
text = self.text_embed(text) # b n -> b n d
if valid_pos_mask is not None:
# Keep short-sample tail strictly zero (equivalent to per-sample pad_sequence(..., 0)).
text = text.masked_fill(~valid_pos_mask.unsqueeze(-1), 0.0)
# possible extra modeling
if self.extra_modeling:
# sinus pos emb; for variable seq lengths, only add positions within each sample's valid range.
freqs = self.freqs_cis[:max_seq_len, :]
if valid_pos_mask is not None:
freqs = freqs.unsqueeze(0) * valid_pos_mask.unsqueeze(-1).to(freqs.dtype)
text = text + freqs
# convnextv2 blocks
if self.mask_padding:
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
for block in self.text_blocks:
text = block(text)
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
else:
text = self.text_blocks(text)
if self.average_upsampling:
if torch.is_tensor(seq_len):
target_lens = seq_len.to(device=text.device, dtype=torch.long)
else:
target_lens = torch.full((text.shape[0],), int(seq_len), device=text.device, dtype=torch.long)
text = self.average_upsample_text_by_mask(text, ~text_mask, target_lens)
return text
# noised input audio and context mixing embedding
class InputEmbedding(nn.Module):
def __init__(self, mel_dim, text_dim, out_dim):
super().__init__()
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
def forward(
self,
x: float["b n d"],
cond: float["b n d"],
text_embed: float["b n d"],
drop_audio_cond=False,
audio_mask: bool["b n"] | None = None,
):
if drop_audio_cond: # cfg for cond audio
cond = torch.zeros_like(cond)
x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
x = self.conv_pos_embed(x, mask=audio_mask) + x
return x
# Transformer backbone using DiT blocks
class DiT(nn.Module):
def __init__(
self,
*,
dim,
depth=8,
heads=8,
dim_head=64,
dropout=0.1,
ff_mult=4,
mel_dim=100,
text_num_embeds=256,
text_dim=None,
text_mask_padding=True,
text_embedding_average_upsampling=False,
qk_norm=None,
conv_layers=0,
pe_attn_head=None,
attn_backend="torch", # "torch" | "flash_attn"
attn_mask_enabled=False,
long_skip_connection=False,
checkpoint_activations=False,
):
super().__init__()
self.time_embed = TimestepEmbedding(dim)
if text_dim is None:
text_dim = mel_dim
self.text_embed = TextEmbedding(
text_num_embeds,
text_dim,
mask_padding=text_mask_padding,
average_upsampling=text_embedding_average_upsampling,
conv_layers=conv_layers,
)
self.text_cond, self.text_uncond = None, None # text cache
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
self.rotary_embed = RotaryEmbedding(dim_head)
self.dim = dim
self.depth = depth
self.transformer_blocks = nn.ModuleList(
[
DiTBlock(
dim=dim,
heads=heads,
dim_head=dim_head,
ff_mult=ff_mult,
dropout=dropout,
qk_norm=qk_norm,
pe_attn_head=pe_attn_head,
attn_backend=attn_backend,
attn_mask_enabled=attn_mask_enabled,
)
for _ in range(depth)
]
)
self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
self.norm_out = AdaLayerNorm_Final(dim) # final modulation
self.proj_out = nn.Linear(dim, mel_dim)
self.checkpoint_activations = checkpoint_activations
self.initialize_weights()
def initialize_weights(self):
# Zero-out AdaLN layers in DiT blocks:
for block in self.transformer_blocks:
nn.init.constant_(block.attn_norm.linear.weight, 0)
nn.init.constant_(block.attn_norm.linear.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.norm_out.linear.weight, 0)
nn.init.constant_(self.norm_out.linear.bias, 0)
nn.init.constant_(self.proj_out.weight, 0)
nn.init.constant_(self.proj_out.bias, 0)
def ckpt_wrapper(self, module):
# https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
def ckpt_forward(*inputs):
outputs = module(*inputs)
return outputs
return ckpt_forward
def get_input_embed(
self,
x, # b n d
cond, # b n d
text, # b nt
drop_audio_cond: bool = False,
drop_text: bool = False,
cache: bool = True,
audio_mask: bool["b n"] | None = None,
):
if self.text_uncond is None or self.text_cond is None or not cache:
if audio_mask is None:
seq_len = x.shape[1]
else:
seq_len = audio_mask.sum(dim=1) # per-sample valid speech length
text_embed = self.text_embed(text, seq_len=seq_len, drop_text=drop_text)
if cache:
if drop_text:
self.text_uncond = text_embed
else:
self.text_cond = text_embed
if cache:
if drop_text:
text_embed = self.text_uncond
else:
text_embed = self.text_cond
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond, audio_mask=audio_mask)
return x
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
self,
x: float["b n d"], # nosied input audio
cond: float["b n d"], # masked cond audio
text: int["b nt"], # text
time: float["b"] | float[""], # time step
mask: bool["b n"] | None = None,
drop_audio_cond: bool = False, # cfg for cond audio
drop_text: bool = False, # cfg for text
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward
cache: bool = False,
):
batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning time, text: text, x: noised audio + cond audio + text
t = self.time_embed(time)
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
x_cond = self.get_input_embed(
x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache, audio_mask=mask
)
x_uncond = self.get_input_embed(
x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache, audio_mask=mask
)
x = torch.cat((x_cond, x_uncond), dim=0)
t = torch.cat((t, t), dim=0)
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
else:
x = self.get_input_embed(
x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache, audio_mask=mask
)
rope = self.rotary_embed.forward_from_seq_len(seq_len)
if self.long_skip_connection is not None:
residual = x
for block in self.transformer_blocks:
if self.checkpoint_activations:
# https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint
x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False)
else:
x = block(x, t, mask=mask, rope=rope)
if self.long_skip_connection is not None:
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
x = self.norm_out(x, t)
output = self.proj_out(x)
return output
================================================
FILE: src/f5_tts/model/backbones/mmdit.py
================================================
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
# ruff: noqa: F722 F821
from __future__ import annotations
import torch
from torch import nn
from x_transformers.x_transformers import RotaryEmbedding
from f5_tts.model.modules import (
AdaLayerNorm_Final,
ConvPositionEmbedding,
MMDiTBlock,
TimestepEmbedding,
get_pos_embed_indices,
precompute_freqs_cis,
)
# text embedding
class TextEmbedding(nn.Module):
def __init__(self, out_dim, text_num_embeds, mask_padding=True):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
self.precompute_max_pos = 1024
self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]:
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
if self.mask_padding:
text_mask = text == 0
if drop_text: # cfg for text
text = torch.zeros_like(text)
text = self.text_embed(text) # b nt -> b nt d
# sinus pos emb
batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
batch_text_len = text.shape[1]
pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos)
text_pos_embed = self.freqs_cis[pos_idx]
text = text + text_pos_embed
if self.mask_padding:
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
return text
# noised input & masked cond audio embedding
class AudioEmbedding(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.linear = nn.Linear(2 * in_dim, out_dim)
self.conv_pos_embed = ConvPositionEmbedding(out_dim)
def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False):
if drop_audio_cond:
cond = torch.zeros_like(cond)
x = torch.cat((x, cond), dim=-1)
x = self.linear(x)
x = self.conv_pos_embed(x) + x
return x
# Transformer backbone using MM-DiT blocks
class MMDiT(nn.Module):
def __init__(
self,
*,
dim,
depth=8,
heads=8,
dim_head=64,
dropout=0.1,
ff_mult=4,
mel_dim=100,
text_num_embeds=256,
text_mask_padding=True,
qk_norm=None,
checkpoint_activations=False,
attn_backend="torch",
attn_mask_enabled=False,
):
super().__init__()
self.time_embed = TimestepEmbedding(dim)
self.text_embed = TextEmbedding(dim, text_num_embeds, mask_padding=text_mask_padding)
self.text_cond, self.text_uncond = None, None # text cache
self.audio_embed = AudioEmbedding(mel_dim, dim)
self.rotary_embed = RotaryEmbedding(dim_head)
self.dim = dim
self.depth = depth
self.transformer_blocks = nn.ModuleList(
[
MMDiTBlock(
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
ff_mult=ff_mult,
context_pre_only=i == depth - 1,
qk_norm=qk_norm,
attn_backend=attn_backend,
attn_mask_enabled=attn_mask_enabled,
)
for i in range(depth)
]
)
self.norm_out = AdaLayerNorm_Final(dim) # final modulation
self.proj_out = nn.Linear(dim, mel_dim)
self.checkpoint_activations = checkpoint_activations
self.initialize_weights()
def initialize_weights(self):
# Zero-out AdaLN layers in MMDiT blocks:
for block in self.transformer_blocks:
nn.init.constant_(block.attn_norm_x.linear.weight, 0)
nn.init.constant_(block.attn_norm_x.linear.bias, 0)
nn.init.constant_(block.attn_norm_c.linear.weight, 0)
nn.init.constant_(block.attn_norm_c.linear.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.norm_out.linear.weight, 0)
nn.init.constant_(self.norm_out.linear.bias, 0)
nn.init.constant_(self.proj_out.weight, 0)
nn.init.constant_(self.proj_out.bias, 0)
def ckpt_wrapper(self, module):
def ckpt_forward(*inputs):
outputs = module(*inputs)
return outputs
return ckpt_forward
def get_input_embed(
self,
x, # b n d
cond, # b n d
text, # b nt
drop_audio_cond: bool = False,
drop_text: bool = False,
cache: bool = True,
):
if cache:
if drop_text:
if self.text_uncond is None:
self.text_uncond = self.text_embed(text, drop_text=True)
c = self.text_uncond
else:
if self.text_cond is None:
self.text_cond = self.text_embed(text, drop_text=False)
c = self.text_cond
else:
c = self.text_embed(text, drop_text=drop_text)
x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
return x, c
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
self,
x: float["b n d"], # nosied input audio
cond: float["b n d"], # masked cond audio
text: int["b nt"], # text
time: float["b"] | float[""], # time step
mask: bool["b n"] | None = None,
drop_audio_cond: bool = False, # cfg for cond audio
drop_text: bool = False, # cfg for text
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward
cache: bool = False,
):
batch = x.shape[0]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
c_mask = (text + 1) != 0 # True = valid, False = padding (-1 tokens)
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
x_cond, c_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache)
x_uncond, c_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache)
x = torch.cat((x_cond, x_uncond), dim=0)
c = torch.cat((c_cond, c_uncond), dim=0)
t = torch.cat((t, t), dim=0)
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
c_mask = torch.cat((c_mask, c_mask), dim=0)
else:
x, c = self.get_input_embed(
x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache
)
seq_len = x.shape[1]
text_len = text.shape[1]
rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
rope_text = self.rotary_embed.forward_from_seq_len(text_len)
for block in self.transformer_blocks:
if self.checkpoint_activations:
c, x = torch.utils.checkpoint.checkpoint(
self.ckpt_wrapper(block), x, c, t, mask, rope_audio, rope_text, c_mask, use_reentrant=False
)
else:
c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text, c_mask=c_mask)
x = self.norm_out(x, t)
output = self.proj_out(x)
return output
================================================
FILE: src/f5_tts/model/backbones/unett.py
================================================
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
# ruff: noqa: F722 F821
from __future__ import annotations
from typing import Literal
import torch
import torch.nn.functional as F
from torch import nn
from x_transformers import RMSNorm
from x_transformers.x_transformers import RotaryEmbedding
from f5_tts.model.modules import (
Attention,
AttnProcessor,
ConvNeXtV2Block,
ConvPositionEmbedding,
FeedForward,
TimestepEmbedding,
get_pos_embed_indices,
precompute_freqs_cis,
)
# Text embedding
class TextEmbedding(nn.Module):
def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
if conv_layers > 0:
self.extra_modeling = True
self.precompute_max_pos = 4096 # ~44s of 24khz audio
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
self.text_blocks = nn.Sequential(
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
)
else:
self.extra_modeling = False
def forward(self, text: int["b nt"], seq_len, drop_text=False):
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
batch, text_len = text.shape[0], text.shape[1]
text = F.pad(text, (0, seq_len - text_len), value=0)
if self.mask_padding:
text_mask = text == 0
if drop_text: # cfg for text
text = torch.zeros_like(text)
text = self.text_embed(text) # b n -> b n d
# possible extra modeling
if self.extra_modeling:
# sinus pos emb
batch_start = torch.zeros((batch,), dtype=torch.long)
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
text_pos_embed = self.freqs_cis[pos_idx]
text = text + text_pos_embed
# convnextv2 blocks
if self.mask_padding:
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
for block in self.text_blocks:
text = block(text)
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
else:
text = self.text_blocks(text)
return text
# noised input audio and context mixing embedding
class InputEmbedding(nn.Module):
def __init__(self, mel_dim, text_dim, out_dim):
super().__init__()
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False):
if drop_audio_cond: # cfg for cond audio
cond = torch.zeros_like(cond)
x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
x = self.conv_pos_embed(x) + x
return x
# Flat UNet Transformer backbone
class UNetT(nn.Module):
def __init__(
self,
*,
dim,
depth=8,
heads=8,
dim_head=64,
dropout=0.1,
ff_mult=4,
mel_dim=100,
text_num_embeds=256,
text_dim=None,
text_mask_padding=True,
qk_norm=None,
conv_layers=0,
pe_attn_head=None,
attn_backend="torch", # "torch" | "flash_attn"
attn_mask_enabled=False,
skip_connect_type: Literal["add", "concat", "none"] = "concat",
):
super().__init__()
assert depth % 2 == 0, "UNet-Transformer's depth should be even."
self.time_embed = TimestepEmbedding(dim)
if text_dim is None:
text_dim = mel_dim
self.text_embed = TextEmbedding(
text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers
)
self.text_cond, self.text_uncond = None, None # text cache
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
self.rotary_embed = RotaryEmbedding(dim_head)
# transformer layers & skip connections
self.dim = dim
self.skip_connect_type = skip_connect_type
needs_skip_proj = skip_connect_type == "concat"
self.depth = depth
self.layers = nn.ModuleList([])
for idx in range(depth):
is_later_half = idx >= (depth // 2)
attn_norm = RMSNorm(dim)
attn = Attention(
processor=AttnProcessor(
pe_attn_head=pe_attn_head,
attn_backend=attn_backend,
attn_mask_enabled=attn_mask_enabled,
),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
qk_norm=qk_norm,
)
ff_norm = RMSNorm(dim)
ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
skip_proj = nn.Linear(dim * 2, dim, bias=False) if needs_skip_proj and is_later_half else None
self.layers.append(
nn.ModuleList(
[
skip_proj,
attn_norm,
attn,
ff_norm,
ff,
]
)
)
self.norm_out = RMSNorm(dim)
self.proj_out = nn.Linear(dim, mel_dim)
def get_input_embed(
self,
x, # b n d
cond, # b n d
text, # b nt
drop_audio_cond: bool = False,
drop_text: bool = False,
cache: bool = True,
):
seq_len = x.shape[1]
if cache:
if drop_text:
if self.text_uncond is None:
self.text_uncond = self.text_embed(text, seq_len, drop_text=True)
text_embed = self.text_uncond
else:
if self.text_cond is None:
self.text_cond = self.text_embed(text, seq_len, drop_text=False)
text_embed = self.text_cond
else:
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
return x
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
self,
x: float["b n d"], # nosied input audio
cond: float["b n d"], # masked cond audio
text: int["b nt"], # text
time: float["b"] | float[""], # time step
mask: bool["b n"] | None = None,
drop_audio_cond: bool = False, # cfg for cond audio
drop_text: bool = False, # cfg for text
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward
cache: bool = False,
):
batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
x_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache)
x_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache)
x = torch.cat((x_cond, x_uncond), dim=0)
t = torch.cat((t, t), dim=0)
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
else:
x = self.get_input_embed(x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache)
# postfix time t to input x, [b n d] -> [b n+1 d]
x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
if mask is not None:
mask = F.pad(mask, (1, 0), value=1)
rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
# flat unet transformer
skip_connect_type = self.skip_connect_type
skips = []
for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(self.layers):
layer = idx + 1
# skip connection logic
is_first_half = layer <= (self.depth // 2)
is_later_half = not is_first_half
if is_first_half:
skips.append(x)
if is_later_half:
skip = skips.pop()
if skip_connect_type == "concat":
x = torch.cat((x, skip), dim=-1)
x = maybe_skip_proj(x)
elif skip_connect_type == "add":
x = x + skip
# attention and feedforward blocks
x = attn(attn_norm(x), rope=rope, mask=mask) + x
x = ff(ff_norm(x)) + x
assert len(skips) == 0
x = self.norm_out(x)[:, 1:, :] # unpack t from x
return self.proj_out(x)
================================================
FILE: src/f5_tts/model/cfm.py
================================================
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
# ruff: noqa: F722 F821
from __future__ import annotations
from random import random
from typing import Callable
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torchdiffeq import odeint
from f5_tts.model.modules import MelSpec
from f5_tts.model.utils import (
default,
exists,
get_epss_timesteps,
lens_to_mask,
list_str_to_idx,
list_str_to_tensor,
mask_from_frac_lengths,
)
class CFM(nn.Module):
def __init__(
self,
transformer: nn.Module,
sigma=0.0,
odeint_kwargs: dict = dict(
# atol = 1e-5,
# rtol = 1e-5,
method="euler" # 'midpoint'
),
audio_drop_prob=0.3,
cond_drop_prob=0.2,
num_channels=None,
mel_spec_module: nn.Module | None = None,
mel_spec_kwargs: dict = dict(),
frac_lengths_mask: tuple[float, float] = (0.7, 1.0),
vocab_char_map: dict[str:int] | None = None,
):
super().__init__()
self.frac_lengths_mask = frac_lengths_mask
# mel spec
self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
num_channels = default(num_channels, self.mel_spec.n_mel_channels)
self.num_channels = num_channels
# classifier-free guidance
self.audio_drop_prob = audio_drop_prob
self.cond_drop_prob = cond_drop_prob
# transformer
self.transformer = transformer
dim = transformer.dim
self.dim = dim
# conditional flow related
self.sigma = sigma
# sampling related
self.odeint_kwargs = odeint_kwargs
# vocab map for tokenization
self.vocab_char_map = vocab_char_map
@property
def device(self):
return next(self.parameters()).device
@torch.no_grad()
def sample(
self,
cond: float["b n d"] | float["b nw"],
text: int["b nt"] | list[str],
duration: int | int["b"],
*,
lens: int["b"] | None = None,
steps=32,
cfg_strength=1.0,
sway_sampling_coef=None,
seed: int | None = None,
max_duration=65536,
vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None,
use_epss=True,
no_ref_audio=False,
duplicate_test=False,
t_inter=0.1,
edit_mask=None,
):
self.eval()
# raw wave
if cond.ndim == 2:
cond = self.mel_spec(cond)
cond = cond.permute(0, 2, 1)
assert cond.shape[-1] == self.num_channels
cond = cond.to(next(self.parameters()).dtype)
batch, cond_seq_len, device = *cond.shape[:2], cond.device
if not exists(lens):
lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
# text
if isinstance(text, list):
if exists(self.vocab_char_map):
text = list_str_to_idx(text, self.vocab_char_map).to(device)
else:
text = list_str_to_tensor(text).to(device)
assert text.shape[0] == batch
# duration
cond_mask = lens_to_mask(lens)
if edit_mask is not None:
cond_mask = cond_mask & edit_mask
if isinstance(duration, int):
duration = torch.full((batch,), duration, device=device, dtype=torch.long)
duration = torch.maximum(
torch.maximum((text != -1).sum(dim=-1), lens) + 1, duration
) # duration at least text/audio prompt length plus one token, so something is generated
duration = duration.clamp(max=max_duration)
max_duration = duration.amax()
# duplicate test corner for inner time step oberservation
if duplicate_test:
test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
if no_ref_audio:
cond = torch.zeros_like(cond)
cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False)
cond_mask = cond_mask.unsqueeze(-1)
step_cond = torch.where(
cond_mask, cond, torch.zeros_like(cond)
) # allow direct control (cut cond audio) with lens passed in
if batch > 1:
mask = lens_to_mask(duration)
else: # save memory and speed up, as single inference need no mask currently
mask = None
# neural ode
def fn(t, x):
# at each step, conditioning is fixed
# step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
# predict flow (cond)
if cfg_strength < 1e-5:
pred = self.transformer(
x=x,
cond=step_cond,
text=text,
time=t,
mask=mask,
drop_audio_cond=False,
drop_text=False,
cache=True,
)
return pred
# predict flow (cond and uncond), for classifier-free guidance
pred_cfg = self.transformer(
x=x,
cond=step_cond,
text=text,
time=t,
mask=mask,
cfg_infer=True,
cache=True,
)
pred, null_pred = torch.chunk(pred_cfg, 2, dim=0)
return pred + (pred - null_pred) * cfg_strength
# noise input
# to make sure batch inference result is same with different batch size, and for sure single inference
# still some difference maybe due to convolutional layers
y0 = []
for dur in duration:
if exists(seed):
torch.manual_seed(seed)
y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype))
y0 = pad_sequence(y0, padding_value=0, batch_first=True)
t_start = 0
# duplicate test corner for inner time step oberservation
if duplicate_test:
t_start = t_inter
y0 = (1 - t_start) * y0 + t_start * test_cond
steps = int(steps * (1 - t_start))
if t_start == 0 and use_epss: # use Empirically Pruned Step Sampling for low NFE
t = get_epss_timesteps(steps, device=self.device, dtype=step_cond.dtype)
else:
t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype)
if sway_sampling_coef is not None:
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
self.transformer.clear_cache()
sampled = trajectory[-1]
out = sampled
out = torch.where(cond_mask, cond, out)
if exists(vocoder):
out = out.permute(0, 2, 1)
out = vocoder(out)
return out, trajectory
def forward(
self,
inp: float["b n d"] | float["b nw"], # mel or raw wave
text: int["b nt"] | list[str],
*,
lens: int["b"] | None = None,
noise_scheduler: str | None = None,
):
# handle raw wave
if inp.ndim == 2:
inp = self.mel_spec(inp)
inp = inp.permute(0, 2, 1)
assert inp.shape[-1] == self.num_channels
batch, seq_len, dtype, device, _σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
# handle text as string
if isinstance(text, list):
if exists(self.vocab_char_map):
text = list_str_to_idx(text, self.vocab_char_map).to(device)
else:
text = list_str_to_tensor(text).to(device)
assert text.shape[0] == batch
# lens and mask
if not exists(lens): # if lens not acquired by trainer from collate_fn
lens = torch.full((batch,), seq_len, device=device)
mask = lens_to_mask(lens, length=seq_len)
# get a random span to mask out for training conditionally
frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)
rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
if exists(mask):
rand_span_mask &= mask
# mel is x1
x1 = inp
# x0 is gaussian noise
x0 = torch.randn_like(x1)
# time step
time = torch.rand((batch,), dtype=dtype, device=self.device)
# TODO. noise_scheduler
# sample xt (φ_t(x) in the paper)
t = time.unsqueeze(-1).unsqueeze(-1)
φ = (1 - t) * x0 + t * x1
flow = x1 - x0
# only predict what is within the random mask span for infilling
cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1)
# transformer and cfg training with a drop rate
drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
if random() < self.cond_drop_prob: # p_uncond in voicebox paper
drop_audio_cond = True
drop_text = True
else:
drop_text = False
# apply mask will use more memory; might adjust batchsize or batchsampler long sequence threshold
pred = self.transformer(
x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text, mask=mask
)
# flow matching loss
loss = F.mse_loss(pred, flow, reduction="none")
loss = loss[rand_span_mask]
return loss.mean(), cond, pred
================================================
FILE: src/f5_tts/model/dataset.py
================================================
import json
from importlib.resources import files
import torch
import torch.nn.functional as F
import torchaudio
from datasets import Dataset as Dataset_
from datasets import load_from_disk
from torch import nn
from torch.utils.data import Dataset, Sampler
from tqdm import tqdm
from f5_tts.model.modules import MelSpec
from f5_tts.model.utils import default
class HFDataset(Dataset):
def __init__(
self,
hf_dataset: Dataset,
target_sample_rate=24_000,
n_mel_channels=100,
hop_length=256,
n_fft=1024,
win_length=1024,
mel_spec_type="vocos",
):
self.data = hf_dataset
self.target_sample_rate = target_sample_rate
self.hop_length = hop_length
self.mel_spectrogram = MelSpec(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type,
)
def get_frame_len(self, index):
row = self.data[index]
audio = row["audio"]["array"]
sample_rate = row["audio"]["sampling_rate"]
return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
def __len__(self):
return len(self.data)
def __getitem__(self, index):
row = self.data[index]
audio = row["audio"]["array"]
# logger.info(f"Audio shape: {audio.shape}")
sample_rate = row["audio"]["sampling_rate"]
duration = audio.shape[-1] / sample_rate
if duration > 30 or duration < 0.3:
return self.__getitem__((index + 1) % len(self.data))
audio_tensor = torch.from_numpy(audio).float()
if sample_rate != self.target_sample_rate:
resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
audio_tensor = resampler(audio_tensor)
audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t')
mel_spec = self.mel_spectrogram(audio_tensor)
mel_spec = mel_spec.squeeze(0) # '1 d t -> d t'
text = row["text"]
return dict(
mel_spec=mel_spec,
text=text,
)
class CustomDataset(Dataset):
def __init__(
self,
custom_dataset: Dataset,
durations=None,
target_sample_rate=24_000,
hop_length=256,
n_mel_channels=100,
n_fft=1024,
win_length=1024,
mel_spec_type="vocos",
preprocessed_mel=False,
mel_spec_module: nn.Module | None = None,
):
self.data = custom_dataset
self.durations = durations
self.target_sample_rate = target_sample_rate
self.hop_length = hop_length
self.n_fft = n_fft
self.win_length = win_length
self.mel_spec_type = mel_spec_type
self.preprocessed_mel = preprocessed_mel
if not preprocessed_mel:
self.mel_spectrogram = default(
mel_spec_module,
MelSpec(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type,
),
)
def get_frame_len(self, index):
if (
self.durations is not None
): # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
return self.durations[index] * self.target_sample_rate / self.hop_length
return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
def __len__(self):
return len(self.data)
def __getitem__(self, index):
while True:
row = self.data[index]
audio_path = row["audio_path"]
text = row["text"]
duration = row["duration"]
# filter by given length
if 0.3 <= duration <= 30:
break # valid
index = (index + 1) % len(self.data)
if self.preprocessed_mel:
mel_spec = torch.tensor(row["mel_spec"])
else:
audio, source_sample_rate = torchaudio.load(audio_path)
# make sure mono input
if audio.shape[0] > 1:
audio = torch.mean(audio, dim=0, keepdim=True)
# resample if necessary
if source_sample_rate != self.target_sample_rate:
resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
audio = resampler(audio)
# to mel spectrogram
mel_spec = self.mel_spectrogram(audio)
mel_spec = mel_spec.squeeze(0) # '1 d t -> d t'
return {
"mel_spec": mel_spec,
"text": text,
}
# Dynamic Batch Sampler
class DynamicBatchSampler(Sampler[list[int]]):
"""Extension of Sampler that will do the following:
1. Change the batch size (essentially number of sequences)
in a batch to ensure that the total number of frames are less
than a certain threshold.
2. Make sure the padding efficiency in the batch is high.
3. Shuffle batches each epoch while maintaining reproducibility.
"""
def __init__(
self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_residual: bool = False
):
self.sampler = sampler
self.frames_threshold = frames_threshold
self.max_samples = max_samples
self.random_seed = random_seed
self.epoch = 0
indices, batches = [], []
data_source = self.sampler.data_source
for idx in tqdm(
self.sampler, desc="Sorting with sampler... if slow, check whether dataset is provided with duration"
):
indices.append((idx, data_source.get_frame_len(idx)))
indices.sort(key=lambda elem: elem[1])
batch = []
batch_frames = 0
for idx, frame_len in tqdm(
indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"
):
if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
batch.append(idx)
batch_frames += frame_len
else:
if len(batch) > 0:
batches.append(batch)
if frame_len <= self.frames_threshold:
batch = [idx]
batch_frames = frame_len
else:
batch = []
batch_frames = 0
if not drop_residual and len(batch) > 0:
batches.append(batch)
del indices
self.batches = batches
# Ensure even batches with accelerate BatchSamplerShard cls under frame_per_batch setting
self.drop_last = True
def set_epoch(self, epoch: int) -> None:
"""Sets the epoch for this sampler."""
self.epoch = epoch
def __iter__(self):
# Use both random_seed and epoch for deterministic but different shuffling per epoch
if self.random_seed is not None:
g = torch.Generator()
g.manual_seed(self.random_seed + self.epoch)
# Use PyTorch's random permutation for better reproducibility across PyTorch versions
indices = torch.randperm(len(self.batches), generator=g).tolist()
batches = [self.batches[i] for i in indices]
else:
batches = self.batches
return iter(batches)
def __len__(self):
return len(self.batches)
# Load dataset
def load_dataset(
dataset_name: str,
tokenizer: str = "pinyin",
dataset_type: str = "CustomDataset",
audio_type: str = "raw",
mel_spec_module: nn.Module | None = None,
mel_spec_kwargs: dict = dict(),
) -> CustomDataset | HFDataset:
"""
dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
- "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer
"""
print("Loading dataset ...")
if dataset_type == "CustomDataset":
rel_data_path = str(files("f5_tts").joinpath(f"../../data/{dataset_name}_{tokenizer}"))
if audio_type == "raw":
try:
train_dataset = load_from_disk(f"{rel_data_path}/raw")
except: # noqa: E722
train_dataset = Dataset_.from_file(f"{rel_data_path}/raw.arrow")
preprocessed_mel = False
elif audio_type == "mel":
train_dataset = Dataset_.from_file(f"{rel_data_path}/mel.arrow")
preprocessed_mel = True
with open(f"{rel_data_path}/duration.json", "r", encoding="utf-8") as f:
data_dict = json.load(f)
durations = data_dict["duration"]
train_dataset = CustomDataset(
train_dataset,
durations=durations,
preprocessed_mel=preprocessed_mel,
mel_spec_module=mel_spec_module,
**mel_spec_kwargs,
)
elif dataset_type == "CustomDatasetPath":
try:
train_dataset = load_from_disk(f"{dataset_name}/raw")
except: # noqa: E722
train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow")
with open(f"{dataset_name}/duration.json", "r", encoding="utf-8") as f:
data_dict = json.load(f)
durations = data_dict["duration"]
train_dataset = CustomDataset(
train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs
)
elif dataset_type == "HFDataset":
print(
"Should manually modify the path of huggingface dataset to your need.\n"
+ "May also the corresponding script cuz different dataset may have different format."
)
pre, post = dataset_name.split("_")
train_dataset = HFDataset(
load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir=str(files("f5_tts").joinpath("../../data"))),
)
return train_dataset
# collation
def collate_fn(batch):
mel_specs = [item["mel_spec"].squeeze(0) for item in batch]
mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
max_mel_length = mel_lengths.amax()
padded_mel_specs = []
for spec in mel_specs:
padding = (0, max_mel_length - spec.size(-1))
padded_spec = F.pad(spec, padding, value=0)
padded_mel_specs.append(padded_spec)
mel_specs = torch.stack(padded_mel_specs)
text = [item["text"] for item in batch]
text_lengths = torch.LongTensor([len(item) for item in text])
return dict(
mel=mel_specs,
mel_lengths=mel_lengths, # records for padding mask
text=text,
text_lengths=text_lengths,
)
================================================
FILE: src/f5_tts/model/modules.py
================================================
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
# ruff: noqa: F722 F821
from __future__ import annotations
import math
import warnings
from typing import Optional
import torch
import torch.nn.functional as F
import torchaudio
from librosa.filters import mel as librosa_mel_fn
from torch import nn
from x_transformers.x_transformers import apply_rotary_pos_emb
from f5_tts.model.utils import is_package_available
# raw wav to mel spec
mel_basis_cache = {}
hann_window_cache = {}
def get_bigvgan_mel_spectrogram(
waveform,
n_fft=1024,
n_mel_channels=100,
target_sample_rate=24000,
hop_length=256,
win_length=1024,
fmin=0,
fmax=None,
center=False,
): # Copy from https://github.com/NVIDIA/BigVGAN/tree/main
device = waveform.device
key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}"
if key not in mel_basis_cache:
mel = librosa_mel_fn(sr=target_sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=fmin, fmax=fmax)
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) # TODO: why they need .float()?
hann_window_cache[key] = torch.hann_window(win_length).to(device)
mel_basis = mel_basis_cache[key]
hann_window = hann_window_cache[key]
padding = (n_fft - hop_length) // 2
waveform = torch.nn.functional.pad(waveform.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
spec = torch.stft(
waveform,
n_fft,
hop_length=hop_length,
win_length=win_length,
window=hann_window,
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
mel_spec = torch.matmul(mel_basis, spec)
mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
return mel_spec
def get_vocos_mel_spectrogram(
waveform,
n_fft=1024,
n_mel_channels=100,
target_sample_rate=24000,
hop_length=256,
win_length=1024,
):
mel_stft = torchaudio.transforms.MelSpectrogram(
sample_rate=target_sample_rate,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
n_mels=n_mel_channels,
power=1,
center=True,
normalized=False,
norm=None,
).to(waveform.device)
if len(waveform.shape) == 3:
waveform = waveform.squeeze(1) # 'b 1 nw -> b nw'
assert len(waveform.shape) == 2
mel = mel_stft(waveform)
mel = mel.clamp(min=1e-5).log()
return mel
class MelSpec(nn.Module):
def __init__(
self,
n_fft=1024,
hop_length=256,
win_length=1024,
n_mel_channels=100,
target_sample_rate=24_000,
mel_spec_type="vocos",
):
super().__init__()
assert mel_spec_type in ["vocos", "bigvgan"], print("We only support two extract mel backend: vocos or bigvgan")
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.n_mel_channels = n_mel_channels
self.target_sample_rate = target_sample_rate
if mel_spec_type == "vocos":
self.extractor = get_vocos_mel_spectrogram
elif mel_spec_type == "bigvgan":
self.extractor = get_bigvgan_mel_spectrogram
self.register_buffer("dummy", torch.tensor(0), persistent=False)
def forward(self, wav):
if self.dummy.device != wav.device:
self.to(wav.device)
mel = self.extractor(
waveform=wav,
n_fft=self.n_fft,
n_mel_channels=self.n_mel_channels,
target_sample_rate=self.target_sample_rate,
hop_length=self.hop_length,
win_length=self.win_length,
)
return mel
# sinusoidal position embedding
class SinusPositionEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x, scale=1000):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
# convolutional position embedding
class ConvPositionEmbedding(nn.Module):
def __init__(self, dim, kernel_size=31, groups=16):
super().__init__()
assert kernel_size % 2 != 0
self.conv1d = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
nn.Mish(),
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
nn.Mish(),
)
self.layer_need_mask_idx = [i for i, layer in enumerate(self.conv1d) if isinstance(layer, nn.Conv1d)]
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None):
if mask is not None:
mask = mask.unsqueeze(1) # [B 1 N]
x = x.permute(0, 2, 1) # [B D N]
if mask is not None:
x = x.masked_fill(~mask, 0.0)
for i, block in enumerate(self.conv1d):
x = block(x)
if mask is not None and i in self.layer_need_mask_idx:
x = x.masked_fill(~mask, 0.0)
x = x.permute(0, 2, 1) # [B N D]
return x
# rotary positional embedding related
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
# https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cos = torch.cos(freqs) # real part
freqs_sin = torch.sin(freqs) # imaginary part
return torch.cat([freqs_cos, freqs_sin], dim=-1)
def get_pos_embed_indices(start, length, max_pos, scale=1.0):
# length = length if isinstance(length, int) else length.max()
scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
pos = (
start.unsqueeze(1)
+ (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
)
# avoid extra long error.
pos = torch.where(pos < max_pos, pos, max_pos - 1)
return pos
# Global Response Normalization layer (Instance Normalization ?)
class GRN(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x
# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
class ConvNeXtV2Block(nn.Module):
def __init__(
self,
dim: int,
intermediate_dim: int,
dilation: int = 1,
):
super().__init__()
padding = (dilation * (7 - 1)) // 2
self.dwconv = nn.Conv1d(
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
) # depthwise conv
self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.grn = GRN(intermediate_dim)
self.pwconv2 = nn.Linear(intermediate_dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = x.transpose(1, 2) # b n d -> b d n
x = self.dwconv(x)
x = x.transpose(1, 2) # b d n -> b n d
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.grn(x)
x = self.pwconv2(x)
return residual + x
# RMSNorm
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
self.native_rms_norm = float(torch.__version__[:3]) >= 2.4
def forward(self, x):
if self.native_rms_norm:
if self.weight.dtype in [torch.float16, torch.bfloat16]:
x = x.to(self.weight.dtype)
x = F.rms_norm(x, normalized_shape=(x.shape[-1],), weight=self.weight, eps=self.eps)
else:
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
if self.weight.dtype in [torch.float16, torch.bfloat16]:
x = x.to(self.weight.dtype)
x = x * self.weight
return x
# AdaLayerNorm
# return with modulated x for attn input, and params for later mlp modulation
class AdaLayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(dim, dim * 6)
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb=None):
emb = self.linear(self.silu(emb))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
# AdaLayerNorm for final layer
# return only with modulated x for attn input, cuz no more mlp modulation
class AdaLayerNorm_Final(nn.Module):
def __init__(self, dim):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(dim, dim * 2)
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb):
emb = self.linear(self.silu(emb))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
# FeedForward
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
activation = nn.GELU(approximate=approximate)
project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
def forward(self, x):
return self.ff(x)
# Attention with possible joint part
# modified from diffusers/src/diffusers/models/attention_processor.py
class Attention(nn.Module):
def __init__(
self,
processor: JointAttnProcessor | AttnProcessor,
dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
context_dim: Optional[int] = None, # if not None -> joint attention
context_pre_only: bool = False,
qk_norm: Optional[str] = None,
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.processor = processor
self.dim = dim
self.heads = heads
self.inner_dim = dim_head * heads
self.dropout = dropout
self.context_dim = context_dim
self.context_pre_only = context_pre_only
self.to_q = nn.Linear(dim, self.inner_dim)
self.to_k = nn.Linear(dim, self.inner_dim)
self.to_v = nn.Linear(dim, self.inner_dim)
if qk_norm is None:
self.q_norm = None
self.k_norm = None
elif qk_norm == "rms_norm":
self.q_norm = RMSNorm(dim_head, eps=1e-6)
self.k_norm = RMSNorm(dim_head, eps=1e-6)
else:
raise ValueError(f"Unimplemented qk_norm: {qk_norm}")
if self.context_dim is not None:
self.to_q_c = nn.Linear(context_dim, self.inner_dim)
self.to_k_c = nn.Linear(context_dim, self.inner_dim)
self.to_v_c = nn.Linear(context_dim, self.inner_dim)
if qk_norm is None:
self.c_q_norm = None
self.c_k_norm = None
elif qk_norm == "rms_norm":
self.c_q_norm = RMSNorm(dim_head, eps=1e-6)
self.c_k_norm = RMSNorm(dim_head, eps=1e-6)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, dim))
self.to_out.append(nn.Dropout(dropout))
if self.context_dim is not None and not self.context_pre_only:
self.to_out_c = nn.Linear(self.inner_dim, context_dim)
def forward(
self,
x: float["b n d"], # noised input x
c: float["b n d"] = None, # context c
mask: bool["b n"] | None = None,
rope=None, # rotary position embedding for x
c_rope=None, # rotary position embedding for c
c_mask: bool["b nt"] | None = None, # text mask
) -> torch.Tensor:
if c is not None:
return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope, c_mask=c_mask)
else:
return self.processor(self, x, mask=mask, rope=rope)
# Attention processor
if is_package_available("flash_attn"):
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import pad_input, unpad_input
class AttnProcessor:
def __init__(
self,
pe_attn_head: int | None = None, # number of attention head to apply rope, None for all
attn_backend: str = "torch", # "torch" or "flash_attn"
attn_mask_enabled: bool = True,
):
if attn_backend == "flash_attn":
assert is_package_available("flash_attn"), "Please install flash-attn first."
if attn_backend == "torch" and attn_mask_enabled:
warnings.warn(
"attn_mask_enabled=True with attn_backend='torch' can consume large GPU memory. "
"Please switch attn_backend to 'flash_attn'.",
UserWarning,
)
self.pe_attn_head = pe_attn_head
self.attn_backend = attn_backend
self.attn_mask_enabled = attn_mask_enabled
def __call__(
self,
attn: Attention,
x: float["b n d"], # noised input x
mask: bool["b n"] | None = None,
rope=None, # rotary position embedding
) -> torch.FloatTensor:
batch_size = x.shape[0]
# `sample` projections
query = attn.to_q(x)
key = attn.to_k(x)
value = attn.to_v(x)
# attention
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# qk norm
if attn.q_norm is not None:
query = attn.q_norm(query)
if attn.k_norm is not None:
key = attn.k_norm(key)
# apply rotary position embedding
if rope is not None:
freqs, xpos_scale = rope
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
if self.pe_attn_head is not None:
pn = self.pe_attn_head
query[:, :pn, :, :] = apply_rotary_pos_emb(query[:, :pn, :, :], freqs, q_xpos_scale)
key[:, :pn, :, :] = apply_rotary_pos_emb(key[:, :pn, :, :], freqs, k_xpos_scale)
else:
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
if self.attn_backend == "torch":
# mask. e.g. inference got a batch with different target durations, mask out the padding
if self.attn_mask_enabled and mask is not None:
attn_mask = mask
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
else:
attn_mask = None
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
elif self.attn_backend == "flash_attn":
query = query.transpose(1, 2) # [b, h, n, d] -> [b, n, h, d]
key = key.transpose(1, 2)
value = value.transpose(1, 2)
if self.attn_mask_enabled and mask is not None:
query, indices, q_cu_seqlens, q_max_seqlen_in_batch, _ = unpad_input(query, mask)
key, _, k_cu_seqlens, k_max_seqlen_in_batch, _ = unpad_input(key, mask)
value, _, _, _, _ = unpad_input(value, mask)
x = flash_attn_varlen_func(
query,
key,
value,
q_cu_seqlens,
k_cu_seqlens,
q_max_seqlen_in_batch,
k_max_seqlen_in_batch,
)
x = pad_input(x, indices, batch_size, q_max_seqlen_in_batch)
x = x.reshape(batch_size, -1, attn.heads * head_dim)
else:
x = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
x = x.reshape(batch_size, -1, attn.heads * head_dim)
x = x.to(query.dtype)
# linear proj
x = attn.to_out[0](x)
# dropout
x = attn.to_out[1](x)
if mask is not None:
mask = mask.unsqueeze(-1)
x = x.masked_fill(~mask, 0.0)
return x
# Joint Attention processor for MM-DiT
# modified from diffusers/src/diffusers/models/attention_processor.py
class JointAttnProcessor:
def __init__(
self,
attn_backend: str = "torch", # "torch" or "flash_attn"
attn_mask_enabled: bool = True,
):
if attn_backend == "flash_attn":
assert is_package_available("flash_attn"), "Please install flash-attn first."
if attn_backend == "torch" and attn_mask_enabled:
warnings.warn(
"attn_mask_enabled=True with attn_backend='torch' can consume large GPU memory. "
"Please switch attn_backend to 'flash_attn'.",
UserWarning,
)
self.attn_backend = attn_backend
self.attn_mask_enabled = attn_mask_enabled
def __call__(
self,
attn: Attention,
x: float["b n d"], # noised input x
c: float["b nt d"] = None, # context c, here text
mask: bool["b n"] | None = None,
rope=None, # rotary position embedding for x
c_rope=None, # rotary position embedding for c
c_mask: bool["b nt"] | None = None, # text mask
) -> torch.FloatTensor:
residual = x
audio_mask = mask
batch_size = c.shape[0]
# `sample` projections
query = attn.to_q(x)
key = attn.to_k(x)
value = attn.to_v(x)
# `context` projections
c_query = attn.to_q_c(c)
c_key = attn.to_k_c(c)
c_value = attn.to_v_c(c)
# attention
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
c_query = c_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
c_key = c_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
c_value = c_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# qk norm
if attn.q_norm is not None:
query = attn.q_norm(query)
if attn.k_norm is not None:
key = attn.k_norm(key)
if attn.c_q_norm is not None:
c_query = attn.c_q_norm(c_query)
if attn.c_k_norm is not None:
c_key = attn.c_k_norm(c_key)
# apply rope for context and noised input independently
if rope is not None:
freqs, xpos_scale = rope
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
if c_rope is not None:
freqs, xpos_scale = c_rope
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
# joint attention
query = torch.cat([query, c_query], dim=2)
key = torch.cat([key, c_key], dim=2)
value = torch.cat([value, c_value], dim=2)
# build combined mask for joint attention: audio mask + text mask
if self.attn_mask_enabled and mask is not None:
if c_mask is not None:
mask = torch.cat([mask, c_mask], dim=1)
else:
mask = F.pad(mask, (0, c.shape[1]), value=True)
if self.attn_backend == "torch":
# mask. e.g. inference got a batch with different target durations, mask out the padding
if self.attn_mask_enabled and mask is not None:
attn_mask = mask
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
else:
attn_mask = None
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
elif self.attn_backend == "flash_attn":
query = query.transpose(1, 2) # [b, h, n, d] -> [b, n, h, d]
key = key.transpose(1, 2)
value = value.transpose(1, 2)
if self.attn_mask_enabled and mask is not None:
total_seq_len = query.shape[1]
query, indices, q_cu_seqlens, q_max_seqlen_in_batch, _ = unpad_input(query, mask)
key, _, k_cu_seqlens, k_max_seqlen_in_batch, _ = unpad_input(key, mask)
value, _, _, _, _ = unpad_input(value, mask)
x = flash_attn_varlen_func(
query,
key,
value,
q_cu_seqlens,
k_cu_seqlens,
q_max_seqlen_in_batch,
k_max_seqlen_in_batch,
)
x = pad_input(x, indices, batch_size, total_seq_len)
x = x.reshape(batch_size, -1, attn.heads * head_dim)
else:
x = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
x = x.reshape(batch_size, -1, attn.heads * head_dim)
x = x.to(query.dtype)
# Split the attention outputs.
x, c = (
x[:, : residual.shape[1]],
x[:, residual.shape[1] :],
)
# linear proj
x = attn.to_out[0](x)
# dropout
x = attn.to_out[1](x)
if not attn.context_pre_only:
c = attn.to_out_c(c)
if audio_mask is not None:
x = x.masked_fill(~audio_mask.unsqueeze(-1), 0.0)
if c_mask is not None:
c = c.masked_fill(~c_mask.unsqueeze(-1), 0.0)
return x, c
# DiT Block
class DiTBlock(nn.Module):
def __init__(
self,
dim,
heads,
dim_head,
ff_mult=4,
dropout=0.1,
qk_norm=None,
pe_attn_head=None,
attn_backend="torch", # "torch" or "flash_attn"
attn_mask_enabled=True,
):
super().__init__()
self.attn_norm = AdaLayerNorm(dim)
self.attn = Attention(
processor=AttnProcessor(
pe_attn_head=pe_attn_head,
attn_backend=attn_backend,
attn_mask_enabled=attn_mask_enabled,
),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
qk_norm=qk_norm,
)
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
# pre-norm & modulation for attention input
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
# attention
attn_output = self.attn(x=norm, mask=mask, rope=rope)
# process attention output for input x
x = x + gate_msa.unsqueeze(1) * attn_output
norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(norm)
x = x + gate_mlp.unsqueeze(1) * ff_output
return x
# MMDiT Block https://arxiv.org/abs/2403.03206
class MMDiTBlock(nn.Module):
r"""
modified from diffusers/src/diffusers/models/attention.py
notes.
_c: context related. text, cond, etc. (left part in sd3 fig2.b)
_x: noised input related. (right part)
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
"""
def __init__(
self,
dim,
heads,
dim_head,
ff_mult=4,
dropout=0.1,
context_dim=None,
context_pre_only=False,
qk_norm=None,
attn_backend="torch",
attn_mask_enabled=False,
):
super().__init__()
if context_dim is None:
context_dim = dim
self.context_pre_only = context_pre_only
self.attn_norm_c = AdaLayerNorm_Final(context_dim) if context_pre_only else AdaLayerNorm(context_dim)
self.attn_norm_x = AdaLayerNorm(dim)
self.attn = Attention(
processor=JointAttnProcessor(
attn_backend=attn_backend,
attn_mask_enabled=attn_mask_enabled,
),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
context_dim=context_dim,
context_pre_only=context_pre_only,
qk_norm=qk_norm,
)
if not context_pre_only:
self.ff_norm_c = nn.LayerNorm(context_dim, elementwise_affine=False, eps=1e-6)
self.ff_c = FeedForward(dim=context_dim, mult=ff_mult, dropout=dropout, approximate="tanh")
else:
self.ff_norm_c = None
self.ff_c = None
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
def forward(
self, x, c, t, mask=None, rope=None, c_rope=None, c_mask=None
): # x: noised input, c: context, t: time embedding
# pre-norm & modulation for attention input
if self.context_pre_only:
norm_c = self.attn_norm_c(c, t)
else:
norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
# attention
x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope, c_mask=c_mask)
# process attention output for context c
if self.context_pre_only:
c = None
else: # if not last layer
c = c + c_gate_msa.unsqueeze(1) * c_attn_output
norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
c_ff_output = self.ff_c(norm_c)
c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
# process attention output for input x
x = x + x_gate_msa.unsqueeze(1) * x_attn_output
norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
x_ff_output = self.ff_x(norm_x)
x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
return c, x
# time step conditioning embedding
class TimestepEmbedding(nn.Module):
def __init__(self, dim, freq_embed_dim=256):
super().__init__()
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
def forward(self, timestep: float["b"]):
time_hidden = self.time_embed(timestep)
time_hidden = time_hidden.to(timestep.dtype)
time = self.time_mlp(time_hidden) # b d
return time
================================================
FILE: src/f5_tts/model/trainer.py
================================================
from __future__ import annotations
import gc
import math
import os
import torch
import torchaudio
import wandb
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
from ema_pytorch import EMA
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR, SequentialLR
from torch.utils.data import DataLoader, Dataset, SequentialSampler
from tqdm import tqdm
from f5_tts.model import CFM
from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
from f5_tts.model.utils import default, exists
# trainer
class Trainer:
def __init__(
self,
model: CFM,
epochs,
learning_rate,
num_warmup_updates=20000,
save_per_updates=1000,
keep_last_n_checkpoints: int = -1, # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
checkpoint_path=None,
batch_size_per_gpu=32,
batch_size_type: str = "sample",
max_samples=32,
grad_accumulation_steps=1,
max_grad_norm=1.0,
noise_scheduler: str | None = None,
duration_predictor: torch.nn.Module | None = None,
logger: str | None = "wandb", # "wandb" | "tensorboard" | None
wandb_project="test_f5-tts",
wandb_run_name="test_run",
wandb_resume_id: str = None,
log_samples: bool = False,
last_per_updates=None,
accelerate_kwargs: dict = dict(),
ema_kwargs: dict = dict(),
bnb_optimizer: bool = False,
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
is_local_vocoder: bool = False, # use local path vocoder
local_vocoder_path: str = "", # local vocoder path
model_cfg_dict: dict = dict(), # training config
):
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
if logger == "wandb" and not wandb.api.api_key:
logger = None
self.log_samples = log_samples
self.accelerator = Accelerator(
log_with=logger if logger == "wandb" else None,
kwargs_handlers=[ddp_kwargs],
gradient_accumulation_steps=grad_accumulation_steps,
**accelerate_kwargs,
)
self.logger = logger
if self.logger == "wandb":
if exists(wandb_resume_id):
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
else:
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
if not model_cfg_dict:
model_cfg_dict = {
"epochs": epochs,
"learning_rate": learning_rate,
"num_warmup_updates": num_warmup_updates,
"batch_size_per_gpu": batch_size_per_gpu,
"batch_size_type": batch_size_type,
"max_samples": max_samples,
"grad_accumulation_steps": grad_accumulation_steps,
"max_grad_norm": max_grad_norm,
"noise_scheduler": noise_scheduler,
"bnb_optimizer": bnb_optimizer,
}
model_cfg_dict["gpus"] = self.accelerator.num_processes
self.accelerator.init_trackers(
project_name=wandb_project,
init_kwargs=init_kwargs,
config=model_cfg_dict,
)
elif self.logger == "tensorboard":
from torch.utils.tensorboard import SummaryWriter
self.writer = None
if self.accelerator.is_main_process:
self.writer = SummaryWriter(log_dir=f"runs/{wandb_run_name}")
self.model = model
if self.is_main:
self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
self.ema_model.to(self.accelerator.device)
print(f"Using logger: {logger}")
if grad_accumulation_steps > 1:
print(
"Gradient accumulation checkpointing with per_updates now, old logic per_steps used with before f992c4e"
)
self.epochs = epochs
self.num_warmup_updates = num_warmup_updates
self.save_per_updates = save_per_updates
self.keep_last_n_checkpoints = keep_last_n_checkpoints
self.last_per_updates = default(last_per_updates, save_per_updates)
self.checkpoint_path = default(checkpoint_path, "ckpts/test_f5-tts")
self.batch_size_per_gpu = batch_size_per_gpu
self.batch_size_type = batch_size_type
self.max_samples = max_samples
self.grad_accumulation_steps = grad_accumulation_steps
self.max_grad_norm = max_grad_norm
# mel vocoder config
self.vocoder_name = mel_spec_type
self.is_local_vocoder = is_local_vocoder
self.local_vocoder_path = local_vocoder_path
self.noise_scheduler = noise_scheduler
self.duration_predictor = duration_predictor
if bnb_optimizer:
import bitsandbytes as bnb
self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
else:
self.optimizer = AdamW(model.parameters(), lr=learning_rate, fused=True)
self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
@property
def is_main(self):
return self.accelerator.is_main_process
def save_checkpoint(self, update, last=False):
self.accelerator.wait_for_everyone()
if self.is_main:
checkpoint = dict(
model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
optimizer_state_dict=self.optimizer.state_dict(),
ema_model_state_dict=self.ema_model.state_dict(),
scheduler_state_dict=self.scheduler.state_dict(),
update=update,
)
if not os.path.exists(self.checkpoint_path):
os.makedirs(self.checkpoint_path)
if last:
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
print(f"Saved last checkpoint at update {update}")
else:
if self.keep_last_n_checkpoints == 0:
return
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt")
if self.keep_last_n_checkpoints > 0:
# Updated logic to exclude pretrained model from rotation
checkpoints = [
f
for f in os.listdir(self.checkpoint_path)
if f.startswith("model_")
and not f.startswith("pretrained_") # Exclude pretrained models
and f.endswith(".pt")
and f != "model_last.pt"
]
checkpoints.sort(key=lambda x: int(x.split("_")[1].split(".")[0]))
while len(checkpoints) > self.keep_last_n_checkpoints:
oldest_checkpoint = checkpoints.pop(0)
os.remove(os.path.join(self.checkpoint_path, oldest_checkpoint))
print(f"Removed old checkpoint: {oldest_checkpoint}")
def load_checkpoint(self):
if (
not exists(self.checkpoint_path)
or not os.path.exists(self.checkpoint_path)
or not any(filename.endswith((".pt", ".safetensors")) for filename in os.listdir(self.checkpoint_path))
):
return 0
self.accelerator.wait_for_everyone()
if "model_last.pt" in os.listdir(self.checkpoint_path):
latest_checkpoint = "model_last.pt"
else:
# Updated to consider pretrained models for loading but prioritize training checkpoints
all_checkpoints = [
f
for f in os.listdir(self.checkpoint_path)
if (f.startswith("model_") or f.startswith("pretrained_")) and f.endswith((".pt", ".safetensors"))
]
# First try to find regular training checkpoints
training_checkpoints = [f for f in all_checkpoints if f.startswith("model_") and f != "model_last.pt"]
if training_checkpoints:
latest_checkpoint = sorted(
training_checkpoints,
key=lambda x: int("".join(filter(str.isdigit, x))),
)[-1]
else:
# If no training checkpoints, use pretrained model
latest_checkpoint = next(f for f in all_checkpoints if f.startswith("pretrained_"))
if latest_checkpoint.endswith(".safetensors"): # always a pretrained checkpoint
from safetensors.torch import load_file
checkpoint = load_file(f"{self.checkpoint_path}/{latest_checkpoint}", device="cpu")
checkpoint = {"ema_model_state_dict": checkpoint}
elif latest_checkpoint.endswith(".pt"):
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
checkpoint = torch.load(
f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu"
)
# patch for backward compatibility, 305e3ea
for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]:
if key in checkpoint["ema_model_state_dict"]:
del checkpoint["ema_model_state_dict"][key]
if self.is_main:
self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
if "update" in checkpoint or "step" in checkpoint:
# patch for backward compatibility, with before f992c4e
if "step" in checkpoint:
checkpoint["update"] = checkpoint["step"] // self.grad_accumulation_steps
if self.grad_accumulation_steps > 1 and self.is_main:
print(
"F5-TTS WARNING: Loading checkpoint saved with per_steps logic (before f992c4e), will convert to per_updates according to grad_accumulation_steps setting, may have unexpected behaviour."
)
# patch for backward compatibility, 305e3ea
for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
if key in checkpoint["model_state_dict"]:
del checkpoint["model_state_dict"][key]
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
if self.scheduler:
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
update = checkpoint["update"]
else:
checkpoint["model_state_dict"] = {
k.replace("ema_model.", ""): v
for k, v in checkpoint["ema_model_state_dict"].items()
if k not in ["initted", "update", "step"]
}
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
update = 0
del checkpoint
gc.collect()
return update
def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
if self.log_samples:
from f5_tts.infer.utils_infer import cfg_strength, load_vocoder, nfe_step, sway_sampling_coef
vocoder = load_vocoder(
vocoder_name=self.vocoder_name, is_local=self.is_local_vocoder, local_path=self.local_vocoder_path
)
target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.target_sample_rate
log_samples_path = f"{self.checkpoint_path}/samples"
os.makedirs(log_samples_path, exist_ok=True)
if exists(resumable_with_seed):
generator = torch.Generator()
generator.manual_seed(resumable_with_seed)
else:
generator = None
if self.batch_size_type == "sample":
train_dataloader = DataLoader(
train_dataset,
collate_fn=collate_fn,
num_workers=num_workers,
pin_memory=True,
persistent_workers=True,
batch_size=self.batch_size_per_gpu,
shuffle=True,
generator=generator,
)
elif self.batch_size_type == "frame":
self.accelerator.even_batches = False
sampler = SequentialSampler(train_dataset)
batch_sampler = DynamicBatchSampler(
sampler,
self.batch_size_per_gpu,
max_samples=self.max_samples,
random_seed=resumable_with_seed, # This enables reproducible shuffling
drop_residual=False,
)
train_dataloader = DataLoader(
train_dataset,
collate_fn=collate_fn,
num_workers=num_workers,
pin_memory=True,
persistent_workers=True,
batch_sampler=batch_sampler,
)
else:
raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
# accelerator.prepare() dispatches batches to devices;
# which means the length of dataloader calculated before, should consider the number of devices
warmup_updates = (
self.num_warmup_updates * self.accelerator.num_processes
) # consider a fixed warmup steps while using accelerate multi-gpu ddp
# otherwise by default with split_batches=False, warmup steps change with num_processes
total_updates = math.ceil(len(train_dataloader) / self.grad_accumulation_steps) * self.epochs
decay_updates = total_updates - warmup_updates
warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_updates)
decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_updates)
self.scheduler = SequentialLR(
self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_updates]
)
train_dataloader, self.scheduler = self.accelerator.prepare(
train_dataloader, self.scheduler
) # actual multi_gpu updates = single_gpu updates / gpu nums
start_update = self.load_checkpoint()
global_update = start_update
if exists(resumable_with_seed):
orig_epoch_step = len(train_dataloader)
start_step = start_update * self.grad_accumulation_steps
skipped_epoch = int(start_step // orig_epoch_step)
skipped_batch = start_step % orig_epoch_step
skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
else:
skipped_epoch = 0
for epoch in range(skipped_epoch, self.epochs):
self.model.train()
if exists(resumable_with_seed) and epoch == skipped_epoch:
progress_bar_initial = math.ceil(skipped_batch / self.grad_accumulation_steps)
current_dataloader = skipped_dataloader
else:
progress_bar_initial = 0
current_dataloader = train_dataloader
# Set epoch for the batch sampler if it exists
if hasattr(train_dataloader, "batch_sampler") and hasattr(train_dataloader.batch_sampler, "set_epoch"):
train_dataloader.batch_sampler.set_epoch(epoch)
progress_bar = tqdm(
range(math.ceil(len(train_dataloader) / self.grad_accumulation_steps)),
desc=f"Epoch {epoch + 1}/{self.epochs}",
unit="update",
disable=not self.accelerator.is_local_main_process,
initial=progress_bar_initial,
)
for batch in current_dataloader:
with self.accelerator.accumulate(self.model):
text_inputs = batch["text"]
mel_spec = batch["mel"].permute(0, 2, 1)
mel_lengths = batch["mel_lengths"]
# TODO. add duration predictor training
if self.duration_predictor is not None and self.accelerator.is_local_main_process:
dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations"))
self.accelerator.log({"duration loss": dur_loss.item()}, step=global_update)
loss, cond, pred = self.model(
mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
)
self.accelerator.backward(loss)
if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
if self.accelerator.sync_gradients:
if self.is_main:
self.ema_model.update()
global_update += 1
progress_bar.update(1)
progress_bar.set_postfix(update=str(global_update), loss=loss.item())
if self.accelerator.is_local_main_process:
self.accelerator.log(
{"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_update
)
if self.logger == "tensorboard" and self.accelerator.is_main_process:
self.writer.add_scalar("loss", loss.item(), global_update)
self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_update)
if global_update % self.last_per_updates == 0 and self.accelerator.sync_gradients:
self.save_checkpoint(global_update, last=True)
if global_update % self.save_per_updates == 0 and self.accelerator.sync_gradients:
self.save_checkpoint(global_update)
if self.log_samples and self.accelerator.is_local_main_process:
ref_audio_len = mel_lengths[0]
infer_text = [
text_inputs[0] + ([" "] if isinstance(text_inputs[0], list) else " ") + text_inputs[0]
]
with torch.inference_mode(), self.accelerator.autocast():
generated, _ = self.accelerator.unwrap_model(self.model).sample(
cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
text=infer_text,
duration=ref_audio_len * 2,
steps=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
)
generated = generated.to(torch.float32)
gen_mel_spec = generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device)
ref_mel_spec = batch["mel"][0, :, :ref_audio_len].unsqueeze(0)
if self.vocoder_name == "vocos":
gen_audio = vocoder.decode(gen_mel_spec).cpu()
ref_audio = vocoder.decode(ref_mel_spec).cpu()
elif self.vocoder_name == "bigvgan":
gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
torchaudio.save(
f"{log_samples_path}/update_{global_update}_gen.wav", gen_audio, target_sample_rate
)
torchaudio.save(
f"{log_samples_path}/update_{global_update}_ref.wav", ref_audio, target_sample_rate
)
self.model.train()
self.save_checkpoint(global_update, last=True)
self.accelerator.end_training()
================================================
FILE: src/f5_tts/model/utils.py
================================================
# ruff: noqa: F722 F821
from __future__ import annotations
import os
import random
from collections import defaultdict
from importlib.resources import files
import rjieba
import torch
from pypinyin import Style, lazy_pinyin
from torch.nn.utils.rnn import pad_sequence
# seed everything
def seed_everything(seed=0):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# helpers
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def is_package_available(package_name: str) -> bool:
try:
import importlib
package_exists = importlib.util.find_spec(package_name) is not None
return package_exists
except Exception:
return False
# tensor helpers
def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]:
if not exists(length):
length = t.amax()
seq = torch.arange(length, device=t.device)
return seq[None, :] < t[:, None]
def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]):
max_seq_len = seq_len.max().item()
seq = torch.arange(max_seq_len, device=start.device).long()
start_mask = seq[None, :] >= start[:, None]
end_mask = seq[None, :] < end[:, None]
return start_mask & end_mask
def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]):
lengths = (frac_lengths * seq_len).long()
max_start = seq_len - lengths
rand = torch.rand_like(frac_lengths)
start = (max_start * rand).long().clamp(min=0)
end = start + lengths
return mask_from_start_end_indices(seq_len, start, end)
def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]:
if not exists(mask):
return t.mean(dim=1)
t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device))
num = t.sum(dim=1)
den = mask.float().sum(dim=1)
return num / den.clamp(min=1.0)
# simple utf-8 tokenizer, since paper went character based
def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]:
list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style
text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True)
return text
# char tokenizer, based on custom dataset's extracted .txt file
def list_str_to_idx(
text: list[str] | list[list[str]],
vocab_char_map: dict[str, int], # {char: idx}
padding_value=-1,
) -> int["b nt"]:
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
return text
# Get tokenizer
def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
"""
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
- "char" for char-wise tokenizer, need .txt vocab_file
- "byte" for utf-8 tokenizer
- "custom" if you're directly passing in a path to the vocab.txt you want to use
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
- if use "char", derived from unfiltered character & symbol counts of custom dataset
- if use "byte", set to 256 (unicode byte range)
"""
if tokenizer in ["pinyin", "char"]:
tokenizer_path = os.path.join(files("f5_tts").joinpath("../../data"), f"{dataset_name}_{tokenizer}/vocab.txt")
with open(tokenizer_path, "r", encoding="utf-8") as f:
vocab_char_map = {}
for i, char in enumerate(f):
vocab_char_map[char[:-1]] = i
vocab_size = len(vocab_char_map)
assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
elif tokenizer == "byte":
vocab_char_map = None
vocab_size = 256
elif tokenizer == "custom":
with open(dataset_name, "r", encoding="utf-8") as f:
vocab_char_map = {}
for i, char in enumerate(f):
vocab_char_map[char[:-1]] = i
vocab_size = len(vocab_char_map)
return vocab_char_map, vocab_size
# convert char to pinyin
def convert_char_to_pinyin(text_list, polyphone=True):
final_text_list = []
custom_trans = str.maketrans(
{";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
) # add custom trans here, to address oov
def is_chinese(c):
return (
"\u3100" <= c <= "\u9fff" # common chinese characters
)
for text in text_list:
char_list = []
text = text.translate(custom_trans)
for seg in rjieba.cut(text):
seg_byte_len = len(bytes(seg, "UTF-8"))
if seg_byte_len == len(seg): # if pure alphabets and symbols
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
char_list.append(" ")
char_list.extend(seg)
elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
for i, c in enumerate(seg):
if is_chinese(c):
char_list.append(" ")
char_list.append(seg_[i])
else: # if mixed characters, alphabets and symbols
for c in seg:
if ord(c) < 256:
char_list.extend(c)
elif is_chinese(c):
char_list.append(" ")
char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
else:
char_list.append(c)
final_text_list.append(char_list)
return final_text_list
# filter func for dirty data with many repetitions
def repetition_found(text, length=2, tolerance=10):
pattern_count = defaultdict(int)
for i in range(len(text) - length + 1):
pattern = text[i : i + length]
pattern_count[pattern] += 1
for pattern, count in pattern_count.items():
if count > tolerance:
return True
return False
# get the empirically pruned step for sampling
def get_epss_timesteps(n, device, dtype):
dt = 1 / 32
predefined_timesteps = {
5: [0, 2, 4, 8, 16, 32],
6: [0, 2, 4, 6, 8, 16, 32],
7: [0, 2, 4, 6, 8, 16, 24, 32],
10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32],
12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32],
16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32],
}
t = predefined_timesteps.get(n, [])
if not t:
return torch.linspace(0, 1, n + 1, device=device, dtype=dtype)
return dt * torch.tensor(t, device=device, dtype=dtype)
================================================
FILE: src/f5_tts/runtime/triton_trtllm/.gitignore
================================================
# runtime/triton_trtllm related
model.cache
model_repo/
================================================
FILE: src/f5_tts/runtime/triton_trtllm/Dockerfile.server
================================================
FROM nvcr.io/nvidia/tritonserver:24.12-py3
RUN pip install tritonclient[grpc] tensorrt-llm==0.16.0 torchaudio==2.5.1 rjieba pypinyin librosa vocos
WORKDIR /workspace
================================================
FILE: src/f5_tts/runtime/triton_trtllm/README.md
================================================
## Triton Inference Serving Best Practice for F5-TTS
### Setup
#### Option 1: Quick Start
```sh
# Directly launch the service using docker compose
MODEL=F5TTS_v1_Base docker compose up
```
#### Option 2: Build from scratch
```sh
# Build the docker image
docker build . -f Dockerfile.server -t soar97/triton-f5-tts:24.12
# Create Docker Container
your_mount_dir=/mnt:/mnt
docker run -it --name "f5-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-f5-tts:24.12
```
### Build TensorRT-LLM Engines and Launch Server
Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/models/core/whisper).
```sh
# F5TTS_v1_Base | F5TTS_Base | F5TTS_v1_Small | F5TTS_Small
bash run.sh 0 4 F5TTS_v1_Base
```
> [!NOTE]
> If use custom checkpoint, set `ckpt_file` and `vocab_file` in `run.sh`.
> Remember to used matched model version (`F5TTS_v1_*` for v1, `F5TTS_*` for v0).
>
> If use checkpoint of different structure, see `scripts/convert_checkpoint.py`, and perform modification if necessary.
> [!IMPORTANT]
> If train or finetune with fp32, add `--dtype float32` flag when converting checkpoint in `run.sh` phase 1.
### HTTP Client
```sh
python3 client_http.py
```
### Benchmarking
#### Using Client-Server Mode
```sh
# bash run.sh 5 5 F5TTS_v1_Base
num_task=2
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts
```
#### Using Offline TRT-LLM Mode
```sh
# bash run.sh 7 7 F5TTS_v1_Base
batch_size=1
split_name=wenetspeech4tts
backend_type=trt
log_dir=./tests/benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
rm -r $log_dir
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--enable-warmup \
--split-name $split_name \
--model-path $ckpt_file \
--vocab-file $vocab_file \
--vocoder-trt-engine-path $VOCODER_TRT_ENGINE_PATH \
--backend-type $backend_type \
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
```
### Benchmark Results
Decoding on a single L20 GPU, using 26 different prompt_audio & target_text pairs, 16 NFE.
| Model | Concurrency | Avg Latency | RTF | Mode |
|---------------------|----------------|-------------|--------|-----------------|
| F5-TTS Base (Vocos) | 2 | 253 ms | 0.0394 | Client-Server |
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.0402 | Offline TRT-LLM |
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.1467 | Offline Pytorch |
### Credits
1. [Yuekai Zhang](https://github.com/yuekaizhang)
2. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)
================================================
FILE: src/f5_tts/runtime/triton_trtllm/benchmark.py
================================================
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
# 2025 (authors: Yuekai Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/cli.py
""" Example Usage
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--enable-warmup \
--split-name $split_name \
--model-path $CKPT_DIR/$model/model_1200000.pt \
--vocab-file $CKPT_DIR/$model/vocab.txt \
--vocoder-trt-engine-path $vocoder_trt_engine_path \
--backend-type $backend_type \
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
"""
import argparse
import importlib
import json
import os
import sys
import time
import datasets
import tensorrt as trt
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torchaudio
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from tensorrt_llm._utils import trt_dtype_to_torch
from tensorrt_llm.logger import logger
from tensorrt_llm.runtime.session import Session, TensorInfo
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from vocos import Vocos
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
from f5_tts.eval.utils_eval import padded_mel_batch
from f5_tts.model.modules import get_vocos_mel_spectrogram
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer, list_str_to_idx
F5TTS = importlib.import_module("model_repo_f5_tts.f5_tts.1.f5_tts_trtllm").F5TTS
torch.manual_seed(0)
def get_args():
parser = argparse.ArgumentParser(description="extract speech code")
parser.add_argument(
"--split-name",
type=str,
default="wenetspeech4tts",
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
help="huggingface dataset split name",
)
parser.add_argument("--output-dir", required=True, type=str, help="dir to save result")
parser.add_argument(
"--vocab-file",
required=True,
type=str,
help="vocab file",
)
parser.add_argument(
"--model-path",
required=True,
type=str,
help="model path, to load text embedding",
)
parser.add_argument(
"--tllm-model-dir",
required=True,
type=str,
help="tllm model dir",
)
parser.add_argument(
"--batch-size",
required=True,
type=int,
help="batch size (per-device) for inference",
)
parser.add_argument("--num-workers", type=int, default=0, help="workers for dataloader")
parser.add_argument("--prefetch", type=int, default=None, help="prefetch for dataloader")
parser.add_argument(
"--vocoder",
default="vocos",
type=str,
help="vocoder name",
)
parser.add_argument(
"--vocoder-trt-engine-path",
default=None,
type=str,
help="vocoder trt engine path",
)
parser.add_argument("--enable-warmup", action="store_true")
parser.add_argument("--remove-input-padding", action="store_true")
parser.add_argument("--use-perf", action="store_true", help="use nvtx to record performance")
parser.add_argument("--backend-type", type=str, default="triton", choices=["trt", "pytorch"], help="backend type")
args = parser.parse_args()
return args
def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
if use_perf:
torch.cuda.nvtx.range_push("data_collator")
target_sample_rate = 24000
target_rms = 0.1
(
ids,
ref_rms_list,
ref_mel_list,
ref_mel_len_list,
estimated_reference_target_mel_len,
reference_target_texts_list,
) = (
[],
[],
[],
[],
[],
[],
)
for i, item in enumerate(batch):
item_id, prompt_text, target_text = (
item["id"],
item["prompt_text"],
item["target_text"],
)
ids.append(item_id)
reference_target_texts_list.append(prompt_text + target_text)
ref_audio_org, ref_sr = (
item["prompt_audio"]["array"],
item["prompt_audio"]["sampling_rate"],
)
ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float()
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
ref_rms_list.append(ref_rms)
if ref_rms < target_rms:
ref_audio_org = ref_audio_org * target_rms / ref_rms
if ref_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
ref_audio = resampler(ref_audio_org)
else:
ref_audio = ref_audio_org
if use_perf:
torch.cuda.nvtx.range_push(f"mel_spectrogram {i}")
ref_audio = ref_audio.to("cuda")
ref_mel = get_vocos_mel_spectrogram(ref_audio).squeeze(0)
if use_perf:
torch.cuda.nvtx.range_pop()
ref_mel_len = ref_mel.shape[-1]
assert ref_mel.shape[0] == 100
ref_mel_list.append(ref_mel)
ref_mel_len_list.append(ref_mel_len)
estimated_reference_target_mel_len.append(
int(ref_mel_len * (1 + len(target_text.encode("utf-8")) / len(prompt_text.encode("utf-8"))))
)
ref_mel_batch = padded_mel_batch(ref_mel_list)
ref_mel_len_batch = torch.LongTensor(ref_mel_len_list)
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
text_pad_sequence = list_str_to_idx(pinyin_list, vocab_char_map)
if use_perf:
torch.cuda.nvtx.range_pop()
return {
"ids": ids,
"ref_rms_list": ref_rms_list,
"ref_mel_batch": ref_mel_batch,
"ref_mel_len_batch": ref_mel_len_batch,
"text_pad_sequence": text_pad_sequence,
"estimated_reference_target_mel_len": estimated_reference_target_mel_len,
}
def init_distributed():
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
rank = int(os.environ.get("RANK", 0))
print(
"Inference on multiple gpus, this gpu {}".format(local_rank)
+ ", rank {}, world_size {}".format(rank, world_size)
)
torch.cuda.set_device(local_rank)
# Initialize process group with explicit device IDs
dist.init_process_group(
"nccl",
)
return world_size, local_rank, rank
def load_vocoder(
vocoder_name="vocos", is_local=False, local_path="", device="cuda", hf_cache_dir=None, vocoder_trt_engine_path=None
):
if vocoder_name == "vocos":
if vocoder_trt_engine_path is not None:
vocoder = VocosTensorRT(engine_path=vocoder_trt_engine_path)
else:
# vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
if is_local:
print(f"Load vocos from local path {local_path}")
config_path = f"{local_path}/config.yaml"
model_path = f"{local_path}/pytorch_model.bin"
else:
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
repo_id = "charactr/vocos-mel-24khz"
config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
vocoder = Vocos.from_hparams(config_path)
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
from vocos.feature_extractors import EncodecFeatures
if isinstance(vocoder.feature_extractor, EncodecFeatures):
encodec_parameters = {
"feature_extractor.encodec." + key: value
for key, value in vocoder.feature_extractor.encodec.state_dict().items()
}
state_dict.update(encodec_parameters)
vocoder.load_state_dict(state_dict)
vocoder = vocoder.eval().to(device)
elif vocoder_name == "bigvgan":
raise NotImplementedError("BigVGAN is not implemented yet")
return vocoder
class VocosTensorRT:
def __init__(self, engine_path="./vocos_vocoder.plan", stream=None):
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="")
logger.info(f"Loading vocoder engine from {engine_path}")
self.engine_path = engine_path
with open(engine_path, "rb") as f:
engine_buffer = f.read()
self.session = Session.from_serialized_engine(engine_buffer)
self.stream = stream if stream is not None else torch.cuda.current_stream().cuda_stream
def decode(self, mels):
mels = mels.contiguous()
inputs = {"mel": mels}
output_info = self.session.infer_shapes([TensorInfo("mel", trt.DataType.FLOAT, mels.shape)])
outputs = {
t.name: torch.empty(tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device="cuda") for t in output_info
}
ok = self.session.run(inputs, outputs, self.stream)
assert ok, "Runtime execution failed for vae session"
samples = outputs["waveform"]
return samples
def main():
args = get_args()
os.makedirs(args.output_dir, exist_ok=True)
assert torch.cuda.is_available()
world_size, local_rank, rank = init_distributed()
device = torch.device(f"cuda:{local_rank}")
vocab_char_map, vocab_size = get_tokenizer(args.vocab_file, "custom")
tllm_model_dir = args.tllm_model_dir
with open(os.path.join(tllm_model_dir, "config.json")) as f:
tllm_model_config = json.load(f)
if args.backend_type == "trt":
model = F5TTS(
tllm_model_config,
debug_mode=False,
tllm_model_dir=tllm_model_dir,
model_path=args.model_path,
vocab_size=vocab_size,
)
elif args.backend_type == "pytorch":
from f5_tts.infer.utils_infer import load_model
from f5_tts.model import DiT
pretrained_config = tllm_model_config["pretrained_config"]
pt_model_config = dict(
dim=pretrained_config["hidden_size"],
depth=pretrained_config["num_hidden_layers"],
heads=pretrained_config["num_attention_heads"],
ff_mult=pretrained_config["ff_mult"],
text_dim=pretrained_config["text_dim"],
text_mask_padding=pretrained_config["text_mask_padding"],
conv_layers=pretrained_config["conv_layers"],
pe_attn_head=pretrained_config["pe_attn_head"],
# attn_backend="flash_attn",
# attn_mask_enabled=True,
)
model = load_model(DiT, pt_model_config, args.model_path)
vocoder = load_vocoder(
vocoder_name=args.vocoder, device=device, vocoder_trt_engine_path=args.vocoder_trt_engine_path
)
dataset = load_dataset(
"yuekai/seed_tts",
split=args.split_name,
trust_remote_code=True,
)
def add_estimated_duration(example):
prompt_audio_len = example["prompt_audio"]["array"].shape[0]
scale_factor = 1 + len(example["target_text"]) / len(example["prompt_text"])
estimated_duration = prompt_audio_len * scale_factor
example["estimated_duration"] = estimated_duration / example["prompt_audio"]["sampling_rate"]
return example
dataset = dataset.map(add_estimated_duration)
dataset = dataset.sort("estimated_duration", reverse=True)
if args.use_perf:
# dataset_list = [dataset.select(range(1)) for i in range(16)] # seq_len 1000
dataset_list_short = [dataset.select([24]) for i in range(8)] # seq_len 719
# dataset_list_long = [dataset.select([23]) for i in range(8)] # seq_len 2002
# dataset = datasets.concatenate_datasets(dataset_list_short + dataset_list_long)
dataset = datasets.concatenate_datasets(dataset_list_short)
if world_size > 1:
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
else:
# This would disable shuffling
sampler = None
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler,
shuffle=False,
num_workers=args.num_workers,
prefetch_factor=args.prefetch,
collate_fn=lambda x: data_collator(x, vocab_char_map, use_perf=args.use_perf),
)
total_steps = len(dataset)
if args.enable_warmup:
for batch in dataloader:
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
text_pad_seq = batch["text_pad_sequence"].to(device)
total_mel_lens = batch["estimated_reference_target_mel_len"]
cond_pad_seq = F.pad(ref_mels, (0, 0, 0, max(total_mel_lens) - ref_mels.shape[1], 0, 0))
if args.backend_type == "trt":
_ = model.sample(
text_pad_seq,
cond_pad_seq,
ref_mel_lens,
total_mel_lens,
remove_input_padding=args.remove_input_padding,
)
elif args.backend_type == "pytorch":
total_mel_lens = torch.tensor(total_mel_lens, device=device)
with torch.inference_mode():
generated, _ = model.sample(
cond=ref_mels,
text=text_pad_seq,
duration=total_mel_lens,
steps=32,
cfg_strength=2.0,
sway_sampling_coef=-1,
)
if rank == 0:
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
decoding_time = 0
vocoder_time = 0
total_duration = 0
if args.use_perf:
torch.cuda.cudart().cudaProfilerStart()
total_decoding_time = time.time()
for batch in dataloader:
if args.use_perf:
torch.cuda.nvtx.range_push("data sample")
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
text_pad_seq = batch["text_pad_sequence"].to(device)
total_mel_lens = batch["estimated_reference_target_mel_len"]
cond_pad_seq = F.pad(ref_mels, (0, 0, 0, max(total_mel_lens) - ref_mels.shape[1], 0, 0))
if args.use_perf:
torch.cuda.nvtx.range_pop()
if args.backend_type == "trt":
generated, cost_time = model.sample(
text_pad_seq,
cond_pad_seq,
ref_mel_lens,
total_mel_lens,
remove_input_padding=args.remove_input_padding,
use_perf=args.use_perf,
)
elif args.backend_type == "pytorch":
total_mel_lens = torch.tensor(total_mel_lens, device=device)
with torch.inference_mode():
start_time = time.time()
generated, _ = model.sample(
cond=ref_mels,
text=text_pad_seq,
duration=total_mel_lens,
lens=ref_mel_lens,
steps=32,
cfg_strength=2.0,
sway_sampling_coef=-1,
)
cost_time = time.time() - start_time
decoding_time += cost_time
vocoder_start_time = time.time()
target_rms = 0.1
target_sample_rate = 24000
for i, gen in enumerate(generated):
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
if args.vocoder == "vocos":
if args.use_perf:
torch.cuda.nvtx.range_push("vocoder decode")
generated_wave = vocoder.decode(gen_mel_spec).cpu()
if args.use_perf:
torch.cuda.nvtx.range_pop()
else:
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
if batch["ref_rms_list"][i] < target_rms:
generated_wave = generated_wave * batch["ref_rms_list"][i] / target_rms
utt = batch["ids"][i]
torchaudio.save(
f"{args.output_dir}/{utt}.wav",
generated_wave,
target_sample_rate,
)
total_duration += generated_wave.shape[1] / target_sample_rate
vocoder_time += time.time() - vocoder_start_time
if rank == 0:
progress_bar.update(world_size * len(batch["ids"]))
total_decoding_time = time.time() - total_decoding_time
if rank == 0:
progress_bar.close()
rtf = total_decoding_time / total_duration
s = f"RTF: {rtf:.4f}\n"
s += f"total_duration: {total_duration:.3f} seconds\n"
s += f"({total_duration / 3600:.2f} hours)\n"
s += f"DiT time: {decoding_time:.3f} seconds ({decoding_time / 3600:.2f} hours)\n"
s += f"Vocoder time: {vocoder_time:.3f} seconds ({vocoder_time / 3600:.2f} hours)\n"
s += f"total decoding time: {total_decoding_time:.3f} seconds ({total_decoding_time / 3600:.2f} hours)\n"
s += f"batch size: {args.batch_size}\n"
print(s)
with open(f"{args.output_dir}/rtf.txt", "w") as f:
f.write(s)
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
main()
================================================
FILE: src/f5_tts/runtime/triton_trtllm/client_grpc.py
================================================
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
# 2023 Nvidia (authors: Yuekai Zhang)
# 2023 Recurrent.ai (authors: Songtao Shi)
# See LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script supports to load dataset from huggingface and sends it to the server
for decoding, in parallel.
Usage:
num_task=2
# For offline F5-TTS
python3 client_grpc.py \
--server-addr localhost \
--model-name f5_tts \
--num-tasks $num_task \
--huggingface-dataset yuekai/seed_tts \
--split-name test_zh \
--log-dir ./log_concurrent_tasks_${num_task}
"""
import argparse
import asyncio
import json
import os
import time
import types
from pathlib import Path
import numpy as np
import soundfile as sf
import tritonclient
import tritonclient.grpc.aio as grpcclient
from tritonclient.utils import np_to_triton_dtype
def write_triton_stats(stats, summary_file):
with open(summary_file, "w") as summary_f:
model_stats = stats["model_stats"]
# write a note, the log is from triton_client.get_inference_statistics(), to better human readability
summary_f.write(
"The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
)
summary_f.write("To learn more about the log, please refer to: \n")
summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n")
summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n")
summary_f.write(
"To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
)
summary_f.write(
"However, there is a trade-off between the increased queue time and the increased batch size. \n"
)
summary_f.write(
"You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n"
)
summary_f.write(
"See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n"
)
for model_state in model_stats:
if "last_inference" not in model_state:
continue
summary_f.write(f"model name is {model_state['name']} \n")
model_inference_stats = model_state["inference_stats"]
total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
summary_f.write(
f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa
)
model_batch_stats = model_state["batch_stats"]
for batch in model_batch_stats:
batch_size = int(batch["batch_size"])
compute_input = batch["compute_input"]
compute_output = batch["compute_output"]
compute_infer = batch["compute_infer"]
batch_count = int(compute_infer["count"])
assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
compute_input_time_ms = int(compute_input["ns"]) / 1e6
compute_output_time_ms = int(compute_output["ns"]) / 1e6
summary_f.write(
f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" # noqa
)
summary_f.write(
f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " # noqa
)
summary_f.write(
f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n" # noqa
)
def get_args():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--server-addr",
type=str,
default="localhost",
help="Address of the server",
)
parser.add_argument(
"--server-port",
type=int,
default=8001,
help="Grpc port of the triton server, default is 8001",
)
parser.add_argument(
"--reference-audio",
type=str,
default=None,
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
)
parser.add_argument(
"--reference-text",
type=str,
default="",
help="",
)
parser.add_argument(
"--target-text",
type=str,
default="",
help="",
)
parser.add_argument(
"--huggingface-dataset",
type=str,
default="yuekai/seed_tts",
help="dataset name in huggingface dataset hub",
)
parser.add_argument(
"--split-name",
type=str,
default="wenetspeech4tts",
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
help="dataset split name, default is 'test'",
)
parser.add_argument(
"--manifest-path",
type=str,
default=None,
help="Path to the manifest dir which includes wav.scp trans.txt files.",
)
parser.add_argument(
"--model-name",
type=str,
default="f5_tts",
help="triton model_repo module name to request",
)
parser.add_argument(
"--num-tasks",
type=int,
default=1,
help="Number of concurrent tasks for sending",
)
parser.add_argument(
"--log-interval",
type=int,
default=5,
help="Controls how frequently we print the log.",
)
parser.add_argument(
"--compute-wer",
action="store_true",
default=False,
help="""True to compute WER.
""",
)
parser.add_argument(
"--log-dir",
type=str,
required=False,
default="./tests/client_grpc",
help="log directory",
)
parser.add_argument(
"--batch-size",
type=int,
default=1,
help="Inference batch_size per request for offline mode.",
)
return parser.parse_args()
def load_audio(wav_path, target_sample_rate=24000):
assert target_sample_rate == 24000, "hard coding in server"
if isinstance(wav_path, dict):
waveform = wav_path["array"]
sample_rate = wav_path["sampling_rate"]
else:
waveform, sample_rate = sf.read(wav_path)
if sample_rate != target_sample_rate:
from scipy.signal import resample
waveform = resample(waveform, int(len(waveform) * (target_sample_rate / sample_rate)))
return waveform, target_sample_rate
async def send(
manifest_item_list: list,
name: str,
triton_client: tritonclient.grpc.aio.InferenceServerClient,
protocol_client: types.ModuleType,
log_interval: int,
model_name: str,
padding_duration: int = None,
audio_save_dir: str = "./",
save_sample_rate: int = 24000,
):
total_duration = 0.0
latency_data = []
task_id = int(name[5:])
print(f"manifest_item_list: {manifest_item_list}")
for i, item in enumerate(manifest_item_list):
if i % log_interval == 0:
print(f"{name}: {i}/{len(manifest_item_list)}")
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=24000)
duration = len(waveform) / sample_rate
lengths = np.array([[len(waveform)]], dtype=np.int32)
reference_text, target_text = item["reference_text"], item["target_text"]
estimated_target_duration = duration / len(reference_text) * len(target_text)
if padding_duration:
# padding to nearset 10 seconds
samples = np.zeros(
(
1,
padding_duration
* sample_rate
* ((int(estimated_target_duration + duration) // padding_duration) + 1),
),
dtype=np.float32,
)
samples[0, : len(waveform)] = waveform
else:
samples = waveform
samples = samples.reshape(1, -1).astype(np.float32)
inputs = [
protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
protocol_client.InferInput("reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)),
protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
protocol_client.InferInput("target_text", [1, 1], "BYTES"),
]
inputs[0].set_data_from_numpy(samples)
inputs[1].set_data_from_numpy(lengths)
input_data_numpy = np.array([reference_text], dtype=object)
input_data_numpy = input_data_numpy.reshape((1, 1))
inputs[2].set_data_from_numpy(input_data_numpy)
input_data_numpy = np.array([target_text], dtype=object)
input_data_numpy = input_data_numpy.reshape((1, 1))
inputs[3].set_data_from_numpy(input_data_numpy)
outputs = [protocol_client.InferRequestedOutput("waveform")]
sequence_id = 100000000 + i + task_id * 10
start = time.time()
response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs)
audio = response.as_numpy("waveform").reshape(-1)
end = time.time() - start
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
actual_duration = len(audio) / save_sample_rate
latency_data.append((end, actual_duration))
total_duration += actual_duration
return total_duration, latency_data
def load_manifests(manifest_path):
with open(manifest_path, "r") as f:
manifest_list = []
for line in f:
assert len(line.strip().split("|")) == 4
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
utt = Path(utt).stem
# gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav")
if not os.path.isabs(prompt_wav):
prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
manifest_list.append(
{
"audio_filepath": prompt_wav,
"reference_text": prompt_text,
"target_text": gt_text,
"target_audio_path": utt,
}
)
return manifest_list
def split_data(data, k):
n = len(data)
if n < k:
print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.")
k = n
quotient = n // k
remainder = n % k
result = []
start = 0
for i in range(k):
if i < remainder:
end = start + quotient + 1
else:
end = start + quotient
result.append(data[start:end])
start = end
return result
async def main():
args = get_args()
url = f"{args.server_addr}:{args.server_port}"
triton_client = grpcclient.InferenceServerClient(url=url, verbose=False)
protocol_client = grpcclient
if args.reference_audio:
args.num_tasks = 1
args.log_interval = 1
manifest_item_list = [
{
"reference_text": args.reference_text,
"target_text": args.target_text,
"audio_filepath": args.reference_audio,
"target_audio_path": "test",
}
]
elif args.huggingface_dataset:
import datasets
dataset = datasets.load_dataset(
args.huggingface_dataset,
split=args.split_name,
trust_remote_code=True,
)
manifest_item_list = []
for i in range(len(dataset)):
manifest_item_list.append(
{
"audio_filepath": dataset[i]["prompt_audio"],
"reference_text": dataset[i]["prompt_text"],
"target_audio_path": dataset[i]["id"],
"target_text": dataset[i]["target_text"],
}
)
else:
manifest_item_list = load_manifests(args.manifest_path)
args.num_tasks = min(args.num_tasks, len(manifest_item_list))
manifest_item_list = split_data(manifest_item_list, args.num_tasks)
os.makedirs(args.log_dir, exist_ok=True)
tasks = []
start_time = time.time()
for i in range(args.num_tasks):
task = asyncio.create_task(
send(
manifest_item_list[i],
name=f"task-{i}",
triton_client=triton_client,
protocol_client=protocol_client,
log_interval=args.log_interval,
model_name=args.model_name,
audio_save_dir=args.log_dir,
padding_duration=1,
save_sample_rate=24000,
)
)
tasks.append(task)
ans_list = await asyncio.gather(*tasks)
end_time = time.time()
elapsed = end_time - start_time
total_duration = 0.0
latency_data = []
for ans in ans_list:
total_duration += ans[0]
latency_data += ans[1]
rtf = elapsed / total_duration
s = f"RTF: {rtf:.4f}\n"
s += f"total_duration: {total_duration:.3f} seconds\n"
s += f"({total_duration / 3600:.2f} hours)\n"
s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
s += f"latency_variance: {latency_variance:.2f}\n"
s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n"
s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
s += f"average_latency_ms: {latency_ms:.2f}\n"
print(s)
if args.manifest_path:
name = Path(args.manifest_path).stem
elif args.split_name:
name = args.split_name
with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
f.write(s)
stats = await triton_client.get_inference_statistics(model_name="", as_json=True)
write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
metadata = await triton_client.get_model_config(model_name=args.model_name, as_json=True)
with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
json.dump(metadata, f, indent=4)
if __name__ == "__main__":
asyncio.run(main())
================================================
FILE: src/f5_tts/runtime/triton_trtllm/client_http.py
================================================
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
import os
import numpy as np
import requests
import soundfile as sf
def get_args():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--server-url",
type=str,
default="localhost:8000",
help="Address of the server",
)
parser.add_argument(
"--reference-audio",
type=str,
default="../../infer/examples/basic/basic_ref_en.wav",
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
)
parser.add_argument(
"--reference-text",
type=str,
default="Some call me nature, others call me mother nature.",
help="",
)
parser.add_argument(
"--target-text",
type=str,
default="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring.",
help="",
)
parser.add_argument(
"--model-name",
type=str,
default="f5_tts",
help="triton model_repo module name to request",
)
parser.add_argument(
"--output-audio",
type=str,
default="tests/client_http.wav",
help="Path to save the output audio",
)
return parser.parse_args()
def prepare_request(
waveform,
reference_text,
target_text,
sample_rate=24000,
audio_save_dir: str = "./",
):
assert len(waveform.shape) == 1, "waveform should be 1D"
lengths = np.array([[len(waveform)]], dtype=np.int32)
waveform = waveform.reshape(1, -1).astype(np.float32)
data = {
"inputs": [
{"name": "reference_wav", "shape": waveform.shape, "datatype": "FP32", "data": waveform.tolist()},
{
"name": "reference_wav_len",
"shape": lengths.shape,
"datatype": "INT32",
"data": lengths.tolist(),
},
{"name": "reference_text", "shape": [1, 1], "datatype": "BYTES", "data": [reference_text]},
{"name": "target_text", "shape": [1, 1], "datatype": "BYTES", "data": [target_text]},
]
}
return data
def load_audio(wav_path, target_sample_rate=24000):
assert target_sample_rate == 24000, "hard coding in server"
if isinstance(wav_path, dict):
waveform = wav_path["array"]
sample_rate = wav_path["sampling_rate"]
else:
waveform, sample_rate = sf.read(wav_path)
if sample_rate != target_sample_rate:
from scipy.signal import resample
waveform = resample(waveform, int(len(waveform) * (target_sample_rate / sample_rate)))
return waveform, target_sample_rate
if __name__ == "__main__":
args = get_args()
server_url = args.server_url
if not server_url.startswith(("http://", "https://")):
server_url = f"http://{server_url}"
url = f"{server_url}/v2/models/{args.model_name}/infer"
waveform, sr = load_audio(args.reference_audio)
assert sr == 24000, "sample rate hardcoded in server"
waveform = np.array(waveform, dtype=np.float32)
data = prepare_request(waveform, args.reference_text, args.target_text)
rsp = requests.post(
url, headers={"Content-Type": "application/json"}, json=data, verify=False, params={"request_id": "0"}
)
result = rsp.json()
audio = result["outputs"][0]["data"]
audio = np.array(audio, dtype=np.float32)
os.makedirs(os.path.dirname(args.output_audio), exist_ok=True)
sf.write(args.output_audio, audio, 24000, "PCM_16")
================================================
FILE: src/f5_tts/runtime/triton_trtllm/docker-compose.yml
================================================
services:
tts:
image: soar97/triton-f5-tts:24.12
shm_size: '1gb'
ports:
- "8000:8000"
- "8001:8001"
- "8002:8002"
environment:
- PYTHONIOENCODING=utf-8
- MODEL_ID=${MODEL_ID}
deploy:
resources:
reservations:
devices:
- driver: nvidia
device_ids: ['0']
capabilities: [gpu]
command: >
/bin/bash -c "pip install vocos && rm -rf F5-TTS && git clone https://github.com/SWivid/F5-TTS.git && cd F5-TTS/src/f5_tts/runtime/triton_trtllm/ && bash run.sh 0 4 $MODEL"
================================================
FILE: src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py
================================================
import math
import os
import time
from functools import wraps
from typing import List, Optional
import tensorrt as trt
import tensorrt_llm
import torch
import torch.nn as nn
import torch.nn.functional as F
from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch
from tensorrt_llm.logger import logger
from tensorrt_llm.runtime.session import Session
from torch.nn.utils.rnn import pad_sequence
def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
# Audio tensor case: batch, seq_len, feature_len
# position_ids case: batch, seq_len
assert input_tensor_lengths is not None, "input_tensor_lengths must be provided for 3D input_tensor"
# Initialize a list to collect valid sequences
valid_sequences = []
for i in range(input_tensor.shape[0]):
valid_length = input_tensor_lengths[i]
valid_sequences.append(input_tensor[i, :valid_length])
# Concatenate all valid sequences along the batch dimension
output_tensor = torch.cat(valid_sequences, dim=0).contiguous()
return output_tensor
class TextEmbedding(nn.Module):
def __init__(
self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2, precompute_max_pos=4096
):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
self.mask_padding = mask_padding
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, precompute_max_pos), persistent=False)
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
def forward(self, text, seq_len, drop_text=False):
text = text + 1
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
text = F.pad(text, (0, seq_len - text.shape[1]), value=0)
if self.mask_padding:
text_mask = text == 0
if drop_text: # cfg for text
text = torch.zeros_like(text)
text = self.text_embed(text) # b n -> b n d
text = text + self.freqs_cis[:seq_len, :]
if self.mask_padding:
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
for block in self.text_blocks:
text = block(text)
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
else:
text = self.text_blocks(text)
return text
class GRN(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x
class ConvNeXtV2Block(nn.Module):
def __init__(
self,
dim: int,
intermediate_dim: int,
dilation: int = 1,
):
super().__init__()
padding = (dilation * (7 - 1)) // 2
self.dwconv = nn.Conv1d(
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
) # depthwise conv
self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.grn = GRN(intermediate_dim)
self.pwconv2 = nn.Linear(intermediate_dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = x.transpose(1, 2) # b n d -> b d n
x = self.dwconv(x)
x = x.transpose(1, 2) # b d n -> b n d
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.grn(x)
x = self.pwconv2(x)
return residual + x
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
# https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cos = torch.cos(freqs) # real part
freqs_sin = torch.sin(freqs) # imaginary part
return torch.cat([freqs_cos, freqs_sin], dim=-1)
def get_text_embed_dict(ckpt_path, use_ema=True):
ckpt_type = ckpt_path.split(".")[-1]
if ckpt_type == "safetensors":
from safetensors.torch import load_file
checkpoint = load_file(ckpt_path)
else:
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
if use_ema:
if ckpt_type == "safetensors":
checkpoint = {"ema_model_state_dict": checkpoint}
checkpoint["model_state_dict"] = {
k.replace("ema_model.", ""): v
for k, v in checkpoint["ema_model_state_dict"].items()
if k not in ["initted", "step"]
}
else:
if ckpt_type == "safetensors":
checkpoint = {"model_state_dict": checkpoint}
model_params = checkpoint["model_state_dict"]
text_embed_dict = {}
for key in model_params.keys():
# transformer.text_embed.text_embed.weight -> text_embed.weight
if "text_embed" in key:
text_embed_dict[key.replace("transformer.text_embed.", "")] = model_params[key]
return text_embed_dict
class F5TTS(object):
def __init__(
self,
config,
debug_mode=True,
stream: Optional[torch.cuda.Stream] = None,
tllm_model_dir: Optional[str] = None,
model_path: Optional[str] = None,
vocab_size: Optional[int] = None,
):
self.dtype = config["pretrained_config"]["dtype"]
rank = tensorrt_llm.mpi_rank()
world_size = config["pretrained_config"]["mapping"]["world_size"]
cp_size = config["pretrained_config"]["mapping"]["cp_size"]
tp_size = config["pretrained_config"]["mapping"]["tp_size"]
pp_size = config["pretrained_config"]["mapping"]["pp_size"]
assert pp_size == 1
self.mapping = tensorrt_llm.Mapping(
world_size=world_size, rank=rank, cp_size=cp_size, tp_size=tp_size, pp_size=1, gpus_per_node=1
)
local_rank = rank % self.mapping.gpus_per_node
self.device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(self.device)
self.stream = stream
if self.stream is None:
self.stream = torch.cuda.Stream(self.device)
torch.cuda.set_stream(self.stream)
engine_file = os.path.join(tllm_model_dir, f"rank{rank}.engine")
logger.info(f"Loading engine from {engine_file}")
with open(engine_file, "rb") as f:
engine_buffer = f.read()
assert engine_buffer is not None
self.session = Session.from_serialized_engine(engine_buffer)
self.debug_mode = debug_mode
self.inputs = {}
self.outputs = {}
self.buffer_allocated = False
expected_tensor_names = ["noise", "cond", "time", "rope_cos", "rope_sin", "input_lengths", "denoised"]
found_tensor_names = [self.session.engine.get_tensor_name(i) for i in range(self.session.engine.num_io_tensors)]
if not self.debug_mode and set(expected_tensor_names) != set(found_tensor_names):
logger.error(
f"The following expected tensors are not found: {set(expected_tensor_names).difference(set(found_tensor_names))}"
)
logger.error(
f"Those tensors in engine are not expected: {set(found_tensor_names).difference(set(expected_tensor_names))}"
)
logger.error(f"Expected tensor names: {expected_tensor_names}")
logger.error(f"Found tensor names: {found_tensor_names}")
raise RuntimeError("Tensor names in engine are not the same as expected.")
if self.debug_mode:
self.debug_tensors = list(set(found_tensor_names) - set(expected_tensor_names))
self.max_mel_len = 4096
self.text_embedding = TextEmbedding(
text_num_embeds=vocab_size,
text_dim=config["pretrained_config"]["text_dim"],
mask_padding=config["pretrained_config"]["text_mask_padding"],
conv_layers=config["pretrained_config"]["conv_layers"],
precompute_max_pos=self.max_mel_len,
).to(self.device)
self.text_embedding.load_state_dict(get_text_embed_dict(model_path), strict=True)
self.n_mel_channels = config["pretrained_config"]["mel_dim"]
self.head_dim = config["pretrained_config"]["dim_head"]
self.base_rescale_factor = 1.0
self.interpolation_factor = 1.0
base = 10000.0 * self.base_rescale_factor ** (self.head_dim / (self.head_dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
freqs = torch.outer(torch.arange(self.max_mel_len, dtype=torch.float32), inv_freq) / self.interpolation_factor
self.freqs = freqs.repeat_interleave(2, dim=-1).unsqueeze(0)
self.rope_cos = self.freqs.cos().half()
self.rope_sin = self.freqs.sin().half()
self.nfe_steps = 32
epss = {
5: [0, 2, 4, 8, 16, 32],
6: [0, 2, 4, 6, 8, 16, 32],
7: [0, 2, 4, 6, 8, 16, 24, 32],
10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32],
12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32],
16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32],
}
t = 1 / 32 * torch.tensor(epss.get(self.nfe_steps, list(range(self.nfe_steps + 1))), dtype=torch.float32)
time_step = 1 - torch.cos(torch.pi * t / 2)
delta_t = torch.diff(time_step)
freq_embed_dim = 256 # Warning: hard coding 256 here
time_expand = torch.zeros((1, self.nfe_steps, freq_embed_dim), dtype=torch.float32)
half_dim = freq_embed_dim // 2
emb_factor = math.log(10000) / (half_dim - 1)
emb_factor = 1000.0 * torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb_factor)
for i in range(self.nfe_steps):
emb = time_step[i] * emb_factor
time_expand[:, i, :] = torch.cat((emb.sin(), emb.cos()), dim=-1)
self.time_expand = time_expand.to(self.device)
self.delta_t = torch.cat((delta_t, delta_t), dim=0).contiguous().to(self.device)
def _tensor_dtype(self, name):
# return torch dtype given tensor name for convenience
dtype = trt_dtype_to_torch(self.session.engine.get_tensor_dtype(name))
return dtype
def _setup(self, batch_size, seq_len):
for i in range(self.session.engine.num_io_tensors):
name = self.session.engine.get_tensor_name(i)
if self.session.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
shape = list(self.session.engine.get_tensor_shape(name))
shape[0] = batch_size
shape[1] = seq_len
self.outputs[name] = torch.empty(shape, dtype=self._tensor_dtype(name), device=self.device)
self.buffer_allocated = True
def cuda_stream_guard(func):
"""Sync external stream and set current stream to the one bound to the session. Reset on exit."""
@wraps(func)
def wrapper(self, *args, **kwargs):
external_stream = torch.cuda.current_stream()
if external_stream != self.stream:
external_stream.synchronize()
torch.cuda.set_stream(self.stream)
ret = func(self, *args, **kwargs)
if external_stream != self.stream:
self.stream.synchronize()
torch.cuda.set_stream(external_stream)
return ret
return wrapper
@cuda_stream_guard
def forward(
self,
noise: torch.Tensor,
cond: torch.Tensor,
time_expand: torch.Tensor,
rope_cos: torch.Tensor,
rope_sin: torch.Tensor,
input_lengths: torch.Tensor,
delta_t: torch.Tensor,
use_perf: bool = False,
):
if use_perf:
torch.cuda.nvtx.range_push("flow matching")
cfg_strength = 2.0
batch_size = noise.shape[0]
half_batch = batch_size // 2
noise_half = noise[:half_batch] # Store the initial half of noise
input_type = str_dtype_to_torch(self.dtype)
# Keep a copy of the initial tensors
cond = cond.to(input_type)
rope_cos = rope_cos.to(input_type)
rope_sin = rope_sin.to(input_type)
input_lengths = input_lengths.to(str_dtype_to_torch("int32"))
# Instead of iteratively updating noise within a single model context,
# we'll do a single forward pass for each iteration with fresh context setup
for i in range(self.nfe_steps):
# Re-setup the buffers for clean execution
self._setup(batch_size, noise.shape[1])
if not self.buffer_allocated:
raise RuntimeError("Buffer not allocated, please call setup first!")
# Re-create combined noises for this iteration
current_noise = torch.cat([noise_half, noise_half], dim=0).to(input_type)
# Get time step for this iteration
current_time = time_expand[:, i].to(input_type)
# Create fresh input dictionary for this iteration
current_inputs = {
"noise": current_noise,
"cond": cond,
"time": current_time,
"rope_cos": rope_cos,
"rope_sin": rope_sin,
"input_lengths": input_lengths,
}
# Update inputs and set shapes
self.inputs.clear() # Clear previous inputs
self.inputs.update(**current_inputs)
self.session.set_shapes(self.inputs)
if use_perf:
torch.cuda.nvtx.range_push(f"execute {i}")
ok = self.session.run(self.inputs, self.outputs, self.stream.cuda_stream)
assert ok, "Failed to execute model"
# self.session.context.execute_async_v3(self.stream.cuda_stream)
if use_perf:
torch.cuda.nvtx.range_pop()
# Process results
t_scale = delta_t[i].unsqueeze(0).to(input_type)
# Extract predictions
pred_cond = self.outputs["denoised"][:half_batch]
pred_uncond = self.outputs["denoised"][half_batch:]
# Apply classifier-free guidance with safeguards
guidance = pred_cond + (pred_cond - pred_uncond) * cfg_strength
# Calculate update for noise
noise_half = noise_half + guidance * t_scale
if use_perf:
torch.cuda.nvtx.range_pop()
return noise_half
def sample(
self,
text_pad_sequence: torch.Tensor,
cond_pad_sequence: torch.Tensor,
ref_mel_len_batch: torch.Tensor,
estimated_reference_target_mel_len: List[int],
remove_input_padding: bool = False,
use_perf: bool = False,
):
if use_perf:
torch.cuda.nvtx.range_push("text embedding")
batch = text_pad_sequence.shape[0]
max_seq_len = cond_pad_sequence.shape[1]
# get text_embed one by one to avoid misalignment
text_and_drop_embedding_list = []
for i in range(batch):
text_embedding_i = self.text_embedding(
text_pad_sequence[i].unsqueeze(0).to(self.device),
estimated_reference_target_mel_len[i],
drop_text=False,
)
text_embedding_drop_i = self.text_embedding(
text_pad_sequence[i].unsqueeze(0).to(self.device),
estimated_reference_target_mel_len[i],
drop_text=True,
)
text_and_drop_embedding_list.extend([text_embedding_i[0], text_embedding_drop_i[0]])
# pad separately computed text_embed to form batch with max_seq_len
text_and_drop_embedding = pad_sequence(
text_and_drop_embedding_list,
batch_first=True,
padding_value=0,
)
text_embedding = text_and_drop_embedding[0::2]
text_embedding_drop = text_and_drop_embedding[1::2]
noise = torch.randn_like(cond_pad_sequence).to(self.device)
rope_cos = self.rope_cos[:, :max_seq_len, :].float().repeat(batch, 1, 1)
rope_sin = self.rope_sin[:, :max_seq_len, :].float().repeat(batch, 1, 1)
cat_mel_text = torch.cat(
(
cond_pad_sequence,
text_embedding,
),
dim=-1,
)
cat_mel_text_drop = torch.cat(
(
torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device),
text_embedding_drop,
),
dim=-1,
)
time_expand = self.time_expand.repeat(2 * batch, 1, 1).contiguous()
# Convert estimated_reference_target_mel_len to tensor
input_lengths = torch.tensor(estimated_reference_target_mel_len, dtype=torch.int32)
# combine above along the batch dimension
inputs = {
"noise": torch.cat((noise, noise), dim=0).contiguous(),
"cond": torch.cat((cat_mel_text, cat_mel_text_drop), dim=0).contiguous(),
"time_expand": time_expand,
"rope_cos": torch.cat((rope_cos, rope_cos), dim=0).contiguous(),
"rope_sin": torch.cat((rope_sin, rope_sin), dim=0).contiguous(),
"input_lengths": torch.cat((input_lengths, input_lengths), dim=0).contiguous(),
"delta_t": self.delta_t,
}
if use_perf and remove_input_padding:
torch.cuda.nvtx.range_push("remove input padding")
if remove_input_padding:
max_seq_len = inputs["cond"].shape[1]
inputs["noise"] = remove_tensor_padding(inputs["noise"], inputs["input_lengths"])
inputs["cond"] = remove_tensor_padding(inputs["cond"], inputs["input_lengths"])
# for time_expand, convert from B,D to B,T,D by repeat
inputs["time_expand"] = inputs["time_expand"].unsqueeze(1).repeat(1, max_seq_len, 1, 1)
inputs["time_expand"] = remove_tensor_padding(inputs["time_expand"], inputs["input_lengths"])
inputs["rope_cos"] = remove_tensor_padding(inputs["rope_cos"], inputs["input_lengths"])
inputs["rope_sin"] = remove_tensor_padding(inputs["rope_sin"], inputs["input_lengths"])
if use_perf and remove_input_padding:
torch.cuda.nvtx.range_pop()
for key in inputs:
inputs[key] = inputs[key].to(self.device)
if use_perf:
torch.cuda.nvtx.range_pop()
start_time = time.time()
denoised = self.forward(**inputs, use_perf=use_perf)
cost_time = time.time() - start_time
if use_perf and remove_input_padding:
torch.cuda.nvtx.range_push("remove input padding output")
if remove_input_padding:
denoised_list = []
start_idx = 0
for i in range(batch):
denoised_list.append(denoised[start_idx : start_idx + inputs["input_lengths"][i]])
start_idx += inputs["input_lengths"][i]
if use_perf and remove_input_padding:
torch.cuda.nvtx.range_pop()
return denoised_list, cost_time
return denoised, cost_time
================================================
FILE: src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py
================================================
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import os
import rjieba
import torch
import torchaudio
import triton_python_backend_utils as pb_utils
from f5_tts_trtllm import F5TTS
from pypinyin import Style, lazy_pinyin
from torch.nn.utils.rnn import pad_sequence
from torch.utils.dlpack import from_dlpack, to_dlpack
def get_tokenizer(vocab_file_path: str):
"""
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
- "char" for char-wise tokenizer, need .txt vocab_file
- "byte" for utf-8 tokenizer
- "custom" if you're directly passing in a path to the vocab.txt you want to use
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
- if use "char", derived from unfiltered character & symbol counts of custom dataset
- if use "byte", set to 256 (unicode byte range)
"""
with open(vocab_file_path, "r", encoding="utf-8") as f:
vocab_char_map = {}
for i, char in enumerate(f):
vocab_char_map[char[:-1]] = i
vocab_size = len(vocab_char_map)
return vocab_char_map, vocab_size
def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
final_reference_target_texts_list = []
custom_trans = str.maketrans(
{";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
) # add custom trans here, to address oov
def is_chinese(c):
return "\u3100" <= c <= "\u9fff" # common chinese characters
for text in reference_target_texts_list:
char_list = []
text = text.translate(custom_trans)
for seg in rjieba.cut(text):
seg_byte_len = len(bytes(seg, "UTF-8"))
if seg_byte_len == len(seg): # if pure alphabets and symbols
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
char_list.append(" ")
char_list.extend(seg)
elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
for i, c in enumerate(seg):
if is_chinese(c):
char_list.append(" ")
char_list.append(seg_[i])
else: # if mixed characters, alphabets and symbols
for c in seg:
if ord(c) < 256:
char_list.extend(c)
elif is_chinese(c):
char_list.append(" ")
char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
else:
char_list.append(c)
final_reference_target_texts_list.append(char_list)
return final_reference_target_texts_list
def list_str_to_idx(
text: list[str] | list[list[str]],
vocab_char_map: dict[str, int], # {char: idx}
padding_value=-1,
): # noqa: F722
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
return text
class TritonPythonModel:
def initialize(self, args):
self.use_perf = True
self.device = torch.device("cuda")
self.target_audio_sample_rate = 24000
self.target_rms = 0.1 # least rms when inference, normalize to if lower
self.n_fft = 1024
self.win_length = 1024
self.hop_length = 256
self.n_mel_channels = 100
self.max_mel_len = 4096
parameters = json.loads(args["model_config"])["parameters"]
for key, value in parameters.items():
parameters[key] = value["string_value"]
self.vocab_char_map, self.vocab_size = get_tokenizer(parameters["vocab_file"])
self.reference_sample_rate = int(parameters["reference_audio_sample_rate"])
self.resampler = torchaudio.transforms.Resample(self.reference_sample_rate, self.target_audio_sample_rate)
self.tllm_model_dir = parameters["tllm_model_dir"]
config_file = os.path.join(self.tllm_model_dir, "config.json")
with open(config_file) as f:
config = json.load(f)
self.model = F5TTS(
config,
debug_mode=False,
tllm_model_dir=self.tllm_model_dir,
model_path=parameters["model_path"],
vocab_size=self.vocab_size,
)
self.vocoder = parameters["vocoder"]
assert self.vocoder in ["vocos", "bigvgan"]
if self.vocoder == "vocos":
self.mel_stft = torchaudio.transforms.MelSpectrogram(
sample_rate=self.target_audio_sample_rate,
n_fft=self.n_fft,
win_length=self.win_length,
hop_length=self.hop_length,
n_mels=self.n_mel_channels,
power=1,
center=True,
normalized=False,
norm=None,
).to(self.device)
self.compute_mel_fn = self.get_vocos_mel_spectrogram
elif self.vocoder == "bigvgan":
self.compute_mel_fn = self.get_bigvgan_mel_spectrogram
def get_vocos_mel_spectrogram(self, waveform):
mel = self.mel_stft(waveform)
mel = mel.clamp(min=1e-5).log()
return mel.transpose(1, 2)
def forward_vocoder(self, mel):
mel = mel.to(torch.float32).contiguous().cpu()
input_tensor_0 = pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel))
inference_request = pb_utils.InferenceRequest(
model_name="vocoder", requested_output_names=["waveform"], inputs=[input_tensor_0]
)
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
else:
waveform = pb_utils.get_output_tensor_by_name(inference_response, "waveform")
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
return waveform
def execute(self, requests):
(
reference_text_list,
target_text_list,
reference_target_texts_list,
estimated_reference_target_mel_len,
reference_mel_len,
reference_rms_list,
) = [], [], [], [], [], []
mel_features_list = []
if self.use_perf:
torch.cuda.nvtx.range_push("preprocess")
for request in requests:
wav_tensor = pb_utils.get_input_tensor_by_name(request, "reference_wav")
wav_lens = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
reference_text = reference_text[0][0].decode("utf-8")
reference_text_list.append(reference_text)
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
target_text = target_text[0][0].decode("utf-8")
target_text_list.append(target_text)
text = reference_text + target_text
reference_target_texts_list.append(text)
wav = from_dlpack(wav_tensor.to_dlpack())
wav_len = from_dlpack(wav_lens.to_dlpack())
wav_len = wav_len.squeeze()
assert wav.shape[0] == 1, "Only support batch size 1 for now."
wav = wav[:, :wav_len]
ref_rms = torch.sqrt(torch.mean(torch.square(wav)))
if ref_rms < self.target_rms:
wav = wav * self.target_rms / ref_rms
reference_rms_list.append(ref_rms)
if self.reference_sample_rate != self.target_audio_sample_rate:
wav = self.resampler(wav)
wav = wav.to(self.device)
if self.use_perf:
torch.cuda.nvtx.range_push("compute_mel")
mel_features = self.compute_mel_fn(wav)
if self.use_perf:
torch.cuda.nvtx.range_pop()
mel_features_list.append(mel_features)
reference_mel_len.append(mel_features.shape[1])
estimated_reference_target_mel_len.append(
int(
mel_features.shape[1] * (1 + len(target_text.encode("utf-8")) / len(reference_text.encode("utf-8")))
)
)
max_seq_len = min(max(estimated_reference_target_mel_len), self.max_mel_len)
batch = len(requests)
mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device)
for i, mel in enumerate(mel_features_list):
mel_features[i, : mel.shape[1], :] = mel
reference_mel_len_tensor = torch.LongTensor(reference_mel_len).to(self.device)
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
text_pad_sequence = list_str_to_idx(pinyin_list, self.vocab_char_map)
if self.use_perf:
torch.cuda.nvtx.range_pop()
denoised, cost_time = self.model.sample(
text_pad_sequence,
mel_features,
reference_mel_len_tensor,
estimated_reference_target_mel_len,
remove_input_padding=False,
use_perf=self.use_perf,
)
if self.use_perf:
torch.cuda.nvtx.range_push("vocoder")
responses = []
for i in range(batch):
ref_mel_len = reference_mel_len[i]
estimated_mel_len = estimated_reference_target_mel_len[i]
denoised_one_item = denoised[i, ref_mel_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
audio = self.forward_vocoder(denoised_one_item)
if reference_rms_list[i] < self.target_rms:
audio = audio * reference_rms_list[i] / self.target_rms
audio = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio])
responses.append(inference_response)
if self.use_perf:
torch.cuda.nvtx.range_pop()
return responses
================================================
FILE: src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/config.pbtxt
================================================
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: "f5_tts"
backend: "python"
max_batch_size: 4
dynamic_batching {
max_queue_delay_microseconds: 1000
}
parameters [
{
key: "vocab_file"
value: { string_value: "${vocab}"}
},
{
key: "model_path",
value: {string_value:"${model}"}
},
{
key: "tllm_model_dir",
value: {string_value:"${trtllm}"}
},
{
key: "reference_audio_sample_rate",
value: {string_value:"24000"}
},
{
key: "vocoder",
value: {string_value:"${vocoder}"}
}
]
input [
{
name: "reference_wav"
data_type: TYPE_FP32
dims: [-1]
optional: True
},
{
name: "reference_wav_len"
data_type: TYPE_INT32
dims: [1]
optional: True
},
{
name: "reference_text"
data_type: TYPE_STRING
dims: [1]
},
{
name: "target_text"
data_type: TYPE_STRING
dims: [1]
}
]
output [
{
name: "waveform"
data_type: TYPE_FP32
dims: [ -1 ]
}
]
instance_group [
{
count: 1
kind: KIND_GPU
}
]
================================================
FILE: src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/vocoder/1/.gitkeep
================================================
================================================
FILE: src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/vocoder/config.pbtxt
================================================
name: "vocoder"
backend: "tensorrt"
default_model_filename: "vocoder.plan"
max_batch_size: 4
input [
{
name: "mel"
data_type: TYPE_FP32
dims: [ 100, -1 ]
}
]
output [
{
name: "waveform"
data_type: TYPE_FP32
dims: [ -1 ]
}
]
dynamic_batching {
preferred_batch_size: [1, 2, 4]
max_queue_delay_microseconds: 1
}
instance_group [
{
count: 1
kind: KIND_GPU
}
]
================================================
FILE: src/f5_tts/runtime/triton_trtllm/patch/__init__.py
================================================
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .baichuan.model import BaichuanForCausalLM
from .bert.model import (
BertForQuestionAnswering,
BertForSequenceClassification,
BertModel,
RobertaForQuestionAnswering,
RobertaForSequenceClassification,
RobertaModel,
)
from .bloom.model import BloomForCausalLM, BloomModel
from .chatglm.config import ChatGLMConfig
from .chatglm.model import ChatGLMForCausalLM, ChatGLMModel
from .cogvlm.config import CogVLMConfig
from .cogvlm.model import CogVLMForCausalLM
from .commandr.model import CohereForCausalLM
from .dbrx.config import DbrxConfig
from .dbrx.model import DbrxForCausalLM
from .deepseek_v1.model import DeepseekForCausalLM
from .deepseek_v2.model import DeepseekV2ForCausalLM
from .dit.model import DiT
from .eagle.model import EagleForCausalLM
from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder
from .f5tts.model import F5TTS
from .falcon.config import FalconConfig
from .falcon.model import FalconForCausalLM, FalconModel
from .gemma.config import GEMMA2_ARCHITECTURE, GEMMA_ARCHITECTURE, GemmaConfig
from .gemma.model import GemmaForCausalLM
from .gpt.config import GPTConfig
from .gpt.model import GPTForCausalLM, GPTModel
from .gptj.config import GPTJConfig
from .gptj.model import GPTJForCausalLM, GPTJModel
from .gptneox.model import GPTNeoXForCausalLM, GPTNeoXModel
from .grok.model import GrokForCausalLM
from .llama.config import LLaMAConfig
from .llama.model import LLaMAForCausalLM, LLaMAModel
from .mamba.model import MambaForCausalLM
from .medusa.config import MedusaConfig
from .medusa.model import MedusaForCausalLm
from .mllama.model import MLLaMAModel
from .modeling_utils import PretrainedConfig, PretrainedModel, SpeculativeDecodingMode
from .mpt.model import MPTForCausalLM, MPTModel
from .nemotron_nas.model import DeciLMForCausalLM
from .opt.model import OPTForCausalLM, OPTModel
from .phi.model import PhiForCausalLM, PhiModel
from .phi3.model import Phi3ForCausalLM, Phi3Model
from .qwen.model import QWenForCausalLM
from .recurrentgemma.model import RecurrentGemmaForCausalLM
from .redrafter.model import ReDrafterForCausalLM
__all__ = [
"BertModel",
"BertForQuestionAnswering",
"BertForSequenceClassification",
"RobertaModel",
"RobertaForQuestionAnswering",
"RobertaForSequenceClassification",
"BloomModel",
"BloomForCausalLM",
"DiT",
"DeepseekForCausalLM",
"FalconConfig",
"DeepseekV2ForCausalLM",
"FalconForCausalLM",
"FalconModel",
"GPTConfig",
"GPTModel",
"GPTForCausalLM",
"OPTForCausalLM",
"OPTModel",
"LLaMAConfig",
"LLaMAForCausalLM",
"LLaMAModel",
"MedusaConfig",
"MedusaForCausalLm",
"ReDrafterForCausalLM",
"GPTJConfig",
"GPTJModel",
"GPTJForCausalLM",
"GPTNeoXModel",
"GPTNeoXForCausalLM",
"PhiModel",
"PhiConfig",
"Phi3Model",
"Phi3Config",
"PhiForCausalLM",
"Phi3ForCausalLM",
"ChatGLMConfig",
"ChatGLMForCausalLM",
"ChatGLMModel",
"BaichuanForCausalLM",
"QWenConfigQWenForCausalLM",
"QWenModel",
"EncoderModel",
"DecoderModel",
"PretrainedConfig",
"PretrainedModel",
"WhisperEncoder",
"MambaForCausalLM",
"MambaConfig",
"MPTForCausalLM",
"MPTModel",
"SkyworkForCausalLM",
"GemmaConfig",
"GemmaForCausalLM",
"DbrxConfig",
"DbrxForCausalLM",
"RecurrentGemmaForCausalLM",
"CogVLMConfig",
"CogVLMForCausalLM",
"EagleForCausalLM",
"SpeculativeDecodingMode",
"CohereForCausalLM",
"MLLaMAModel",
"F5TTS",
]
MODEL_MAP = {
"GPT2LMHeadModel": GPTForCausalLM,
"GPT2LMHeadCustomModel": GPTForCausalLM,
"GPTBigCodeForCausalLM": GPTForCausalLM,
"Starcoder2ForCausalLM": GPTForCausalLM,
"FuyuForCausalLM": GPTForCausalLM,
"Kosmos2ForConditionalGeneration": GPTForCausalLM,
"JAISLMHeadModel": GPTForCausalLM,
"GPTForCausalLM": GPTForCausalLM,
"NemotronForCausalLM": GPTForCausalLM,
"OPTForCausalLM": OPTForCausalLM,
"BloomForCausalLM": BloomForCausalLM,
"RWForCausalLM": FalconForCausalLM,
"FalconForCausalLM": FalconForCausalLM,
"PhiForCausalLM": PhiForCausalLM,
"Phi3ForCausalLM": Phi3ForCausalLM,
"Phi3VForCausalLM": Phi3ForCausalLM,
"Phi3SmallForCausalLM": Phi3ForCausalLM,
"PhiMoEForCausalLM": Phi3ForCausalLM,
"MambaForCausalLM": MambaForCausalLM,
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,
"GPTJForCausalLM": GPTJForCausalLM,
"MPTForCausalLM": MPTForCausalLM,
"GLMModel": ChatGLMForCausalLM,
"ChatGLMModel": ChatGLMForCausalLM,
"ChatGLMForCausalLM": ChatGLMForCausalLM,
"LlamaForCausalLM": LLaMAForCausalLM,
"ExaoneForCausalLM": LLaMAForCausalLM,
"MistralForCausalLM": LLaMAForCausalLM,
"MixtralForCausalLM": LLaMAForCausalLM,
"ArcticForCausalLM": LLaMAForCausalLM,
"Grok1ModelForCausalLM": GrokForCausalLM,
"InternLMForCausalLM": LLaMAForCausalLM,
"InternLM2ForCausalLM": LLaMAForCausalLM,
"MedusaForCausalLM": MedusaForCausalLm,
"ReDrafterForCausalLM": ReDrafterForCausalLM,
"BaichuanForCausalLM": BaichuanForCausalLM,
"BaiChuanForCausalLM": BaichuanForCausalLM,
"SkyworkForCausalLM": LLaMAForCausalLM,
GEMMA_ARCHITECTURE: GemmaForCausalLM,
GEMMA2_ARCHITECTURE: GemmaForCausalLM,
"QWenLMHeadModel": QWenForCausalLM,
"QWenForCausalLM": QWenForCausalLM,
"Qwen2ForCausalLM": QWenForCausalLM,
"Qwen2MoeForCausalLM": QWenForCausalLM,
"Qwen2ForSequenceClassification": QWenForCausalLM,
"Qwen2VLForConditionalGeneration": QWenForCausalLM,
"WhisperEncoder": WhisperEncoder,
"EncoderModel": EncoderModel,
"DecoderModel": DecoderModel,
"DbrxForCausalLM": DbrxForCausalLM,
"RecurrentGemmaForCausalLM": RecurrentGemmaForCausalLM,
"CogVLMForCausalLM": CogVLMForCausalLM,
"DiT": DiT,
"DeepseekForCausalLM": DeepseekForCausalLM,
"DeciLMForCausalLM": DeciLMForCausalLM,
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
"EagleForCausalLM": EagleForCausalLM,
"CohereForCausalLM": CohereForCausalLM,
"MllamaForConditionalGeneration": MLLaMAModel,
"BertForQuestionAnswering": BertForQuestionAnswering,
"BertForSequenceClassification": BertForSequenceClassification,
"BertModel": BertModel,
"RobertaModel": RobertaModel,
"RobertaForQuestionAnswering": RobertaForQuestionAnswering,
"RobertaForSequenceClassification": RobertaForSequenceClassification,
"F5TTS": F5TTS,
}
================================================
FILE: src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py
================================================
from __future__ import annotations
import os
import sys
from collections import OrderedDict
import numpy as np
import tensorrt as trt
from tensorrt_llm._common import default_net
from ..._utils import str_dtype_to_trt
from ...functional import (
Tensor,
concat,
constant,
expand,
shape,
slice,
unsqueeze,
)
from ...layers import Linear
from ...module import Module, ModuleList
from ...plugin import current_all_reduce_helper
from ..modeling_utils import PretrainedConfig, PretrainedModel
from .modules import AdaLayerNormZero_Final, ConvPositionEmbedding, DiTBlock, TimestepEmbedding
current_file_path = os.path.abspath(__file__)
parent_dir = os.path.dirname(current_file_path)
sys.path.append(parent_dir)
class InputEmbedding(Module):
def __init__(self, mel_dim, text_dim, out_dim):
super().__init__()
self.proj = Linear(mel_dim * 2 + text_dim, out_dim)
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
def forward(self, x, cond, mask=None):
x = self.proj(concat([x, cond], dim=-1))
return self.conv_pos_embed(x, mask=mask) + x
class F5TTS(PretrainedModel):
def __init__(self, config: PretrainedConfig):
super().__init__(config)
self.dtype = str_dtype_to_trt(config.dtype)
self.time_embed = TimestepEmbedding(config.hidden_size)
self.input_embed = InputEmbedding(config.mel_dim, config.text_dim, config.hidden_size)
self.dim = config.hidden_size
self.depth = config.num_hidden_layers
self.transformer_blocks = ModuleList(
[
DiTBlock(
dim=self.dim,
heads=config.num_attention_heads,
dim_head=config.dim_head,
ff_mult=config.ff_mult,
dropout=config.dropout,
pe_attn_head=config.pe_attn_head,
)
for _ in range(self.depth)
]
)
self.norm_out = AdaLayerNormZero_Final(config.hidden_size) # final modulation
self.proj_out = Linear(config.hidden_size, config.mel_dim)
def forward(
self,
noise, # nosied input audio
cond, # masked cond audio
time, # time step
rope_cos,
rope_sin,
input_lengths,
scale=1.0,
):
if default_net().plugin_config.remove_input_padding:
mask = None
else:
N = shape(noise, 1)
B = shape(noise, 0)
seq_len_2d = concat([1, N])
max_position_embeddings = 4096
# create position ids
position_ids_buffer = constant(np.expand_dims(np.arange(max_position_embeddings).astype(np.int32), 0))
tmp_position_ids = slice(position_ids_buffer, starts=[0, 0], sizes=seq_len_2d)
tmp_position_ids = expand(tmp_position_ids, concat([B, N])) # [B, N]
tmp_input_lengths = unsqueeze(input_lengths, 1) # [B, 1]
tmp_input_lengths = expand(tmp_input_lengths, concat([B, N])) # [B, N]
mask = tmp_position_ids < tmp_input_lengths # [B, N]
mask = mask.cast("int32")
t = self.time_embed(time)
x = self.input_embed(noise, cond, mask=mask)
for block in self.transformer_blocks:
x = block(x, t, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale, mask=mask)
denoise = self.proj_out(self.norm_out(x, t))
denoise.mark_output("denoised", self.dtype)
return denoise
def prepare_inputs(self, **kwargs):
max_batch_size = kwargs["max_batch_size"]
batch_size_range = [2, 2, max_batch_size]
mel_size = self.config.mel_dim
max_seq_len = 3000 # 4096
num_frames_range = [mel_size * 2, max_seq_len * 2, max_seq_len * max_batch_size]
concat_feature_dim = mel_size + self.config.text_dim
freq_embed_dim = 256 # Warning: hard coding 256 here
head_dim = self.config.dim_head
mapping = self.config.mapping
if mapping.tp_size > 1:
current_all_reduce_helper().set_workspace_tensor(mapping, 1)
if default_net().plugin_config.remove_input_padding:
noise = Tensor(
name="noise",
dtype=self.dtype,
shape=[-1, mel_size],
dim_range=OrderedDict(
[
("num_frames", [num_frames_range]),
("n_mels", [mel_size]),
]
),
)
cond = Tensor(
name="cond",
dtype=self.dtype,
shape=[-1, concat_feature_dim],
dim_range=OrderedDict(
[
("num_frames", [num_frames_range]),
("embeded_length", [concat_feature_dim]),
]
),
)
time = Tensor(
name="time",
dtype=self.dtype,
shape=[-1, freq_embed_dim],
dim_range=OrderedDict(
[
("num_frames", [num_frames_range]),
("freq_dim", [freq_embed_dim]),
]
),
)
rope_cos = Tensor(
name="rope_cos",
dtype=self.dtype,
shape=[-1, head_dim],
dim_range=OrderedDict(
[
("num_frames", [num_frames_range]),
("head_dim", [head_dim]),
]
),
)
rope_sin = Tensor(
name="rope_sin",
dtype=self.dtype,
shape=[-1, head_dim],
dim_range=OrderedDict(
[
("num_frames", [num_frames_range]),
("head_dim", [head_dim]),
]
),
)
else:
noise = Tensor(
name="noise",
dtype=self.dtype,
shape=[-1, -1, mel_size],
dim_range=OrderedDict(
[
("batch_size", [batch_size_range]),
("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
("n_mels", [mel_size]),
]
),
)
cond = Tensor(
name="cond",
dtype=self.dtype,
shape=[-1, -1, concat_feature_dim],
dim_range=OrderedDict(
[
("batch_size", [batch_size_range]),
("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
("embeded_length", [concat_feature_dim]),
]
),
)
time = Tensor(
name="time",
dtype=self.dtype,
shape=[-1, freq_embed_dim],
dim_range=OrderedDict(
[
("batch_size", [batch_size_range]),
("freq_dim", [freq_embed_dim]),
]
),
)
rope_cos = Tensor(
name="rope_cos",
dtype=self.dtype,
shape=[-1, -1, head_dim],
dim_range=OrderedDict(
[
("batch_size", [batch_size_range]),
("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
("head_dim", [head_dim]),
]
),
)
rope_sin = Tensor(
name="rope_sin",
dtype=self.dtype,
shape=[-1, -1, head_dim],
dim_range=OrderedDict(
[
("batch_size", [batch_size_range]),
("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
("head_dim", [head_dim]),
]
),
)
input_lengths = Tensor(
name="input_lengths",
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([("batch_size", [batch_size_range])]),
)
return {
"noise": noise,
"cond": cond,
"time": time,
"rope_cos": rope_cos,
"rope_sin": rope_sin,
"input_lengths": input_lengths,
}
================================================
FILE: src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py
================================================
from __future__ import annotations
import math
from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
from tensorrt_llm._common import default_net
from ..._utils import str_dtype_to_trt, trt_dtype_to_np
from ...functional import (
Tensor,
bert_attention,
cast,
chunk,
concat,
constant,
expand_dims,
expand_dims_like,
expand_mask,
gelu,
matmul,
permute,
shape,
silu,
slice,
softmax,
squeeze,
unsqueeze,
view,
)
from ...layers import ColumnLinear, Conv1d, LayerNorm, Linear, Mish, RowLinear
from ...module import Module
class FeedForward(Module):
def __init__(self, dim, dim_out=None, mult=4, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
self.project_in = Linear(dim, inner_dim)
self.ff = Linear(inner_dim, dim_out)
def forward(self, x):
return self.ff(gelu(self.project_in(x)))
class AdaLayerNormZero(Module):
def __init__(self, dim):
super().__init__()
self.linear = Linear(dim, dim * 6)
self.norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb=None):
emb = self.linear(silu(emb))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = chunk(emb, 6, dim=1)
x = self.norm(x)
ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
if default_net().plugin_config.remove_input_padding:
x = x * (ones + scale_msa) + shift_msa
else:
x = x * (ones + unsqueeze(scale_msa, 1)) + unsqueeze(shift_msa, 1)
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class AdaLayerNormZero_Final(Module):
def __init__(self, dim):
super().__init__()
self.linear = Linear(dim, dim * 2)
self.norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb):
emb = self.linear(silu(emb))
scale, shift = chunk(emb, 2, dim=1)
ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
if default_net().plugin_config.remove_input_padding:
x = self.norm(x) * (ones + scale) + shift
else:
x = self.norm(x) * unsqueeze((ones + scale), 1)
x = x + unsqueeze(shift, 1)
return x
class ConvPositionEmbedding(Module):
def __init__(self, dim, kernel_size=31, groups=16):
super().__init__()
assert kernel_size % 2 != 0
self.conv1d1 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2)
self.conv1d2 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2)
self.mish = Mish()
def forward(self, x, mask=None):
if default_net().plugin_config.remove_input_padding:
x = unsqueeze(x, 0)
if mask is not None:
mask = mask.view(concat([shape(mask, 0), 1, shape(mask, 1)])) # [B 1 N]
mask = expand_dims_like(mask, x) # [B D N]
mask = cast(mask, x.dtype)
x = permute(x, [0, 2, 1]) # [B D N]
if mask is not None:
x = self.mish(self.conv1d2(self.mish(self.conv1d1(x * mask) * mask)) * mask)
else:
x = self.mish(self.conv1d2(self.mish(self.conv1d1(x))))
x = permute(x, [0, 2, 1]) # [B N D]
if default_net().plugin_config.remove_input_padding:
x = squeeze(x, 0)
return x
class Attention(Module):
def __init__(
self,
processor: AttnProcessor,
dim: int,
heads: int = 16,
dim_head: int = 64,
dropout: float = 0.0,
context_dim: Optional[int] = None, # if not None -> joint attention
context_pre_only=None,
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.processor = processor
self.dim = dim # hidden_size
self.heads = heads
self.inner_dim = dim_head * heads
self.dropout = dropout
self.attention_head_size = dim_head
self.context_dim = context_dim
self.context_pre_only = context_pre_only
self.tp_size = 1
self.num_attention_heads = heads // self.tp_size
self.num_attention_kv_heads = heads // self.tp_size # 8
self.dtype = str_dtype_to_trt("float32")
self.attention_hidden_size = self.attention_head_size * self.num_attention_heads
self.to_q = ColumnLinear(
dim,
self.tp_size * self.num_attention_heads * self.attention_head_size,
bias=True,
dtype=self.dtype,
tp_group=None,
tp_size=self.tp_size,
)
self.to_k = ColumnLinear(
dim,
self.tp_size * self.num_attention_heads * self.attention_head_size,
bias=True,
dtype=self.dtype,
tp_group=None,
tp_size=self.tp_size,
)
self.to_v = ColumnLinear(
dim,
self.tp_size * self.num_attention_heads * self.attention_head_size,
bias=True,
dtype=self.dtype,
tp_group=None,
tp_size=self.tp_size,
)
if self.context_dim is not None:
self.to_k_c = Linear(context_dim, self.inner_dim)
self.to_v_c = Linear(context_dim, self.inner_dim)
if self.context_pre_only is not None:
self.to_q_c = Linear(context_dim, self.inner_dim)
self.to_out = RowLinear(
self.tp_size * self.num_attention_heads * self.attention_head_size,
dim,
bias=True,
dtype=self.dtype,
tp_group=None,
tp_size=self.tp_size,
)
if self.context_pre_only is not None and not self.context_pre_only:
self.to_out_c = Linear(self.inner_dim, dim)
def forward(
self,
x, # noised input x
rope_cos,
rope_sin,
input_lengths,
mask=None,
c=None, # context c
scale=1.0,
rope=None,
c_rope=None, # rotary position embedding for c
) -> torch.Tensor:
if c is not None:
return self.processor(self, x, c=c, input_lengths=input_lengths, scale=scale, rope=rope, c_rope=c_rope)
else:
return self.processor(
self, x, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale
)
def rotate_every_two_3dim(tensor: Tensor) -> Tensor:
shape_tensor = concat(
[shape(tensor, i) / 2 if i == (tensor.ndim() - 1) else shape(tensor, i) for i in range(tensor.ndim())]
)
if default_net().plugin_config.remove_input_padding:
assert tensor.ndim() == 2
x1 = slice(tensor, [0, 0], shape_tensor, [1, 2])
x2 = slice(tensor, [0, 1], shape_tensor, [1, 2])
x1 = expand_dims(x1, 2)
x2 = expand_dims(x2, 2)
zero = constant(np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype))))
x2 = zero - x2
x = concat([x2, x1], 2)
out = view(x, concat([shape(x, 0), shape(x, 1) * 2]))
else:
assert tensor.ndim() == 3
x1 = slice(tensor, [0, 0, 0], shape_tensor, [1, 1, 2])
x2 = slice(tensor, [0, 0, 1], shape_tensor, [1, 1, 2])
x1 = expand_dims(x1, 3)
x2 = expand_dims(x2, 3)
zero = constant(np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype))))
x2 = zero - x2
x = concat([x2, x1], 3)
out = view(x, concat([shape(x, 0), shape(x, 1), shape(x, 2) * 2]))
return out
def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin, pe_attn_head):
full_dim = x.size(-1)
head_dim = rope_cos.size(-1) # attn head dim, e.g. 64
if pe_attn_head is None:
pe_attn_head = full_dim // head_dim
rotated_dim = head_dim * pe_attn_head
rotated_and_unrotated_list = []
if default_net().plugin_config.remove_input_padding: # for [N, D] input
new_t_shape = concat([shape(x, 0), head_dim]) # (2, -1, 64)
for i in range(pe_attn_head):
x_slice_i = slice(x, [0, i * 64], new_t_shape, [1, 1])
x_rotated_i = x_slice_i * rope_cos + rotate_every_two_3dim(x_slice_i) * rope_sin
rotated_and_unrotated_list.append(x_rotated_i)
new_t_unrotated_shape = concat([shape(x, 0), full_dim - rotated_dim]) # (2, -1, 1024 - 64 * pe_attn_head)
x_unrotated = slice(x, concat([0, rotated_dim]), new_t_unrotated_shape, [1, 1])
rotated_and_unrotated_list.append(x_unrotated)
else: # for [B, N, D] input
new_t_shape = concat([shape(x, 0), shape(x, 1), head_dim]) # (2, -1, 64)
for i in range(pe_attn_head):
x_slice_i = slice(x, [0, 0, i * 64], new_t_shape, [1, 1, 1])
x_rotated_i = x_slice_i * rope_cos + rotate_every_two_3dim(x_slice_i) * rope_sin
rotated_and_unrotated_list.append(x_rotated_i)
new_t_unrotated_shape = concat(
[shape(x, 0), shape(x, 1), full_dim - rotated_dim]
) # (2, -1, 1024 - 64 * pe_attn_head)
x_unrotated = slice(x, concat([0, 0, rotated_dim]), new_t_unrotated_shape, [1, 1, 1])
rotated_and_unrotated_list.append(x_unrotated)
out = concat(rotated_and_unrotated_list, dim=-1)
return out
class AttnProcessor:
def __init__(
self,
pe_attn_head: Optional[int] = None, # number of attention head to apply rope, None for all
):
self.pe_attn_head = pe_attn_head
def __call__(
self,
attn,
x, # noised input x
rope_cos,
rope_sin,
input_lengths,
scale=1.0,
rope=None,
mask=None,
) -> torch.FloatTensor:
query = attn.to_q(x)
key = attn.to_k(x)
value = attn.to_v(x)
# k,v,q all (2,1226,1024)
query = apply_rotary_pos_emb_3dim(query, rope_cos, rope_sin, self.pe_attn_head)
key = apply_rotary_pos_emb_3dim(key, rope_cos, rope_sin, self.pe_attn_head)
# attention
inner_dim = key.shape[-1]
norm_factor = math.sqrt(attn.attention_head_size)
q_scaling = 1.0 / norm_factor
if default_net().plugin_config.remove_input_padding:
mask = None
if default_net().plugin_config.bert_attention_plugin:
qkv = concat([query, key, value], dim=-1)
# TRT plugin mode
assert input_lengths is not None
if default_net().plugin_config.remove_input_padding:
qkv = qkv.view(concat([-1, 3 * inner_dim]))
max_input_length = constant(
np.zeros(
[
2048,
],
dtype=np.int32,
)
)
else:
max_input_length = None
context = bert_attention(
qkv,
input_lengths,
attn.num_attention_heads,
attn.attention_head_size,
q_scaling=q_scaling,
max_input_length=max_input_length,
)
else:
assert not default_net().plugin_config.remove_input_padding
def transpose_for_scores(x):
new_x_shape = concat([shape(x, 0), shape(x, 1), attn.num_attention_heads, attn.attention_head_size])
y = x.view(new_x_shape)
y = y.transpose(1, 2)
return y
def transpose_for_scores_k(x):
new_x_shape = concat([shape(x, 0), shape(x, 1), attn.num_attention_heads, attn.attention_head_size])
y = x.view(new_x_shape)
y = y.permute([0, 2, 3, 1])
return y
query = transpose_for_scores(query)
key = transpose_for_scores_k(key)
value = transpose_for_scores(value)
attention_scores = matmul(query, key, use_fp32_acc=False)
if mask is not None:
attention_mask = expand_mask(mask, shape(query, 2))
attention_mask = cast(attention_mask, attention_scores.dtype)
attention_scores = attention_scores + attention_mask
attention_probs = softmax(attention_scores, dim=-1)
context = matmul(attention_probs, value, use_fp32_acc=False).transpose(1, 2)
context = context.view(concat([shape(context, 0), shape(context, 1), attn.attention_hidden_size]))
context = attn.to_out(context)
if mask is not None:
mask = mask.view(concat([shape(mask, 0), shape(mask, 1), 1]))
mask = expand_dims_like(mask, context)
mask = cast(mask, context.dtype)
context = context * mask
return context
# DiT Block
class DiTBlock(Module):
def __init__(self, dim, heads, dim_head, ff_mult=2, dropout=0.1, pe_attn_head=None):
super().__init__()
self.attn_norm = AdaLayerNormZero(dim)
self.attn = Attention(
processor=AttnProcessor(pe_attn_head=pe_attn_head),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
)
self.ff_norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout)
def forward(
self, x, t, rope_cos, rope_sin, input_lengths, scale=1.0, rope=ModuleNotFoundError, mask=None
): # x: noised input, t: time embedding
# pre-norm & modulation for attention input
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
# attention
# norm ----> (2,1226,1024)
attn_output = self.attn(
x=norm, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale, mask=mask
)
# process attention output for input x
if default_net().plugin_config.remove_input_padding:
x = x + gate_msa * attn_output
else:
x = x + unsqueeze(gate_msa, 1) * attn_output
ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
if default_net().plugin_config.remove_input_padding:
norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp
else:
norm = self.ff_norm(x) * (ones + unsqueeze(scale_mlp, 1)) + unsqueeze(shift_mlp, 1)
# norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp
ff_output = self.ff(norm)
if default_net().plugin_config.remove_input_padding:
x = x + gate_mlp * ff_output
else:
x = x + unsqueeze(gate_mlp, 1) * ff_output
return x
class TimestepEmbedding(Module):
def __init__(self, dim, freq_embed_dim=256, dtype=None):
super().__init__()
# self.time_embed = SinusPositionEmbedding(freq_embed_dim)
self.mlp1 = Linear(freq_embed_dim, dim, bias=True, dtype=dtype)
self.mlp2 = Linear(dim, dim, bias=True, dtype=dtype)
def forward(self, timestep):
t_freq = self.mlp1(timestep)
t_freq = silu(t_freq)
t_emb = self.mlp2(t_freq)
return t_emb
================================================
FILE: src/f5_tts/runtime/triton_trtllm/run.sh
================================================
stage=$1
stop_stage=$2
model=$3 # F5TTS_v1_Base | F5TTS_Base | F5TTS_v1_Small | F5TTS_Small
if [ -z "$model" ]; then
model=F5TTS_v1_Base
fi
echo "Start stage: $stage, Stop stage: $stop_stage, Model: $model"
export CUDA_VISIBLE_DEVICES=0
CKPT_DIR=../../../../ckpts
TRTLLM_CKPT_DIR=$CKPT_DIR/$model/trtllm_ckpt
TRTLLM_ENGINE_DIR=$CKPT_DIR/$model/trtllm_engine
VOCODER_ONNX_PATH=$CKPT_DIR/vocos_vocoder.onnx
VOCODER_TRT_ENGINE_PATH=$CKPT_DIR/vocos_vocoder.plan
MODEL_REPO=./model_repo
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
echo "Downloading F5-TTS from huggingface"
huggingface-cli download SWivid/F5-TTS $model/model_*.* $model/vocab.txt --local-dir $CKPT_DIR
fi
ckpt_file=$(ls $CKPT_DIR/$model/model_*.* 2>/dev/null | sort -V | tail -1) # default select latest update
vocab_file=$CKPT_DIR/$model/vocab.txt
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
echo "Converting checkpoint"
python3 scripts/convert_checkpoint.py \
--pytorch_ckpt $ckpt_file \
--output_dir $TRTLLM_CKPT_DIR --model_name $model
python_package_path=/usr/local/lib/python3.12/dist-packages
cp -r patch/* $python_package_path/tensorrt_llm/models
trtllm-build --checkpoint_dir $TRTLLM_CKPT_DIR \
--max_batch_size 8 \
--output_dir $TRTLLM_ENGINE_DIR --remove_input_padding disable
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
echo "Exporting vocos vocoder"
python3 scripts/export_vocoder_to_onnx.py --vocoder vocos --output-path $VOCODER_ONNX_PATH
bash scripts/export_vocos_trt.sh $VOCODER_ONNX_PATH $VOCODER_TRT_ENGINE_PATH
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
echo "Building triton server"
rm -r $MODEL_REPO
cp -r ./model_repo_f5_tts $MODEL_REPO
python3 scripts/fill_template.py -i $MODEL_REPO/f5_tts/config.pbtxt vocab:$vocab_file,model:$ckpt_file,trtllm:$TRTLLM_ENGINE_DIR,vocoder:vocos
cp $VOCODER_TRT_ENGINE_PATH $MODEL_REPO/vocoder/1/vocoder.plan
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
echo "Starting triton server"
tritonserver --model-repository=$MODEL_REPO
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
echo "Testing triton server"
num_task=1
split_name=wenetspeech4tts
log_dir=./tests/client_grpc_${model}_concurrent_${num_task}_${split_name}
rm -r $log_dir
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name $split_name --log-dir $log_dir
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
echo "Testing http client"
audio=../../infer/examples/basic/basic_ref_en.wav
reference_text="Some call me nature, others call me mother nature."
target_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text" --output-audio "./tests/client_http_$model.wav"
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
echo "TRT-LLM: offline decoding benchmark test"
batch_size=2
split_name=wenetspeech4tts
backend_type=trt
log_dir=./tests/benchmark_${model}_batch_size_${batch_size}_${split_name}_${backend_type}
rm -r $log_dir
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--enable-warmup \
--split-name $split_name \
--model-path $ckpt_file \
--vocab-file $vocab_file \
--vocoder-trt-engine-path $VOCODER_TRT_ENGINE_PATH \
--backend-type $backend_type \
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
echo "Native Pytorch: offline decoding benchmark test"
if ! python3 -c "import f5_tts" &> /dev/null; then
pip install -e ../../../../
fi
batch_size=1 # set attn_mask_enabled=True if batching in actual use case
split_name=wenetspeech4tts
backend_type=pytorch
log_dir=./tests/benchmark_${model}_batch_size_${batch_size}_${split_name}_${backend_type}
rm -r $log_dir
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--split-name $split_name \
--enable-warmup \
--model-path $ckpt_file \
--vocab-file $vocab_file \
--backend-type $backend_type \
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
fi
================================================
FILE: src/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py
================================================
# Modified from https://github.com/echocatzh/conv-stft/blob/master/conv_stft/conv_stft.py
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# MIT License
# Copyright (c) 2020 Shimin Zhang
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import torch as th
import torch.nn.functional as F
from scipy.signal import check_COLA, get_window
support_clp_op = None
if th.__version__ >= "1.7.0":
from torch.fft import rfft as fft
support_clp_op = True
else:
from torch import rfft as fft
class STFT(th.nn.Module):
def __init__(
self,
win_len=1024,
win_hop=512,
fft_len=1024,
enframe_mode="continue",
win_type="hann",
win_sqrt=False,
pad_center=True,
):
"""
Implement of STFT using 1D convolution and 1D transpose convolutions.
Implement of framing the signal in 2 ways, `break` and `continue`.
`break` method is a kaldi-like framing.
`continue` method is a librosa-like framing.
More information about `perfect reconstruction`:
1. https://ww2.mathworks.cn/help/signal/ref/stft.html
2. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.get_window.html
Args:
win_len (int): Number of points in one frame. Defaults to 1024.
win_hop (int): Number of framing stride. Defaults to 512.
fft_len (int): Number of DFT points. Defaults to 1024.
enframe_mode (str, optional): `break` and `continue`. Defaults to 'continue'.
win_type (str, optional): The type of window to create. Defaults to 'hann'.
win_sqrt (bool, optional): using square root window. Defaults to True.
pad_center (bool, optional): `perfect reconstruction` opts. Defaults to True.
"""
super(STFT, self).__init__()
assert enframe_mode in ["break", "continue"]
assert fft_len >= win_len
self.win_len = win_len
self.win_hop = win_hop
self.fft_len = fft_len
self.mode = enframe_mode
self.win_type = win_type
self.win_sqrt = win_sqrt
self.pad_center = pad_center
self.pad_amount = self.fft_len // 2
en_k, fft_k, ifft_k, ola_k = self.__init_kernel__()
self.register_buffer("en_k", en_k)
self.register_buffer("fft_k", fft_k)
self.register_buffer("ifft_k", ifft_k)
self.register_buffer("ola_k", ola_k)
def __init_kernel__(self):
"""
Generate enframe_kernel, fft_kernel, ifft_kernel and overlap-add kernel.
** enframe_kernel: Using conv1d layer and identity matrix.
** fft_kernel: Using linear layer for matrix multiplication. In fact,
enframe_kernel and fft_kernel can be combined, But for the sake of
readability, I took the two apart.
** ifft_kernel, pinv of fft_kernel.
** overlap-add kernel, just like enframe_kernel, but transposed.
Returns:
tuple: four kernels.
"""
enframed_kernel = th.eye(self.fft_len)[:, None, :]
if support_clp_op:
tmp = fft(th.eye(self.fft_len))
fft_kernel = th.stack([tmp.real, tmp.imag], dim=2)
else:
fft_kernel = fft(th.eye(self.fft_len), 1)
if self.mode == "break":
enframed_kernel = th.eye(self.win_len)[:, None, :]
fft_kernel = fft_kernel[: self.win_len]
fft_kernel = th.cat((fft_kernel[:, :, 0], fft_kernel[:, :, 1]), dim=1)
ifft_kernel = th.pinverse(fft_kernel)[:, None, :]
window = get_window(self.win_type, self.win_len)
self.perfect_reconstruct = check_COLA(window, self.win_len, self.win_len - self.win_hop)
window = th.FloatTensor(window)
if self.mode == "continue":
left_pad = (self.fft_len - self.win_len) // 2
right_pad = left_pad + (self.fft_len - self.win_len) % 2
window = F.pad(window, (left_pad, right_pad))
if self.win_sqrt:
self.padded_window = window
window = th.sqrt(window)
else:
self.padded_window = window**2
fft_kernel = fft_kernel.T * window
ifft_kernel = ifft_kernel * window
ola_kernel = th.eye(self.fft_len)[: self.win_len, None, :]
if self.mode == "continue":
ola_kernel = th.eye(self.fft_len)[:, None, : self.fft_len]
return enframed_kernel, fft_kernel, ifft_kernel, ola_kernel
def is_perfect(self):
"""
Whether the parameters win_len, win_hop and win_sqrt
obey constants overlap-add(COLA)
Returns:
bool: Return true if parameters obey COLA.
"""
return self.perfect_reconstruct and self.pad_center
def transform(self, inputs, return_type="complex"):
"""Take input data (audio) to STFT domain.
Args:
inputs (tensor): Tensor of floats, with shape (num_batch, num_samples)
return_type (str, optional): return (mag, phase) when `magphase`,
return (real, imag) when `realimag` and complex(real, imag) when `complex`.
Defaults to 'complex'.
Returns:
tuple: (mag, phase) when `magphase`, return (real, imag) when
`realimag`. Defaults to 'complex', each elements with shape
[num_batch, num_frequencies, num_frames]
"""
assert return_type in ["magphase", "realimag", "complex"]
if inputs.dim() == 2:
inputs = th.unsqueeze(inputs, 1)
self.num_samples = inputs.size(-1)
if self.pad_center:
inputs = F.pad(inputs, (self.pad_amount, self.pad_amount), mode="reflect")
enframe_inputs = F.conv1d(inputs, self.en_k, stride=self.win_hop)
outputs = th.transpose(enframe_inputs, 1, 2)
outputs = F.linear(outputs, self.fft_k)
outputs = th.transpose(outputs, 1, 2)
dim = self.fft_len // 2 + 1
real = outputs[:, :dim, :]
imag = outputs[:, dim:, :]
if return_type == "realimag":
return real, imag
elif return_type == "complex":
assert support_clp_op
return th.complex(real, imag)
else:
mags = th.sqrt(real**2 + imag**2)
phase = th.atan2(imag, real)
return mags, phase
def inverse(self, input1, input2=None, input_type="magphase"):
"""Call the inverse STFT (iSTFT), given tensors produced
by the `transform` function.
Args:
input1 (tensors): Magnitude/Real-part of STFT with shape
[num_batch, num_frequencies, num_frames]
input2 (tensors): Phase/Imag-part of STFT with shape
[num_batch, num_frequencies, num_frames]
input_type (str, optional): Mathematical meaning of input tensor's.
Defaults to 'magphase'.
Returns:
tensors: Reconstructed audio given magnitude and phase. Of
shape [num_batch, num_samples]
"""
assert input_type in ["magphase", "realimag"]
if input_type == "realimag":
real, imag = None, None
if support_clp_op and th.is_complex(input1):
real, imag = input1.real, input1.imag
else:
real, imag = input1, input2
else:
real = input1 * th.cos(input2)
imag = input1 * th.sin(input2)
inputs = th.cat([real, imag], dim=1)
outputs = F.conv_transpose1d(inputs, self.ifft_k, stride=self.win_hop)
t = (self.padded_window[None, :, None]).repeat(1, 1, inputs.size(-1))
t = t.to(inputs.device)
coff = F.conv_transpose1d(t, self.ola_k, stride=self.win_hop)
num_frames = input1.size(-1)
num_samples = num_frames * self.win_hop
rm_start, rm_end = self.pad_amount, self.pad_amount + num_samples
outputs = outputs[..., rm_start:rm_end]
coff = coff[..., rm_start:rm_end]
coffidx = th.where(coff > 1e-8)
outputs[coffidx] = outputs[coffidx] / (coff[coffidx])
return outputs.squeeze(dim=1)
def forward(self, inputs):
"""Take input data (audio) to STFT domain and then back to audio.
Args:
inputs (tensor): Tensor of floats, with shape [num_batch, num_samples]
Returns:
tensor: Reconstructed audio given magnitude and phase.
Of shape [num_batch, num_samples]
"""
mag, phase = self.transform(inputs)
rec_wav = self.inverse(mag, phase)
return rec_wav
================================================
FILE: src/f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py
================================================
import argparse
import json
import os
import re
import time
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
import safetensors.torch
import torch
from tensorrt_llm import str_dtype_to_torch
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.convert_utils import split, split_matrix_tp
def split_q_tp(v, n_head, n_hidden, tensor_parallel, rank):
split_v = split(v, tensor_parallel, rank, dim=1)
return split_v.contiguous()
def split_q_bias_tp(v, n_head, n_hidden, tensor_parallel, rank):
split_v = split(v, tensor_parallel, rank, dim=0)
return split_v.contiguous()
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--pytorch_ckpt", type=str, default="./ckpts/model_last.pt")
parser.add_argument(
"--output_dir", type=str, default="./tllm_checkpoint", help="The path to save the TensorRT-LLM checkpoint"
)
parser.add_argument("--tp_size", type=int, default=1, help="N-way tensor parallelism size")
parser.add_argument("--cp_size", type=int, default=1, help="Context parallelism size")
parser.add_argument("--pp_size", type=int, default=1, help="N-way pipeline parallelism size")
parser.add_argument("--dtype", type=str, default="float16", choices=["float32", "bfloat16", "float16"])
parser.add_argument("--fp8_linear", action="store_true", help="Whether use FP8 for linear layers")
parser.add_argument(
"--workers", type=int, default=1, help="The number of workers for converting checkpoint in parallel"
)
parser.add_argument(
"--model_name",
type=str,
default="F5TTS_Custom",
choices=[
"F5TTS_v1_Base",
"F5TTS_Base",
"F5TTS_v1_Small",
"F5TTS_Small",
], # if set, overwrite the below hyperparams
)
parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT")
parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers")
parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module")
parser.add_argument("--dim_head", type=int, default=64, help="The dimension of attention head")
parser.add_argument("--ff_mult", type=int, default=2, help="The FFN intermediate dimension multiplier")
parser.add_argument("--text_dim", type=int, default=512, help="The output dimension of text encoder")
parser.add_argument(
"--text_mask_padding",
type=lambda x: x.lower() == "true",
choices=[True, False],
default=True,
help="Whether apply padding mask for conv layers in text encoder",
)
parser.add_argument("--conv_layers", type=int, default=4, help="The number of conv layers of text encoder")
parser.add_argument("--pe_attn_head", type=int, default=None, help="The number of attn head that apply pos emb")
args = parser.parse_args()
# overwrite if --model_name ordered
if args.model_name == "F5TTS_v1_Base":
args.hidden_size = 1024
args.depth = 22
args.num_heads = 16
args.dim_head = 64
args.ff_mult = 2
args.text_dim = 512
args.text_mask_padding = True
args.conv_layers = 4
args.pe_attn_head = None
elif args.model_name == "F5TTS_Base":
args.hidden_size = 1024
args.depth = 22
args.num_heads = 16
args.dim_head = 64
args.ff_mult = 2
args.text_dim = 512
args.text_mask_padding = False
args.conv_layers = 4
args.pe_attn_head = 1
elif args.model_name == "F5TTS_v1_Small":
args.hidden_size = 768
args.depth = 18
args.num_heads = 12
args.dim_head = 64
args.ff_mult = 2
args.text_dim = 512
args.text_mask_padding = True
args.conv_layers = 4
args.pe_attn_head = None
elif args.model_name == "F5TTS_Small":
args.hidden_size = 768
args.depth = 18
args.num_heads = 12
args.dim_head = 64
args.ff_mult = 2
args.text_dim = 512
args.text_mask_padding = False
args.conv_layers = 4
args.pe_attn_head = 1
return args
def convert_pytorch_dit_to_trtllm_weight(args, mapping, dtype="float32", use_ema=True):
weights = {}
tik = time.time()
torch_dtype = str_dtype_to_torch(dtype)
tensor_parallel = mapping.tp_size
ckpt_path = args.pytorch_ckpt
ckpt_type = ckpt_path.split(".")[-1]
if ckpt_type == "safetensors":
from safetensors.torch import load_file
model_params = load_file(ckpt_path)
else:
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
model_params = ckpt["ema_model_state_dict"] if use_ema else ckpt["model_state_dict"]
prefix = "ema_model.transformer." if use_ema else "transformer."
if any(k.startswith(prefix) for k in model_params.keys()):
model_params = {
key[len(prefix) :] if key.startswith(prefix) else key: value
for key, value in model_params.items()
if key.startswith(prefix)
}
pytorch_to_trtllm_name = {
r"^time_embed\.time_mlp\.0\.(weight|bias)$": r"time_embed.mlp1.\1",
r"^time_embed\.time_mlp\.2\.(weight|bias)$": r"time_embed.mlp2.\1",
r"^input_embed\.conv_pos_embed\.conv1d\.0\.(weight|bias)$": r"input_embed.conv_pos_embed.conv1d1.\1",
r"^input_embed\.conv_pos_embed\.conv1d\.2\.(weight|bias)$": r"input_embed.conv_pos_embed.conv1d2.\1",
r"^transformer_blocks\.(\d+)\.attn\.to_out\.0\.(weight|bias)$": r"transformer_blocks.\1.attn.to_out.\2",
r"^transformer_blocks\.(\d+)\.ff\.ff\.0\.0\.(weight|bias)$": r"transformer_blocks.\1.ff.project_in.\2",
r"^transformer_blocks\.(\d+)\.ff\.ff\.2\.(weight|bias)$": r"transformer_blocks.\1.ff.ff.\2",
}
def get_trtllm_name(pytorch_name):
for pytorch_name_pattern, trtllm_name_replacement in pytorch_to_trtllm_name.items():
trtllm_name_if_matched = re.sub(pytorch_name_pattern, trtllm_name_replacement, pytorch_name)
if trtllm_name_if_matched != pytorch_name:
return trtllm_name_if_matched
return pytorch_name
weights = dict()
for name, param in model_params.items():
if name == "input_embed.conv_pos_embed.conv1d.0.weight" or name == "input_embed.conv_pos_embed.conv1d.2.weight":
weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype).unsqueeze(-1)
else:
weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype)
assert len(weights) == len(model_params)
# new_prefix = "f5_transformer."
new_prefix = ""
weights = {new_prefix + key: value for key, value in weights.items()}
import math
scale_factor = math.pow(64, -0.25)
for k, v in weights.items():
if re.match("^transformer_blocks.*.attn.to_k.weight$", k):
weights[k] *= scale_factor
weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
elif re.match("^transformer_blocks.*.attn.to_k.bias$", k):
weights[k] *= scale_factor
weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
elif re.match("^transformer_blocks.*.attn.to_q.weight$", k):
weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
weights[k] *= scale_factor
elif re.match("^transformer_blocks.*.attn.to_q.bias$", k):
weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
weights[k] *= scale_factor
elif re.match("^transformer_blocks.*.attn.to_v.weight$", k):
weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
elif re.match("^transformer_blocks.*.attn.to_v.bias$", k):
weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
elif re.match("^transformer_blocks.*.attn.to_out.weight$", k):
weights[k] = split_matrix_tp(v, tensor_parallel, mapping.tp_rank, dim=1)
tok = time.time()
t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
print(f"Weights loaded. Total time: {t}")
return weights
def save_config(args):
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
config = {
"architecture": "F5TTS", # set the same as in ../patch/__init__.py
"dtype": args.dtype,
"hidden_size": args.hidden_size,
"num_hidden_layers": args.depth,
"num_attention_heads": args.num_heads,
"dim_head": args.dim_head,
"dropout": 0.0, # inference-only
"ff_mult": args.ff_mult,
"mel_dim": 100,
"text_dim": args.text_dim,
"text_mask_padding": args.text_mask_padding,
"conv_layers": args.conv_layers,
"pe_attn_head": args.pe_attn_head,
"mapping": {
"world_size": args.cp_size * args.tp_size * args.pp_size,
"cp_size": args.cp_size,
"tp_size": args.tp_size,
"pp_size": args.pp_size,
},
}
if args.fp8_linear:
config["quantization"] = {
"quant_algo": "FP8",
# TODO: add support for exclude modules.
# "exclude_modules": "*final_layer*",
}
with open(os.path.join(args.output_dir, "config.json"), "w") as f:
json.dump(config, f, indent=4)
def covert_and_save(args, rank):
if rank == 0:
save_config(args)
mapping = Mapping(
world_size=args.cp_size * args.tp_size * args.pp_size,
rank=rank,
cp_size=args.cp_size,
tp_size=args.tp_size,
pp_size=args.pp_size,
)
weights = convert_pytorch_dit_to_trtllm_weight(args, mapping, dtype=args.dtype)
safetensors.torch.save_file(weights, os.path.join(args.output_dir, f"rank{rank}.safetensors"))
def execute(workers, func, args):
if workers == 1:
for rank, f in enumerate(func):
f(args, rank)
else:
with ThreadPoolExecutor(max_workers=workers) as p:
futures = [p.submit(f, args, rank) for rank, f in enumerate(func)]
exceptions = []
for future in as_completed(futures):
try:
future.result()
except Exception as e:
traceback.print_exc()
exceptions.append(e)
assert len(exceptions) == 0, "Checkpoint conversion failed, please check error log."
def main():
args = parse_arguments()
world_size = args.cp_size * args.tp_size * args.pp_size
assert args.pp_size == 1, "PP is not supported yet."
tik = time.time()
if args.pytorch_ckpt is None:
return
print("Start execute")
execute(args.workers, [covert_and_save] * world_size, args)
tok = time.time()
t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
print(f"Total time of converting checkpoints: {t}")
if __name__ == "__main__":
main()
================================================
FILE: src/f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py
================================================
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import torch
import torch.nn as nn
from conv_stft import STFT
from huggingface_hub import hf_hub_download
from vocos import Vocos
opset_version = 17
def get_args():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--vocoder",
type=str,
default="vocos",
choices=["vocos", "bigvgan"],
help="Vocoder to export",
)
parser.add_argument(
"--output-path",
type=str,
default="./vocos_vocoder.onnx",
help="Output path",
)
return parser.parse_args()
class ISTFTHead(nn.Module):
def __init__(self, n_fft: int, hop_length: int):
super().__init__()
self.out = None
self.stft = STFT(fft_len=n_fft, win_hop=hop_length, win_len=n_fft)
def forward(self, x: torch.Tensor):
x = self.out(x).transpose(1, 2)
mag, p = x.chunk(2, dim=1)
mag = torch.exp(mag)
mag = torch.clip(mag, max=1e2)
real = mag * torch.cos(p)
imag = mag * torch.sin(p)
audio = self.stft.inverse(input1=real, input2=imag, input_type="realimag")
return audio
class VocosVocoder(nn.Module):
def __init__(self, vocos_vocoder):
super(VocosVocoder, self).__init__()
self.vocos_vocoder = vocos_vocoder
istft_head_out = self.vocos_vocoder.head.out
n_fft = self.vocos_vocoder.head.istft.n_fft
hop_length = self.vocos_vocoder.head.istft.hop_length
istft_head_for_export = ISTFTHead(n_fft, hop_length)
istft_head_for_export.out = istft_head_out
self.vocos_vocoder.head = istft_head_for_export
def forward(self, mel):
waveform = self.vocos_vocoder.decode(mel)
return waveform
def export_VocosVocoder(vocos_vocoder, output_path, verbose):
vocos_vocoder = VocosVocoder(vocos_vocoder).cuda()
vocos_vocoder.eval()
dummy_batch_size = 8
dummy_input_length = 500
dummy_mel = torch.randn(dummy_batch_size, 100, dummy_input_length).cuda()
with torch.no_grad():
dummy_waveform = vocos_vocoder(mel=dummy_mel)
print(dummy_waveform.shape)
dummy_input = dummy_mel
torch.onnx.export(
vocos_vocoder,
dummy_input,
output_path,
opset_version=opset_version,
do_constant_folding=True,
input_names=["mel"],
output_names=["waveform"],
dynamic_axes={
"mel": {0: "batch_size", 2: "input_length"},
"waveform": {0: "batch_size", 1: "output_length"},
},
verbose=verbose,
)
print("Exported to {}".format(output_path))
def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device="cpu", hf_cache_dir=None):
if vocoder_name == "vocos":
# vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
if is_local:
print(f"Load vocos from local path {local_path}")
config_path = f"{local_path}/config.yaml"
model_path = f"{local_path}/pytorch_model.bin"
else:
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
repo_id = "charactr/vocos-mel-24khz"
config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
vocoder = Vocos.from_hparams(config_path)
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
vocoder.load_state_dict(state_dict)
vocoder = vocoder.eval().to(device)
elif vocoder_name == "bigvgan":
raise NotImplementedError("BigVGAN is not supported yet")
vocoder.remove_weight_norm()
vocoder = vocoder.eval().to(device)
return vocoder
if __name__ == "__main__":
args = get_args()
vocoder = load_vocoder(vocoder_name=args.vocoder, device="cpu", hf_cache_dir=None)
if args.vocoder == "vocos":
export_VocosVocoder(vocoder, args.output_path, verbose=False)
================================================
FILE: src/f5_tts/runtime/triton_trtllm/scripts/export_vocos_trt.sh
================================================
#!/bin/bash
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Manual installation of TensorRT, in case not using NVIDIA NGC:
# https://docs.nvidia.com/deeplearning/tensorrt/latest/installing-tensorrt/installing.html#downloading-tensorrt
TRTEXEC="/usr/src/tensorrt/bin/trtexec"
ONNX_PATH=$1
ENGINE_PATH=$2
echo "ONNX_PATH: $ONNX_PATH"
echo "ENGINE_PATH: $ENGINE_PATH"
PRECISION="fp32"
MIN_BATCH_SIZE=1
OPT_BATCH_SIZE=1
MAX_BATCH_SIZE=8
MIN_INPUT_LENGTH=1
OPT_INPUT_LENGTH=1000
MAX_INPUT_LENGTH=3000 # 4096
MEL_MIN_SHAPE="${MIN_BATCH_SIZE}x100x${MIN_INPUT_LENGTH}"
MEL_OPT_SHAPE="${OPT_BATCH_SIZE}x100x${OPT_INPUT_LENGTH}"
MEL_MAX_SHAPE="${MAX_BATCH_SIZE}x100x${MAX_INPUT_LENGTH}"
${TRTEXEC} \
--minShapes="mel:${MEL_MIN_SHAPE}" \
--optShapes="mel:${MEL_OPT_SHAPE}" \
--maxShapes="mel:${MEL_MAX_SHAPE}" \
--onnx=${ONNX_PATH} \
--saveEngine=${ENGINE_PATH}
================================================
FILE: src/f5_tts/runtime/triton_trtllm/scripts/fill_template.py
================================================
#! /usr/bin/env python3
from argparse import ArgumentParser
from string import Template
def main(file_path, substitutions, in_place, participant_ids):
with open(file_path) as f:
pbtxt = Template(f.read())
sub_dict = {"max_queue_size": 0}
sub_dict["participant_ids"] = participant_ids
for sub in substitutions.split(","):
key, value = sub.split(":")
sub_dict[key] = value
pbtxt = pbtxt.safe_substitute(sub_dict)
if in_place:
with open(file_path, "w") as f:
f.write(pbtxt)
else:
print(pbtxt)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("file_path", help="path of the .pbtxt to modify")
parser.add_argument(
"substitutions",
help="substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2...",
)
parser.add_argument("--in_place", "-i", action="store_true", help="do the operation in-place")
parser.add_argument("--participant_ids", help="Participant IDs for the model", default="")
args = parser.parse_args()
main(**vars(args))
================================================
FILE: src/f5_tts/scripts/count_max_epoch.py
================================================
"""ADAPTIVE BATCH SIZE"""
print("Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in")
print(" -> least padding, gather wavs with accumulated frames in a batch\n")
# data
total_hours = 95282
mel_hop_length = 256
mel_sampling_rate = 24000
# target
wanted_max_updates = 1200000
# train params
gpus = 8
frames_per_gpu = 38400 # 8 * 38400 = 307200
grad_accum = 1
# intermediate
mini_batch_frames = frames_per_gpu * grad_accum * gpus
mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
updates_per_epoch = total_hours / mini_batch_hours
# steps_per_epoch = updates_per_epoch * grad_accum
# result
epochs = wanted_max_updates / updates_per_epoch
print(f"epochs should be set to: {epochs:.0f} ({epochs / grad_accum:.1f} x gd_acum {grad_accum})")
print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
# print(f" or approx. 0/{steps_per_epoch:.0f} steps")
# others
print(f"total {total_hours:.0f} hours")
print(f"mini-batch of {mini_batch_frames:.0f} frames, {mini_batch_hours:.2f} hours per mini-batch")
================================================
FILE: src/f5_tts/scripts/count_max_epoch_precise.py
================================================
import math
from torch.utils.data import SequentialSampler
from f5_tts.model.dataset import DynamicBatchSampler, load_dataset
train_dataset = load_dataset("Emilia_ZH_EN", "pinyin")
sampler = SequentialSampler(train_dataset)
gpus = 8
batch_size_per_gpu = 38400
max_samples_per_gpu = 64
max_updates = 1250000
batch_sampler = DynamicBatchSampler(
sampler,
batch_size_per_gpu,
max_samples=max_samples_per_gpu,
random_seed=666,
drop_residual=False,
)
updates_per_epoch = int(len(batch_sampler) / gpus)
print(
f"One epoch has {updates_per_epoch} updates if gpus={gpus}, with "
f"batch_size_per_gpu={batch_size_per_gpu} (frames) & "
f"max_samples_per_gpu={max_samples_per_gpu}."
)
print(f"If gpus={gpus}, for max_updates={max_updates} should set epoch={math.ceil(max_updates / updates_per_epoch)}.")
================================================
FILE: src/f5_tts/scripts/count_params_gflops.py
================================================
import os
import sys
sys.path.append(os.getcwd())
import thop
import torch
from f5_tts.model import CFM, DiT
""" ~155M """
# transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
# transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4)
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
# transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
""" ~335M """
# FLOPs: 622.1 G, Params: 333.2 M
# transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
# FLOPs: 363.4 G, Params: 335.8 M
transformer = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
model = CFM(transformer=transformer)
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
duration = 20
frame_length = int(duration * target_sample_rate / hop_length)
text_length = 150
flops, params = thop.profile(
model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long))
)
print(f"FLOPs: {flops / 1e9} G")
print(f"Params: {params / 1e6} M")
================================================
FILE: src/f5_tts/socket_client.py
================================================
import asyncio
import logging
import socket
import time
import numpy as np
import pyaudio
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998):
client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
await asyncio.get_event_loop().run_in_executor(None, client_socket.connect, (server_ip, int(server_port)))
start_time = time.time()
first_chunk_time = None
async def play_audio_stream():
nonlocal first_chunk_time
p = pyaudio.PyAudio()
stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True, frames_per_buffer=2048)
try:
while True:
data = await asyncio.get_event_loop().run_in_executor(None, client_socket.recv, 8192)
if not data:
break
if data == b"END":
logger.info("End of audio received.")
break
audio_array = np.frombuffer(data, dtype=np.float32)
stream.write(audio_array.tobytes())
if first_chunk_time is None:
first_chunk_time = time.time()
finally:
stream.stop_stream()
stream.close()
p.terminate()
logger.info(f"Total time taken: {time.time() - start_time:.4f} seconds")
try:
data_to_send = f"{text}".encode("utf-8")
await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, data_to_send)
await play_audio_stream()
except Exception as e:
logger.error(f"Error in listen_to_F5TTS: {e}")
finally:
client_socket.close()
if __name__ == "__main__":
text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency. Let's break down the components"
asyncio.run(listen_to_F5TTS(text_to_send))
================================================
FILE: src/f5_tts/socket_server.py
================================================
import argparse
import gc
import logging
import queue
import socket
import struct
import threading
import traceback
import wave
from importlib.resources import files
import numpy as np
import torch
import torchaudio
from huggingface_hub import hf_hub_download
from hydra.utils import get_class
from omegaconf import OmegaConf
from f5_tts.infer.utils_infer import (
chunk_text,
infer_batch_process,
load_model,
load_vocoder,
preprocess_ref_audio_text,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class AudioFileWriterThread(threading.Thread):
"""Threaded file writer to avoid blocking the TTS streaming process."""
def __init__(self, output_file, sampling_rate):
super().__init__()
self.output_file = output_file
self.sampling_rate = sampling_rate
self.queue = queue.Queue()
self.stop_event = threading.Event()
self.audio_data = []
def run(self):
"""Process queued audio data and write it to a file."""
logger.info("AudioFileWriterThread started.")
with wave.open(self.output_file, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(self.sampling_rate)
while not self.stop_event.is_set() or not self.queue.empty():
try:
chunk = self.queue.get(timeout=0.1)
if chunk is not None:
chunk = np.int16(chunk * 32767)
self.audio_data.append(chunk)
wf.writeframes(chunk.tobytes())
except queue.Empty:
continue
def add_chunk(self, chunk):
"""Add a new chunk to the queue."""
self.queue.put(chunk)
def stop(self):
"""Stop writing and ensure all queued data is written."""
self.stop_event.set()
self.join()
logger.info("Audio writing completed.")
class TTSStreamingProcessor:
def __init__(self, model, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
self.device = device or (
"cuda"
if torch.cuda.is_available()
else "xpu"
if torch.xpu.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
self.model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
self.model_arc = model_cfg.model.arch
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
self.sampling_rate = model_cfg.model.mel_spec.target_sample_rate
self.model = self.load_ema_model(ckpt_file, vocab_file, dtype)
self.vocoder = self.load_vocoder_model()
self.update_reference(ref_audio, ref_text)
self._warm_up()
self.file_writer_thread = None
self.first_package = True
def load_ema_model(self, ckpt_file, vocab_file, dtype):
return load_model(
self.model_cls,
self.model_arc,
ckpt_path=ckpt_file,
mel_spec_type=self.mel_spec_type,
vocab_file=vocab_file,
ode_method="euler",
use_ema=True,
device=self.device,
).to(self.device, dtype=dtype)
def load_vocoder_model(self):
return load_vocoder(vocoder_name=self.mel_spec_type, is_local=False, local_path=None, device=self.device)
def update_reference(self, ref_audio, ref_text):
self.ref_audio, self.ref_text = preprocess_ref_audio_text(ref_audio, ref_text)
self.audio, self.sr = torchaudio.load(self.ref_audio)
ref_audio_duration = self.audio.shape[-1] / self.sr
ref_text_byte_len = len(self.ref_text.encode("utf-8"))
self.max_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration))
self.few_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration) / 2)
self.min_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration) / 4)
def _warm_up(self):
logger.info("Warming up the model...")
gen_text = "Warm-up text for the model."
for _ in infer_batch_process(
(self.audio, self.sr),
self.ref_text,
[gen_text],
self.model,
self.vocoder,
progress=None,
device=self.device,
streaming=True,
):
pass
logger.info("Warm-up completed.")
def generate_stream(self, text, conn):
text_batches = chunk_text(text, max_chars=self.max_chars)
if self.first_package:
text_batches = chunk_text(text_batches[0], max_chars=self.few_chars) + text_batches[1:]
text_batches = chunk_text(text_batches[0], max_chars=self.min_chars) + text_batches[1:]
self.first_package = False
audio_stream = infer_batch_process(
(self.audio, self.sr),
self.ref_text,
text_batches,
self.model,
self.vocoder,
progress=None,
device=self.device,
streaming=True,
chunk_size=2048,
)
# Reset the file writer thread
if self.file_writer_thread is not None:
self.file_writer_thread.stop()
self.file_writer_thread = AudioFileWriterThread("output.wav", self.sampling_rate)
self.file_writer_thread.start()
for audio_chunk, _ in audio_stream:
if len(audio_chunk) > 0:
logger.info(f"Generated audio chunk of size: {len(audio_chunk)}")
# Send audio chunk via socket
conn.sendall(struct.pack(f"{len(audio_chunk)}f", *audio_chunk))
# Write to file asynchronously
self.file_writer_thread.add_chunk(audio_chunk)
logger.info("Finished sending audio stream.")
conn.sendall(b"END") # Send end signal
# Ensure all audio data is written before exiting
self.file_writer_thread.stop()
def handle_client(conn, processor):
try:
with conn:
conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
while True:
data = conn.recv(1024)
if not data:
processor.first_package = True
break
data_str = data.decode("utf-8").strip()
logger.info(f"Received text: {data_str}")
try:
processor.generate_stream(data_str, conn)
except Exception as inner_e:
logger.error(f"Error during processing: {inner_e}")
traceback.print_exc()
break
except Exception as e:
logger.error(f"Error handling client: {e}")
traceback.print_exc()
def start_server(host, port, processor):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind((host, port))
s.listen()
logger.info(f"Server started on {host}:{port}")
while True:
conn, addr = s.accept()
logger.info(f"Connected by {addr}")
handle_client(conn, processor)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", default=9998)
parser.add_argument(
"--model",
default="F5TTS_v1_Base",
help="The model name, e.g. F5TTS_v1_Base",
)
parser.add_argument(
"--ckpt_file",
default=str(hf_hub_download(repo_id="SWivid/F5-TTS", filename="F5TTS_v1_Base/model_1250000.safetensors")),
help="Path to the model checkpoint file",
)
parser.add_argument(
"--vocab_file",
default="",
help="Path to the vocab file if customized",
)
parser.add_argument(
"--ref_audio",
default=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
help="Reference audio to provide model with speaker characteristics",
)
parser.add_argument(
"--ref_text",
default="",
help="Reference audio subtitle, leave empty to auto-transcribe",
)
parser.add_argument("--device", default=None, help="Device to run the model on")
parser.add_argument("--dtype", default=torch.float32, help="Data type to use for model inference")
args = parser.parse_args()
try:
# Initialize the processor with the model and vocoder
processor = TTSStreamingProcessor(
model=args.model,
ckpt_file=args.ckpt_file,
vocab_file=args.vocab_file,
ref_audio=args.ref_audio,
ref_text=args.ref_text,
device=args.device,
dtype=args.dtype,
)
# Start the server
start_server(args.host, args.port, processor)
except KeyboardInterrupt:
gc.collect()
================================================
FILE: src/f5_tts/train/README.md
================================================
# Training
Check your FFmpeg installation:
```bash
ffmpeg -version
```
If not found, install it first (or skip assuming you know of other backends available).
## Prepare Dataset
Example data processing scripts, and you may tailor your own one along with a Dataset class in `src/f5_tts/model/dataset.py`.
### 1. Some specific Datasets preparing scripts
Download corresponding dataset first, and fill in the path in scripts.
```bash
# Prepare the Emilia dataset
python src/f5_tts/train/datasets/prepare_emilia.py
# Prepare the Wenetspeech4TTS dataset
python src/f5_tts/train/datasets/prepare_wenetspeech4tts.py
# Prepare the LibriTTS dataset
python src/f5_tts/train/datasets/prepare_libritts.py
# Prepare the LJSpeech dataset
python src/f5_tts/train/datasets/prepare_ljspeech.py
```
### 2. Create custom dataset with CSV
Prepare a CSV with two columns using a required header: `audio_file|text`. Audio paths must be absolute.
Use guidance see [#57 here](https://github.com/SWivid/F5-TTS/discussions/57#discussioncomment-10959029).
```bash
python src/f5_tts/train/datasets/prepare_csv_wavs.py /path/to/metadata.csv /path/to/output
```
## Training & Finetuning
Once your datasets are prepared, you can start the training process.
### 1. Training script used for pretrained model
```bash
# setup accelerate config, e.g. use multi-gpu ddp, fp16
# will be to: ~/.cache/huggingface/accelerate/default_config.yaml
accelerate config
# .yaml files are under src/f5_tts/configs directory
accelerate launch src/f5_tts/train/train.py --config-name F5TTS_v1_Base.yaml
# possible to overwrite accelerate and hydra config
accelerate launch --mixed_precision=fp16 src/f5_tts/train/train.py --config-name F5TTS_v1_Base.yaml ++datasets.batch_size_per_gpu=19200
```
### 2. Finetuning practice
Discussion board for Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57).
Gradio UI training/finetuning with `src/f5_tts/train/finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
If want to finetune with a variant version e.g. *F5TTS_v1_Base_no_zero_init*, manually download pretrained checkpoint from model weight repository and fill in the path correspondingly on web interface.
If use tensorboard as logger, install it first with `pip install tensorboard`.
The `use_ema = True` might be harmful for early-stage finetuned checkpoints (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off with finetune gradio option or `load_model(..., use_ema=False)`, see if offer better results.
### 3. W&B Logging
The `wandb/` dir will be created under path you run training/finetuning scripts.
By default, the training script does NOT use logging (assuming you didn't manually log in using `wandb login`).
To turn on wandb logging, you can either:
1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login)
2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/authorize and set the environment variable as follows:
On Mac & Linux:
```
export WANDB_API_KEY=
```
On Windows:
```
set WANDB_API_KEY=
```
Moreover, if you couldn't access W&B and want to log metrics offline, you can set the environment variable as follows:
```
export WANDB_MODE=offline
```
================================================
FILE: src/f5_tts/train/datasets/prepare_csv_wavs.py
================================================
"""
Usage:
python prepare_csv_wavs.py /path/to/metadata.csv /output/dataset/path [--pretrain] [--workers N]
CSV format (header required, "|" delimiter):
audio_file|text
/path/to/wavs/audio_0001.wav|Yo! Hello? Hello?
/path/to/wavs/audio_0002.wav|Hi, how are you doing today? I want to go shopping and buy me some lemons.
Notes:
- audio_file must be an absolute path.
"""
import concurrent.futures
import multiprocessing
import os
import shutil
import signal
import subprocess
import sys
from contextlib import contextmanager
sys.path.append(os.getcwd())
import argparse
import csv
import json
from importlib.resources import files
from pathlib import Path
import soundfile as sf
import torchaudio
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
from f5_tts.model.utils import convert_char_to_pinyin
PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt")
# Configuration constants
BATCH_SIZE = 100 # Batch size for text conversion
MAX_WORKERS = max(1, multiprocessing.cpu_count() - 1) # Leave one CPU free
THREAD_NAME_PREFIX = "AudioProcessor"
CHUNK_SIZE = 100 # Number of files to process per worker batch
executor = None # Global executor for cleanup
def is_csv_wavs_format(input_path):
fpath = Path(input_path).expanduser()
return fpath.is_file() and fpath.suffix.lower() == ".csv"
@contextmanager
def graceful_exit():
"""Context manager for graceful shutdown on signals"""
def signal_handler(signum, frame):
print("\nReceived signal to terminate. Cleaning up...")
if executor is not None:
print("Shutting down executor...")
executor.shutdown(wait=False, cancel_futures=True)
sys.exit(1)
# Set up signal handlers
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
try:
yield
finally:
if executor is not None:
executor.shutdown(wait=False)
def process_audio_file(audio_path, text, polyphone):
"""Process a single audio file by checking its existence and extracting duration."""
if not Path(audio_path).exists():
print(f"audio {audio_path} not found, skipping")
return None
try:
audio_duration = get_audio_duration(audio_path)
if audio_duration <= 0:
raise ValueError(f"Duration {audio_duration} is non-positive.")
return (audio_path, text, audio_duration)
except Exception as e:
print(f"Warning: Failed to process {audio_path} due to error: {e}. Skipping corrupt file.")
return None
def batch_convert_texts(texts, polyphone, batch_size=BATCH_SIZE):
"""Convert a list of texts to pinyin in batches."""
converted_texts = []
for i in tqdm(
range(0, len(texts), batch_size),
total=(len(texts) + batch_size - 1) // batch_size,
desc="Converting texts to pinyin",
):
batch = texts[i : i + batch_size]
converted_batch = convert_char_to_pinyin(batch, polyphone=polyphone)
converted_texts.extend(converted_batch)
return converted_texts
def prepare_csv_wavs_dir(input_path, num_workers=None):
global executor
if not is_csv_wavs_format(input_path):
raise ValueError(f"input must be a .csv file: {input_path}")
audio_path_text_pairs = read_audio_text_pairs(Path(input_path).expanduser().as_posix())
polyphone = True
total_files = len(audio_path_text_pairs)
if total_files == 0:
raise RuntimeError("No valid rows found in CSV.")
# Use provided worker count or calculate optimal number
worker_count = num_workers if num_workers is not None else min(MAX_WORKERS, total_files)
print(f"\nProcessing {total_files} audio files using {worker_count} workers...")
with graceful_exit():
# Initialize thread pool with optimized settings
with concurrent.futures.ThreadPoolExecutor(
max_workers=worker_count, thread_name_prefix=THREAD_NAME_PREFIX
) as exec:
executor = exec
results = []
# Process files in chunks for better efficiency
for i in range(0, len(audio_path_text_pairs), CHUNK_SIZE):
chunk = audio_path_text_pairs[i : i + CHUNK_SIZE]
# Submit futures in order
chunk_futures = [executor.submit(process_audio_file, pair[0], pair[1], polyphone) for pair in chunk]
# Iterate over futures in the original submission order to preserve ordering
for future in tqdm(
chunk_futures,
total=len(chunk),
desc=f"Processing chunk {i // CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1) // CHUNK_SIZE}",
):
try:
result = future.result()
if result is not None:
results.append(result)
except Exception as e:
print(f"Error processing file: {e}")
executor = None
# Filter out failed results
processed = [res for res in results if res is not None]
if not processed:
raise RuntimeError("No valid audio files were processed!")
# Batch process text conversion
raw_texts = [item[1] for item in processed]
converted_texts = batch_convert_texts(raw_texts, polyphone, batch_size=BATCH_SIZE)
# Prepare final results
sub_result = []
durations = []
vocab_set = set()
for (audio_path, _, duration), conv_text in zip(processed, converted_texts):
sub_result.append({"audio_path": audio_path, "text": conv_text, "duration": duration})
durations.append(duration)
vocab_set.update(list(conv_text))
return sub_result, durations, vocab_set
def get_audio_duration(audio_path, timeout=5):
"""Get the duration of an audio file in seconds with fallbacks."""
try:
return sf.info(audio_path).duration
except Exception as e:
print(f"Warning: soundfile failed for {audio_path} with error: {e}. Falling back to ffprobe.")
try:
cmd = [
"ffprobe",
"-v",
"error",
"-show_entries",
"format=duration",
"-of",
"default=noprint_wrappers=1:nokey=1",
audio_path,
]
result = subprocess.run(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True, timeout=timeout
)
duration_str = result.stdout.strip()
if duration_str:
return float(duration_str)
raise ValueError("Empty duration string from ffprobe.")
except (subprocess.TimeoutExpired, subprocess.SubprocessError, ValueError) as e:
print(f"Warning: ffprobe failed for {audio_path} with error: {e}. Falling back to torchaudio.info.")
try:
info = torchaudio.info(audio_path)
if info.sample_rate > 0:
return info.num_frames / info.sample_rate
raise ValueError("Invalid sample_rate from torchaudio.info.")
except Exception as e:
raise RuntimeError(f"failed to get duration for {audio_path}: {e}")
def read_audio_text_pairs(csv_file_path):
audio_text_pairs = []
csv_path = Path(csv_file_path).expanduser().absolute()
with open(csv_path.as_posix(), mode="r", newline="", encoding="utf-8-sig") as csvfile:
reader = csv.reader(csvfile, delimiter="|")
header = next(reader, None)
if header is None:
return audio_text_pairs
if len(header) < 2 or header[0].strip() != "audio_file" or header[1].strip() != "text":
raise ValueError("CSV header must be: audio_file|text")
for row_idx, row in enumerate(reader, start=2):
if len(row) < 2:
continue
audio_file = row[0].strip()
text = row[1].strip()
if not audio_file:
continue
audio_path = Path(audio_file).expanduser()
if not audio_path.is_absolute():
raise ValueError(f"audio_file must be an absolute path (row {row_idx}): {audio_file}")
audio_text_pairs.append((audio_path.as_posix(), text))
return audio_text_pairs
def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune):
out_dir = Path(out_dir)
out_dir.mkdir(exist_ok=True, parents=True)
print(f"\nSaving to {out_dir} ...")
raw_arrow_path = out_dir / "raw.arrow"
with ArrowWriter(path=raw_arrow_path.as_posix()) as writer:
for line in tqdm(result, desc="Writing to raw.arrow ..."):
writer.write(line)
writer.finalize()
# Save durations to JSON
dur_json_path = out_dir / "duration.json"
with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False)
# Handle vocab file - write only once based on finetune flag
voca_out_path = out_dir / "vocab.txt"
if is_finetune:
file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix()
shutil.copy2(file_vocab_finetune, voca_out_path)
else:
with open(voca_out_path.as_posix(), "w") as f:
for vocab in sorted(text_vocab_set):
f.write(vocab + "\n")
dataset_name = out_dir.stem
print(f"\nFor {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True, num_workers: int = None):
if is_finetune:
assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}"
sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir, num_workers=num_workers)
save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune)
def get_args():
parser = argparse.ArgumentParser(description="Prepare and save dataset.")
parser.add_argument(
"inp_dir",
type=str,
help="Input CSV with header 'audio_file|text' and absolute wav paths.",
)
parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")
parser.add_argument("--workers", type=int, help=f"Number of worker threads (default: {MAX_WORKERS})")
return parser.parse_args()
def cli():
try:
args = get_args()
prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain, num_workers=args.workers)
except KeyboardInterrupt:
print("\nOperation cancelled by user. Cleaning up...")
if executor is not None:
executor.shutdown(wait=False, cancel_futures=True)
sys.exit(1)
if __name__ == "__main__":
cli()
================================================
FILE: src/f5_tts/train/datasets/prepare_emilia.py
================================================
# Emilia Dataset: https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07
# if use updated new version, i.e. WebDataset, feel free to modify / draft your own script
# generate audio text map for Emilia ZH & EN
# evaluate for vocab size
import os
import sys
sys.path.append(os.getcwd())
import json
from concurrent.futures import ProcessPoolExecutor
from importlib.resources import files
from pathlib import Path
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
from f5_tts.model.utils import convert_char_to_pinyin, repetition_found
out_zh = {
"ZH_B00041_S06226",
"ZH_B00042_S09204",
"ZH_B00065_S09430",
"ZH_B00065_S09431",
"ZH_B00066_S09327",
"ZH_B00066_S09328",
}
zh_filters = ["い", "て"]
# seems synthesized audios, or heavily code-switched
out_en = {
"EN_B00013_S00913",
"EN_B00042_S00120",
"EN_B00055_S04111",
"EN_B00061_S00693",
"EN_B00061_S01494",
"EN_B00061_S03375",
"EN_B00059_S00092",
"EN_B00111_S04300",
"EN_B00100_S03759",
"EN_B00087_S03811",
"EN_B00059_S00950",
"EN_B00089_S00946",
"EN_B00078_S05127",
"EN_B00070_S04089",
"EN_B00074_S09659",
"EN_B00061_S06983",
"EN_B00061_S07060",
"EN_B00059_S08397",
"EN_B00082_S06192",
"EN_B00091_S01238",
"EN_B00089_S07349",
"EN_B00070_S04343",
"EN_B00061_S02400",
"EN_B00076_S01262",
"EN_B00068_S06467",
"EN_B00076_S02943",
"EN_B00064_S05954",
"EN_B00061_S05386",
"EN_B00066_S06544",
"EN_B00076_S06944",
"EN_B00072_S08620",
"EN_B00076_S07135",
"EN_B00076_S09127",
"EN_B00065_S00497",
"EN_B00059_S06227",
"EN_B00063_S02859",
"EN_B00075_S01547",
"EN_B00061_S08286",
"EN_B00079_S02901",
"EN_B00092_S03643",
"EN_B00096_S08653",
"EN_B00063_S04297",
"EN_B00063_S04614",
"EN_B00079_S04698",
"EN_B00104_S01666",
"EN_B00061_S09504",
"EN_B00061_S09694",
"EN_B00065_S05444",
"EN_B00063_S06860",
"EN_B00065_S05725",
"EN_B00069_S07628",
"EN_B00083_S03875",
"EN_B00071_S07665",
"EN_B00071_S07665",
"EN_B00062_S04187",
"EN_B00065_S09873",
"EN_B00065_S09922",
"EN_B00084_S02463",
"EN_B00067_S05066",
"EN_B00106_S08060",
"EN_B00073_S06399",
"EN_B00073_S09236",
"EN_B00087_S00432",
"EN_B00085_S05618",
"EN_B00064_S01262",
"EN_B00072_S01739",
"EN_B00059_S03913",
"EN_B00069_S04036",
"EN_B00067_S05623",
"EN_B00060_S05389",
"EN_B00060_S07290",
"EN_B00062_S08995",
}
en_filters = ["ا", "い", "て"]
def deal_with_audio_dir(audio_dir):
audio_jsonl = audio_dir.with_suffix(".jsonl")
sub_result, durations = [], []
vocab_set = set()
bad_case_zh = 0
bad_case_en = 0
with open(audio_jsonl, "r") as f:
lines = f.readlines()
for line in tqdm(lines, desc=f"{audio_jsonl.stem}"):
obj = json.loads(line)
text = obj["text"]
if obj["language"] == "zh":
if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text):
bad_case_zh += 1
continue
else:
text = text.translate(
str.maketrans({",": ",", "!": "!", "?": "?"})
) # not "。" cuz much code-switched
if obj["language"] == "en":
if (
obj["wav"].split("/")[1] in out_en
or any(f in text for f in en_filters)
or repetition_found(text, length=4)
):
bad_case_en += 1
continue
if tokenizer == "pinyin":
text = convert_char_to_pinyin([text], polyphone=polyphone)[0]
duration = obj["duration"]
sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration})
durations.append(duration)
vocab_set.update(list(text))
return sub_result, durations, vocab_set, bad_case_zh, bad_case_en
def main():
assert tokenizer in ["pinyin", "char"]
result = []
duration_list = []
text_vocab_set = set()
total_bad_case_zh = 0
total_bad_case_en = 0
# process raw data
executor = ProcessPoolExecutor(max_workers=max_workers)
futures = []
for lang in langs:
dataset_path = Path(os.path.join(dataset_dir, lang))
[
futures.append(executor.submit(deal_with_audio_dir, audio_dir))
for audio_dir in dataset_path.iterdir()
if audio_dir.is_dir()
]
for futures in tqdm(futures, total=len(futures)):
sub_result, durations, vocab_set, bad_case_zh, bad_case_en = futures.result()
result.extend(sub_result)
duration_list.extend(durations)
text_vocab_set.update(vocab_set)
total_bad_case_zh += bad_case_zh
total_bad_case_en += bad_case_en
executor.shutdown()
# save preprocessed dataset to disk
if not os.path.exists(f"{save_dir}"):
os.makedirs(f"{save_dir}")
print(f"\nSaving to {save_dir} ...")
# dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
# dataset.save_to_disk(f"{save_dir}/raw", max_shard_size="2GB")
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
for line in tqdm(result, desc="Writing to raw.arrow ..."):
writer.write(line)
writer.finalize()
# dup a json separately saving duration in case for DynamicBatchSampler ease
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False)
# vocab map, i.e. tokenizer
# add alphabets and symbols (optional, if plan to ft on de/fr etc.)
# if tokenizer == "pinyin":
# text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
with open(f"{save_dir}/vocab.txt", "w") as f:
for vocab in sorted(text_vocab_set):
f.write(vocab + "\n")
print(f"\nFor {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
if "ZH" in langs:
print(f"Bad zh transcription case: {total_bad_case_zh}")
if "EN" in langs:
print(f"Bad en transcription case: {total_bad_case_en}\n")
if __name__ == "__main__":
max_workers = 32
tokenizer = "pinyin" # "pinyin" | "char"
polyphone = True
langs = ["ZH", "EN"]
dataset_dir = "/Emilia_Dataset/raw"
dataset_name = f"Emilia_{'_'.join(langs)}_{tokenizer}"
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n")
main()
# Emilia ZH & EN
# samples count 37837916 (after removal)
# pinyin vocab size 2543 (polyphone)
# total duration 95281.87 (hours)
# bad zh asr cnt 230435 (samples)
# bad eh asr cnt 37217 (samples)
# vocab size may be slightly different due to rjieba tokenizer and pypinyin (e.g. way of polyphoneme)
# please be careful if using pretrained model, make sure the vocab.txt is same
================================================
FILE: src/f5_tts/train/datasets/prepare_emilia_v2.py
================================================
# put in src/f5_tts/train/datasets/prepare_emilia_v2.py
# prepares Emilia dataset with the new format w/ Emilia-YODAS
import json
import os
from concurrent.futures import ProcessPoolExecutor
from importlib.resources import files
from pathlib import Path
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
from f5_tts.model.utils import repetition_found
# Define filters for exclusion
out_en = set()
en_filters = ["ا", "い", "て"]
def process_audio_directory(audio_dir):
sub_result, durations, vocab_set = [], [], set()
bad_case_en = 0
for file in audio_dir.iterdir():
if file.suffix == ".json":
with open(file, "r") as f:
obj = json.load(f)
text = obj["text"]
if any(f in text for f in en_filters) or repetition_found(text, length=4):
bad_case_en += 1
continue
duration = obj["duration"]
audio_file = file.with_suffix(".mp3")
if audio_file.exists():
sub_result.append({"audio_path": str(audio_file), "text": text, "duration": duration})
durations.append(duration)
vocab_set.update(list(text))
return sub_result, durations, vocab_set, bad_case_en
def main():
assert tokenizer in ["pinyin", "char"]
result, duration_list, text_vocab_set = [], [], set()
total_bad_case_en = 0
executor = ProcessPoolExecutor(max_workers=max_workers)
futures = []
dataset_path = Path(dataset_dir)
for sub_dir in dataset_path.iterdir():
if sub_dir.is_dir():
futures.append(executor.submit(process_audio_directory, sub_dir))
for future in tqdm(futures, total=len(futures)):
sub_result, durations, vocab_set, bad_case_en = future.result()
result.extend(sub_result)
duration_list.extend(durations)
text_vocab_set.update(vocab_set)
total_bad_case_en += bad_case_en
executor.shutdown()
if not os.path.exists(f"{save_dir}"):
os.makedirs(f"{save_dir}")
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
for line in tqdm(result, desc="Writing to raw.arrow ..."):
writer.write(line)
writer.finalize()
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False)
with open(f"{save_dir}/vocab.txt", "w") as f:
for vocab in sorted(text_vocab_set):
f.write(vocab + "\n")
print(f"For {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
print(f"Bad en transcription case: {total_bad_case_en}\n")
if __name__ == "__main__":
max_workers = 32
tokenizer = "char"
dataset_dir = "/home/ubuntu/emilia-dataset/Emilia-YODAS/EN"
dataset_name = f"Emilia_EN_{tokenizer}"
# save_dir = os.path.expanduser(f"~/F5-TTS/data/{dataset_name}")
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
print(f"Prepare for {dataset_name}, will save to {save_dir}\n")
main()
================================================
FILE: src/f5_tts/train/datasets/prepare_libritts.py
================================================
import os
import sys
sys.path.append(os.getcwd())
import json
from concurrent.futures import ProcessPoolExecutor
from importlib.resources import files
from pathlib import Path
import soundfile as sf
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
def deal_with_audio_dir(audio_dir):
sub_result, durations = [], []
vocab_set = set()
audio_lists = list(audio_dir.rglob("*.wav"))
for line in audio_lists:
text_path = line.with_suffix(".normalized.txt")
text = open(text_path, "r").read().strip()
duration = sf.info(line).duration
if duration < 0.4 or duration > 30:
continue
sub_result.append({"audio_path": str(line), "text": text, "duration": duration})
durations.append(duration)
vocab_set.update(list(text))
return sub_result, durations, vocab_set
def main():
result = []
duration_list = []
text_vocab_set = set()
# process raw data
executor = ProcessPoolExecutor(max_workers=max_workers)
futures = []
for subset in tqdm(SUB_SET):
dataset_path = Path(os.path.join(dataset_dir, subset))
[
futures.append(executor.submit(deal_with_audio_dir, audio_dir))
for audio_dir in dataset_path.iterdir()
if audio_dir.is_dir()
]
for future in tqdm(futures, total=len(futures)):
sub_result, durations, vocab_set = future.result()
result.extend(sub_result)
duration_list.extend(durations)
text_vocab_set.update(vocab_set)
executor.shutdown()
# save preprocessed dataset to disk
if not os.path.exists(f"{save_dir}"):
os.makedirs(f"{save_dir}")
print(f"\nSaving to {save_dir} ...")
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
for line in tqdm(result, desc="Writing to raw.arrow ..."):
writer.write(line)
writer.finalize()
# dup a json separately saving duration in case for DynamicBatchSampler ease
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False)
# vocab map, i.e. tokenizer
with open(f"{save_dir}/vocab.txt", "w") as f:
for vocab in sorted(text_vocab_set):
f.write(vocab + "\n")
print(f"\nFor {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
if __name__ == "__main__":
max_workers = 36
tokenizer = "char" # "pinyin" | "char"
SUB_SET = ["train-clean-100", "train-clean-360", "train-other-500"]
dataset_dir = "/LibriTTS"
dataset_name = f"LibriTTS_{'_'.join(SUB_SET)}_{tokenizer}".replace("train-clean-", "").replace("train-other-", "")
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n")
main()
# For LibriTTS_100_360_500_char, sample count: 354218
# For LibriTTS_100_360_500_char, vocab size is: 78
# For LibriTTS_100_360_500_char, total 554.09 hours
================================================
FILE: src/f5_tts/train/datasets/prepare_ljspeech.py
================================================
import os
import sys
sys.path.append(os.getcwd())
import json
from importlib.resources import files
from pathlib import Path
import soundfile as sf
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
def main():
result = []
duration_list = []
text_vocab_set = set()
with open(meta_info, "r") as f:
lines = f.readlines()
for line in tqdm(lines):
uttr, text, norm_text = line.split("|")
norm_text = norm_text.strip()
wav_path = Path(dataset_dir) / "wavs" / f"{uttr}.wav"
duration = sf.info(wav_path).duration
if duration < 0.4 or duration > 30:
continue
result.append({"audio_path": str(wav_path), "text": norm_text, "duration": duration})
duration_list.append(duration)
text_vocab_set.update(list(norm_text))
# save preprocessed dataset to disk
if not os.path.exists(f"{save_dir}"):
os.makedirs(f"{save_dir}")
print(f"\nSaving to {save_dir} ...")
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
for line in tqdm(result, desc="Writing to raw.arrow ..."):
writer.write(line)
writer.finalize()
# dup a json separately saving duration in case for DynamicBatchSampler ease
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False)
# vocab map, i.e. tokenizer
# add alphabets and symbols (optional, if plan to ft on de/fr etc.)
with open(f"{save_dir}/vocab.txt", "w") as f:
for vocab in sorted(text_vocab_set):
f.write(vocab + "\n")
print(f"\nFor {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
if __name__ == "__main__":
tokenizer = "char" # "pinyin" | "char"
dataset_dir = "/LJSpeech-1.1"
dataset_name = f"LJSpeech_{tokenizer}"
meta_info = os.path.join(dataset_dir, "metadata.csv")
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n")
main()
================================================
FILE: src/f5_tts/train/datasets/prepare_wenetspeech4tts.py
================================================
# generate audio text map for WenetSpeech4TTS
# evaluate for vocab size
import os
import sys
sys.path.append(os.getcwd())
import json
from concurrent.futures import ProcessPoolExecutor
from importlib.resources import files
import torchaudio
from datasets import Dataset
from tqdm import tqdm
from f5_tts.model.utils import convert_char_to_pinyin
def deal_with_sub_path_files(dataset_path, sub_path):
print(f"Dealing with: {sub_path}")
text_dir = os.path.join(dataset_path, sub_path, "txts")
audio_dir = os.path.join(dataset_path, sub_path, "wavs")
text_files = os.listdir(text_dir)
audio_paths, texts, durations = [], [], []
for text_file in tqdm(text_files):
with open(os.path.join(text_dir, text_file), "r", encoding="utf-8") as file:
first_line = file.readline().split("\t")
audio_nm = first_line[0]
audio_path = os.path.join(audio_dir, audio_nm + ".wav")
text = first_line[1].strip()
audio_paths.append(audio_path)
if tokenizer == "pinyin":
texts.extend(convert_char_to_pinyin([text], polyphone=polyphone))
elif tokenizer == "char":
texts.append(text)
audio, sample_rate = torchaudio.load(audio_path)
durations.append(audio.shape[-1] / sample_rate)
return audio_paths, texts, durations
def main():
assert tokenizer in ["pinyin", "char"]
audio_path_list, text_list, duration_list = [], [], []
executor = ProcessPoolExecutor(max_workers=max_workers)
futures = []
for dataset_path in dataset_paths:
sub_items = os.listdir(dataset_path)
sub_paths = [item for item in sub_items if os.path.isdir(os.path.join(dataset_path, item))]
for sub_path in sub_paths:
futures.append(executor.submit(deal_with_sub_path_files, dataset_path, sub_path))
for future in tqdm(futures, total=len(futures)):
audio_paths, texts, durations = future.result()
audio_path_list.extend(audio_paths)
text_list.extend(texts)
duration_list.extend(durations)
executor.shutdown()
if not os.path.exists("data"):
os.makedirs("data")
print(f"\nSaving to {save_dir} ...")
dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
dataset.save_to_disk(f"{save_dir}/raw", max_shard_size="2GB") # arrow format
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
json.dump(
{"duration": duration_list}, f, ensure_ascii=False
) # dup a json separately saving duration in case for DynamicBatchSampler ease
print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...")
text_vocab_set = set()
for text in tqdm(text_list):
text_vocab_set.update(list(text))
# add alphabets and symbols (optional, if plan to ft on de/fr etc.)
if tokenizer == "pinyin":
text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
with open(f"{save_dir}/vocab.txt", "w") as f:
for vocab in sorted(text_vocab_set):
f.write(vocab + "\n")
print(f"\nFor {dataset_name}, sample count: {len(text_list)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n")
if __name__ == "__main__":
max_workers = 32
tokenizer = "pinyin" # "pinyin" | "char"
polyphone = True
dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic
dataset_name = (
["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice - 1]
+ "_"
+ tokenizer
)
dataset_paths = [
"/WenetSpeech4TTS/Basic",
"/WenetSpeech4TTS/Standard",
"/WenetSpeech4TTS/Premium",
][-dataset_choice:]
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
print(f"\nChoose Dataset: {dataset_name}, will save to {save_dir}\n")
main()
# Results (if adding alphabets with accents and symbols):
# WenetSpeech4TTS Basic Standard Premium
# samples count 3932473 1941220 407494
# pinyin vocab size 1349 1348 1344 (no polyphone)
# - - 1459 (polyphone)
# char vocab size 5264 5219 5042
# vocab size may be slightly different due to rjieba tokenizer and pypinyin (e.g. way of polyphoneme)
# please be careful if using pretrained model, make sure the vocab.txt is same
================================================
FILE: src/f5_tts/train/finetune_cli.py
================================================
import argparse
import os
import shutil
from importlib.resources import files
from cached_path import cached_path
from f5_tts.model import CFM, DiT, Trainer, UNetT
from f5_tts.model.dataset import load_dataset
from f5_tts.model.utils import get_tokenizer
# -------------------------- Dataset Settings --------------------------- #
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
win_length = 1024
n_fft = 1024
mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
# -------------------------- Argument Parsing --------------------------- #
def parse_args():
parser = argparse.ArgumentParser(description="Train CFM Model")
parser.add_argument(
"--exp_name",
type=str,
default="F5TTS_v1_Base",
choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"],
help="Experiment name",
)
parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training")
parser.add_argument("--batch_size_per_gpu", type=int, default=3200, help="Batch size per GPU")
parser.add_argument(
"--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type"
)
parser.add_argument("--max_samples", type=int, default=64, help="Max sequences per batch")
parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
parser.add_argument("--num_warmup_updates", type=int, default=20000, help="Warmup updates")
parser.add_argument("--save_per_updates", type=int, default=50000, help="Save checkpoint every N updates")
parser.add_argument(
"--keep_last_n_checkpoints",
type=int,
default=-1,
help="-1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints",
)
parser.add_argument("--last_per_updates", type=int, default=5000, help="Save last checkpoint every N updates")
parser.add_argument("--finetune", action="store_true", help="Use Finetune")
parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
parser.add_argument(
"--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type"
)
parser.add_argument(
"--tokenizer_path",
type=str,
default=None,
help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')",
)
parser.add_argument(
"--log_samples",
action="store_true",
help="Log inferenced samples per ckpt save updates",
)
parser.add_argument("--logger", type=str, default=None, choices=[None, "wandb", "tensorboard"], help="logger")
parser.add_argument(
"--bnb_optimizer",
action="store_true",
help="Use 8-bit Adam optimizer from bitsandbytes",
)
return parser.parse_args()
# -------------------------- Training Settings -------------------------- #
def main():
args = parse_args()
checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
# Model parameters based on experiment name
if args.exp_name == "F5TTS_v1_Base":
wandb_resume_id = None
model_cls = DiT
model_cfg = dict(
dim=1024,
depth=22,
heads=16,
ff_mult=2,
text_dim=512,
conv_layers=4,
)
if args.finetune:
if args.pretrain is None:
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors"))
else:
ckpt_path = args.pretrain
elif args.exp_name == "F5TTS_Base":
wandb_resume_id = None
model_cls = DiT
model_cfg = dict(
dim=1024,
depth=22,
heads=16,
ff_mult=2,
text_dim=512,
text_mask_padding=False,
conv_layers=4,
pe_attn_head=1,
)
if args.finetune:
if args.pretrain is None:
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
else:
ckpt_path = args.pretrain
elif args.exp_name == "E2TTS_Base":
wandb_resume_id = None
model_cls = UNetT
model_cfg = dict(
dim=1024,
depth=24,
heads=16,
ff_mult=4,
text_mask_padding=False,
pe_attn_head=1,
)
if args.finetune:
if args.pretrain is None:
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
else:
ckpt_path = args.pretrain
if args.finetune:
if not os.path.isdir(checkpoint_path):
os.makedirs(checkpoint_path, exist_ok=True)
file_checkpoint = os.path.basename(ckpt_path)
if not file_checkpoint.startswith("pretrained_"): # Change: Add 'pretrained_' prefix to copied model
file_checkpoint = "pretrained_" + file_checkpoint
file_checkpoint = os.path.join(checkpoint_path, file_checkpoint)
if not os.path.isfile(file_checkpoint):
shutil.copy2(ckpt_path, file_checkpoint)
print("copy checkpoint for finetune")
# Use the tokenizer and tokenizer_path provided in the command line arguments
tokenizer = args.tokenizer
if tokenizer == "custom":
if not args.tokenizer_path:
raise ValueError("Custom tokenizer selected, but no tokenizer_path provided.")
tokenizer_path = args.tokenizer_path
else:
tokenizer_path = args.dataset_name
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
print("\nvocab : ", vocab_size)
print("\nvocoder : ", mel_spec_type)
mel_spec_kwargs = dict(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type,
)
model = CFM(
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
mel_spec_kwargs=mel_spec_kwargs,
vocab_char_map=vocab_char_map,
)
trainer = Trainer(
model,
args.epochs,
args.learning_rate,
num_warmup_updates=args.num_warmup_updates,
save_per_updates=args.save_per_updates,
keep_last_n_checkpoints=args.keep_last_n_checkpoints,
checkpoint_path=checkpoint_path,
batch_size_per_gpu=args.batch_size_per_gpu,
batch_size_type=args.batch_size_type,
max_samples=args.max_samples,
grad_accumulation_steps=args.grad_accumulation_steps,
max_grad_norm=args.max_grad_norm,
logger=args.logger,
wandb_project=args.dataset_name,
wandb_run_name=args.exp_name,
wandb_resume_id=wandb_resume_id,
log_samples=args.log_samples,
last_per_updates=args.last_per_updates,
bnb_optimizer=args.bnb_optimizer,
)
train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
trainer.train(
train_dataset,
resumable_with_seed=666, # seed for shuffling dataset
)
if __name__ == "__main__":
main()
================================================
FILE: src/f5_tts/train/finetune_gradio.py
================================================
import gc
import json
import os
import platform
import queue
import random
import re
import shutil
import signal
import subprocess
import sys
import tempfile
import threading
import time
from glob import glob
from importlib.resources import files
import click
import gradio as gr
import librosa
import numpy as np
import psutil
import torch
import torchaudio
from cached_path import cached_path
from datasets import Dataset as Dataset_
from datasets.arrow_writer import ArrowWriter
from safetensors.torch import load_file, save_file
from scipy.io import wavfile
from f5_tts.api import F5TTS
from f5_tts.infer.utils_infer import transcribe
from f5_tts.model.utils import convert_char_to_pinyin
training_process = None
system = platform.system()
python_executable = sys.executable or "python"
tts_api = None
last_checkpoint = ""
last_device = ""
last_ema = None
path_data = str(files("f5_tts").joinpath("../../data"))
path_project_ckpts = str(files("f5_tts").joinpath("../../ckpts"))
file_train = str(files("f5_tts").joinpath("train/finetune_cli.py"))
device = (
"cuda"
if torch.cuda.is_available()
else "xpu"
if torch.xpu.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
# Save settings from a JSON file
def save_settings(
project_name,
exp_name,
learning_rate,
batch_size_per_gpu,
batch_size_type,
max_samples,
grad_accumulation_steps,
max_grad_norm,
epochs,
num_warmup_updates,
save_per_updates,
keep_last_n_checkpoints,
last_per_updates,
finetune,
file_checkpoint_train,
tokenizer_type,
tokenizer_file,
mixed_precision,
logger,
ch_8bit_adam,
):
path_project = os.path.join(path_project_ckpts, project_name)
os.makedirs(path_project, exist_ok=True)
file_setting = os.path.join(path_project, "setting.json")
settings = {
"exp_name": exp_name,
"learning_rate": learning_rate,
"batch_size_per_gpu": batch_size_per_gpu,
"batch_size_type": batch_size_type,
"max_samples": max_samples,
"grad_accumulation_steps": grad_accumulation_steps,
"max_grad_norm": max_grad_norm,
"epochs": epochs,
"num_warmup_updates": num_warmup_updates,
"save_per_updates": save_per_updates,
"keep_last_n_checkpoints": keep_last_n_checkpoints,
"last_per_updates": last_per_updates,
"finetune": finetune,
"file_checkpoint_train": file_checkpoint_train,
"tokenizer_type": tokenizer_type,
"tokenizer_file": tokenizer_file,
"mixed_precision": mixed_precision,
"logger": logger,
"bnb_optimizer": ch_8bit_adam,
}
with open(file_setting, "w") as f:
json.dump(settings, f, indent=4)
return "Settings saved!"
# Load settings from a JSON file
def load_settings(project_name):
project_name = project_name.replace("_pinyin", "").replace("_char", "")
path_project = os.path.join(path_project_ckpts, project_name)
file_setting = os.path.join(path_project, "setting.json")
# Default settings
default_settings = {
"exp_name": "F5TTS_v1_Base",
"learning_rate": 1e-5,
"batch_size_per_gpu": 3200,
"batch_size_type": "frame",
"max_samples": 64,
"grad_accumulation_steps": 1,
"max_grad_norm": 1.0,
"epochs": 100,
"num_warmup_updates": 100,
"save_per_updates": 500,
"keep_last_n_checkpoints": -1,
"last_per_updates": 100,
"finetune": True,
"file_checkpoint_train": "",
"tokenizer_type": "pinyin",
"tokenizer_file": "",
"mixed_precision": "fp16",
"logger": "none",
"bnb_optimizer": False,
}
if device == "mps":
default_settings["mixed_precision"] = "none"
# Load settings from file if it exists
if os.path.isfile(file_setting):
with open(file_setting, "r") as f:
file_settings = json.load(f)
default_settings.update(file_settings)
# Return as a tuple in the correct order
return (
default_settings["exp_name"],
default_settings["learning_rate"],
default_settings["batch_size_per_gpu"],
default_settings["batch_size_type"],
default_settings["max_samples"],
default_settings["grad_accumulation_steps"],
default_settings["max_grad_norm"],
default_settings["epochs"],
default_settings["num_warmup_updates"],
default_settings["save_per_updates"],
default_settings["keep_last_n_checkpoints"],
default_settings["last_per_updates"],
default_settings["finetune"],
default_settings["file_checkpoint_train"],
default_settings["tokenizer_type"],
default_settings["tokenizer_file"],
default_settings["mixed_precision"],
default_settings["logger"],
default_settings["bnb_optimizer"],
)
# Load metadata
def get_audio_duration(audio_path):
"""Calculate the duration mono of an audio file."""
audio, sample_rate = torchaudio.load(audio_path)
return audio.shape[1] / sample_rate
class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
def __init__(
self,
sr: int,
threshold: float = -40.0,
min_length: int = 20000, # 20 seconds
min_interval: int = 300,
hop_size: int = 20,
max_sil_kept: int = 2000,
):
if not min_length >= min_interval >= hop_size:
raise ValueError("The following condition must be satisfied: min_length >= min_interval >= hop_size")
if not max_sil_kept >= hop_size:
raise ValueError("The following condition must be satisfied: max_sil_kept >= hop_size")
min_interval = sr * min_interval / 1000
self.threshold = 10 ** (threshold / 20.0)
self.hop_size = round(sr * hop_size / 1000)
self.win_size = min(round(min_interval), 4 * self.hop_size)
self.min_length = round(sr * min_length / 1000 / self.hop_size)
self.min_interval = round(min_interval / self.hop_size)
self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
def _apply_slice(self, waveform, begin, end):
if len(waveform.shape) > 1:
return waveform[:, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)]
else:
return waveform[begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)]
# @timeit
def slice(self, waveform):
if len(waveform.shape) > 1:
samples = waveform.mean(axis=0)
else:
samples = waveform
if samples.shape[0] <= self.min_length:
return [waveform]
rms_list = librosa.feature.rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
sil_tags = []
silence_start = None
clip_start = 0
for i, rms in enumerate(rms_list):
# Keep looping while frame is silent.
if rms < self.threshold:
# Record start of silent frames.
if silence_start is None:
silence_start = i
continue
# Keep looping while frame is not silent and silence start has not been recorded.
if silence_start is None:
continue
# Clear recorded silence start if interval is not enough or clip is too short
is_leading_silence = silence_start == 0 and i > self.max_sil_kept
need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length
if not is_leading_silence and not need_slice_middle:
silence_start = None
continue
# Need slicing. Record the range of silent frames to be removed.
if i - silence_start <= self.max_sil_kept:
pos = rms_list[silence_start : i + 1].argmin() + silence_start
if silence_start == 0:
sil_tags.append((0, pos))
else:
sil_tags.append((pos, pos))
clip_start = pos
elif i - silence_start <= self.max_sil_kept * 2:
pos = rms_list[i - self.max_sil_kept : silence_start + self.max_sil_kept + 1].argmin()
pos += i - self.max_sil_kept
pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start
pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept
if silence_start == 0:
sil_tags.append((0, pos_r))
clip_start = pos_r
else:
sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
clip_start = max(pos_r, pos)
else:
pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start
pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept
if silence_start == 0:
sil_tags.append((0, pos_r))
else:
sil_tags.append((pos_l, pos_r))
clip_start = pos_r
silence_start = None
# Deal with trailing silence.
total_frames = rms_list.shape[0]
if silence_start is not None and total_frames - silence_start >= self.min_interval:
silence_end = min(total_frames, silence_start + self.max_sil_kept)
pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
sil_tags.append((pos, total_frames + 1))
# Apply and return slices: [chunk, start, end]
if len(sil_tags) == 0:
return [[waveform, 0, int(total_frames * self.hop_size)]]
else:
chunks = []
if sil_tags[0][0] > 0:
chunks.append([self._apply_slice(waveform, 0, sil_tags[0][0]), 0, int(sil_tags[0][0] * self.hop_size)])
for i in range(len(sil_tags) - 1):
chunks.append(
[
self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]),
int(sil_tags[i][1] * self.hop_size),
int(sil_tags[i + 1][0] * self.hop_size),
]
)
if sil_tags[-1][1] < total_frames:
chunks.append(
[
self._apply_slice(waveform, sil_tags[-1][1], total_frames),
int(sil_tags[-1][1] * self.hop_size),
int(total_frames * self.hop_size),
]
)
return chunks
# terminal
def terminate_process_tree(pid, including_parent=True):
try:
parent = psutil.Process(pid)
except psutil.NoSuchProcess:
# Process already terminated
return
children = parent.children(recursive=True)
for child in children:
try:
os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL
except OSError:
pass
if including_parent:
try:
os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL
except OSError:
pass
def terminate_process(pid):
if system == "Windows":
cmd = f"taskkill /t /f /pid {pid}"
os.system(cmd)
else:
terminate_process_tree(pid)
def start_training(
dataset_name,
exp_name,
learning_rate,
batch_size_per_gpu,
batch_size_type,
max_samples,
grad_accumulation_steps,
max_grad_norm,
epochs,
num_warmup_updates,
save_per_updates,
keep_last_n_checkpoints,
last_per_updates,
finetune,
file_checkpoint_train,
tokenizer_type,
tokenizer_file,
mixed_precision,
stream,
logger,
ch_8bit_adam,
):
global training_process, tts_api, stop_signal
if tts_api is not None:
if tts_api is not None:
del tts_api
gc.collect()
torch.cuda.empty_cache()
tts_api = None
path_project = os.path.join(path_data, dataset_name)
if not os.path.isdir(path_project):
yield (
f"There is not project with name {dataset_name}",
gr.update(interactive=True),
gr.update(interactive=False),
)
return
file_raw = os.path.join(path_project, "raw.arrow")
if not os.path.isfile(file_raw):
yield f"There is no file {file_raw}", gr.update(interactive=True), gr.update(interactive=False)
return
# Check if a training process is already running
if training_process is not None:
return "Train run already!", gr.update(interactive=False), gr.update(interactive=True)
yield "start train", gr.update(interactive=False), gr.update(interactive=False)
# Command to run the training script with the specified arguments
if tokenizer_file == "":
if dataset_name.endswith("_pinyin"):
tokenizer_type = "pinyin"
elif dataset_name.endswith("_char"):
tokenizer_type = "char"
else:
tokenizer_type = "custom"
dataset_name = dataset_name.replace("_pinyin", "").replace("_char", "")
if mixed_precision != "none":
fp16 = f"--mixed_precision={mixed_precision}"
else:
fp16 = ""
cmd = (
f'accelerate launch {fp16} "{file_train}" --exp_name {exp_name}'
f" --learning_rate {learning_rate}"
f" --batch_size_per_gpu {batch_size_per_gpu}"
f" --batch_size_type {batch_size_type}"
f" --max_samples {max_samples}"
f" --grad_accumulation_steps {grad_accumulation_steps}"
f" --max_grad_norm {max_grad_norm}"
f" --epochs {epochs}"
f" --num_warmup_updates {num_warmup_updates}"
f" --save_per_updates {save_per_updates}"
f" --keep_last_n_checkpoints {keep_last_n_checkpoints}"
f" --last_per_updates {last_per_updates}"
f" --dataset_name {dataset_name}"
)
if finetune:
cmd += " --finetune"
if file_checkpoint_train != "":
cmd += f' --pretrain "{file_checkpoint_train}"'
if tokenizer_file != "":
cmd += f" --tokenizer_path {tokenizer_file}"
cmd += f" --tokenizer {tokenizer_type}"
if logger != "none":
cmd += f" --logger {logger}"
cmd += " --log_samples"
if ch_8bit_adam:
cmd += " --bnb_optimizer"
print("run command : \n" + cmd + "\n")
save_settings(
dataset_name,
exp_name,
learning_rate,
batch_size_per_gpu,
batch_size_type,
max_samples,
grad_accumulation_steps,
max_grad_norm,
epochs,
num_warmup_updates,
save_per_updates,
keep_last_n_checkpoints,
last_per_updates,
finetune,
file_checkpoint_train,
tokenizer_type,
tokenizer_file,
mixed_precision,
logger,
ch_8bit_adam,
)
try:
if not stream:
# Start the training process
training_process = subprocess.Popen(cmd, shell=True)
time.sleep(5)
yield "train start", gr.update(interactive=False), gr.update(interactive=True)
# Wait for the training process to finish
training_process.wait()
else:
def stream_output(pipe, output_queue):
try:
for line in iter(pipe.readline, ""):
output_queue.put(line)
except Exception as e:
output_queue.put(f"Error reading pipe: {str(e)}")
finally:
pipe.close()
env = os.environ.copy()
env["PYTHONUNBUFFERED"] = "1"
training_process = subprocess.Popen(
cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1, env=env
)
yield "Training started ...", gr.update(interactive=False), gr.update(interactive=True)
stdout_queue = queue.Queue()
stderr_queue = queue.Queue()
stdout_thread = threading.Thread(target=stream_output, args=(training_process.stdout, stdout_queue))
stderr_thread = threading.Thread(target=stream_output, args=(training_process.stderr, stderr_queue))
stdout_thread.daemon = True
stderr_thread.daemon = True
stdout_thread.start()
stderr_thread.start()
stop_signal = False
while True:
if stop_signal:
training_process.terminate()
time.sleep(0.5)
if training_process.poll() is None:
training_process.kill()
yield "Training stopped by user.", gr.update(interactive=True), gr.update(interactive=False)
break
process_status = training_process.poll()
# Handle stdout
try:
while True:
output = stdout_queue.get_nowait()
print(output, end="")
match = re.search(
r"Epoch (\d+)/(\d+):\s+(\d+)%\|.*\[(\d+:\d+)<.*?loss=(\d+\.\d+), update=(\d+)", output
)
if match:
current_epoch = match.group(1)
total_epochs = match.group(2)
percent_complete = match.group(3)
elapsed_time = match.group(4)
loss = match.group(5)
current_update = match.group(6)
message = (
f"Epoch: {current_epoch}/{total_epochs}, "
f"Progress: {percent_complete}%, "
f"Elapsed Time: {elapsed_time}, "
f"Loss: {loss}, "
f"Update: {current_update}"
)
yield message, gr.update(interactive=False), gr.update(interactive=True)
elif output.strip():
yield output, gr.update(interactive=False), gr.update(interactive=True)
except queue.Empty:
pass
# Handle stderr
try:
while True:
error_output = stderr_queue.get_nowait()
print(error_output, end="")
if error_output.strip():
yield f"{error_output.strip()}", gr.update(interactive=False), gr.update(interactive=True)
except queue.Empty:
pass
if process_status is not None and stdout_queue.empty() and stderr_queue.empty():
if process_status != 0:
yield (
f"Process crashed with exit code {process_status}!",
gr.update(interactive=False),
gr.update(interactive=True),
)
else:
yield (
"Training complete or paused ...",
gr.update(interactive=False),
gr.update(interactive=True),
)
break
# Small sleep to prevent CPU thrashing
time.sleep(0.1)
# Clean up
training_process.stdout.close()
training_process.stderr.close()
training_process.wait()
time.sleep(1)
if training_process is None:
text_info = "Train stopped !"
else:
text_info = "Train complete at end !"
except Exception as e: # Catch all exceptions
# Ensure that we reset the training process variable in case of an error
text_info = f"An error occurred: {str(e)}"
training_process = None
yield text_info, gr.update(interactive=True), gr.update(interactive=False)
def stop_training():
global training_process, stop_signal
if training_process is None:
return "Train not running !", gr.update(interactive=True), gr.update(interactive=False)
terminate_process_tree(training_process.pid)
# training_process = None
stop_signal = True
return "Train stopped !", gr.update(interactive=True), gr.update(interactive=False)
def get_list_projects():
project_list = []
for folder in os.listdir(path_data):
path_folder = os.path.join(path_data, folder)
if not os.path.isdir(path_folder):
continue
folder = folder.lower()
if folder == "emilia_zh_en_pinyin":
continue
project_list.append(folder)
projects_selelect = None if not project_list else project_list[-1]
return project_list, projects_selelect
def create_data_project(name, tokenizer_type):
name += "_" + tokenizer_type
os.makedirs(os.path.join(path_data, name), exist_ok=True)
os.makedirs(os.path.join(path_data, name, "dataset"), exist_ok=True)
project_list, projects_selelect = get_list_projects()
return gr.update(choices=project_list, value=name)
def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()):
path_project = os.path.join(path_data, name_project)
path_dataset = os.path.join(path_project, "dataset")
path_project_wavs = os.path.join(path_project, "wavs")
file_metadata = os.path.join(path_project, "metadata.csv")
if not user:
if audio_files is None:
return "You need to load an audio file."
if os.path.isdir(path_project_wavs):
shutil.rmtree(path_project_wavs)
if os.path.isfile(file_metadata):
os.remove(file_metadata)
os.makedirs(path_project_wavs, exist_ok=True)
if user:
file_audios = [
file
for format in ("*.wav", "*.ogg", "*.opus", "*.mp3", "*.flac")
for file in glob(os.path.join(path_dataset, format))
]
if file_audios == []:
return "No audio file was found in the dataset."
else:
file_audios = audio_files
alpha = 0.5
_max = 1.0
slicer = Slicer(24000)
num = 0
error_num = 0
data = ""
for file_audio in progress.tqdm(file_audios, desc="transcribe files", total=len((file_audios))):
audio, _ = librosa.load(file_audio, sr=24000, mono=True)
list_slicer = slicer.slice(audio)
for chunk, start, end in progress.tqdm(list_slicer, total=len(list_slicer), desc="slicer files"):
name_segment = os.path.join(f"segment_{num}")
file_segment = os.path.join(path_project_wavs, f"{name_segment}.wav")
tmp_max = np.abs(chunk).max()
if tmp_max > 1:
chunk /= tmp_max
chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk
wavfile.write(file_segment, 24000, (chunk * 32767).astype(np.int16))
try:
text = transcribe(file_segment, language)
text = text.strip()
data += f"{name_segment}|{text}\n"
num += 1
except: # noqa: E722
error_num += 1
with open(file_metadata, "w", encoding="utf-8-sig") as f:
f.write(data)
if error_num != []:
error_text = f"\nerror files : {error_num}"
else:
error_text = ""
return f"transcribe complete samples : {num}\npath : {path_project_wavs}{error_text}"
def format_seconds_to_hms(seconds):
hours = int(seconds / 3600)
minutes = int((seconds % 3600) / 60)
seconds = seconds % 60
return "{:02d}:{:02d}:{:02d}".format(hours, minutes, int(seconds))
def get_correct_audio_path(
audio_input,
base_path="wavs",
supported_formats=("wav", "mp3", "aac", "flac", "m4a", "alac", "ogg", "aiff", "wma", "amr"),
):
file_audio = None
# Helper function to check if file has a supported extension
def has_supported_extension(file_name):
return any(file_name.endswith(f".{ext}") for ext in supported_formats)
# Case 1: If it's a full path with a valid extension, use it directly
if os.path.isabs(audio_input) and has_supported_extension(audio_input):
file_audio = audio_input
# Case 2: If it has a supported extension but is not a full path
elif has_supported_extension(audio_input) and not os.path.isabs(audio_input):
file_audio = os.path.join(base_path, audio_input)
# Case 3: If only the name is given (no extension and not a full path)
elif not has_supported_extension(audio_input) and not os.path.isabs(audio_input):
for ext in supported_formats:
potential_file = os.path.join(base_path, f"{audio_input}.{ext}")
if os.path.exists(potential_file):
file_audio = potential_file
break
else:
file_audio = os.path.join(base_path, f"{audio_input}.{supported_formats[0]}")
return file_audio
def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
path_project = os.path.join(path_data, name_project)
path_project_wavs = os.path.join(path_project, "wavs")
file_metadata = os.path.join(path_project, "metadata.csv")
file_raw = os.path.join(path_project, "raw.arrow")
file_duration = os.path.join(path_project, "duration.json")
file_vocab = os.path.join(path_project, "vocab.txt")
if not os.path.isfile(file_metadata):
return "The file was not found in " + file_metadata, ""
with open(file_metadata, "r", encoding="utf-8-sig") as f:
data = f.read()
audio_path_list = []
text_list = []
duration_list = []
count = data.split("\n")
lenght = 0
result = []
error_files = []
text_vocab_set = set()
for line in progress.tqdm(data.split("\n"), total=count):
sp_line = line.split("|")
if len(sp_line) != 2:
continue
name_audio, text = sp_line[:2]
file_audio = get_correct_audio_path(name_audio, path_project_wavs)
if not os.path.isfile(file_audio):
error_files.append([file_audio, "error path"])
continue
try:
duration = get_audio_duration(file_audio)
except Exception as e:
error_files.append([file_audio, "duration"])
print(f"Error processing {file_audio}: {e}")
continue
if duration < 1 or duration > 30:
if duration > 30:
error_files.append([file_audio, "duration > 30 sec"])
if duration < 1:
error_files.append([file_audio, "duration < 1 sec "])
continue
if len(text) < 3:
error_files.append([file_audio, "very short text length 3"])
continue
text = text.strip()
text = convert_char_to_pinyin([text], polyphone=True)[0]
audio_path_list.append(file_audio)
duration_list.append(duration)
text_list.append(text)
result.append({"audio_path": file_audio, "text": text, "duration": duration})
if ch_tokenizer:
text_vocab_set.update(list(text))
lenght += duration
if duration_list == []:
return f"Error: No audio files found in the specified path : {path_project_wavs}", ""
min_second = round(min(duration_list), 2)
max_second = round(max(duration_list), 2)
with ArrowWriter(path=file_raw) as writer:
for line in progress.tqdm(result, total=len(result), desc="prepare data"):
writer.write(line)
writer.finalize()
with open(file_duration, "w") as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False)
new_vocal = ""
if not ch_tokenizer:
if not os.path.isfile(file_vocab):
file_vocab_finetune = os.path.join(path_data, "Emilia_ZH_EN_pinyin/vocab.txt")
if not os.path.isfile(file_vocab_finetune):
return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!", ""
shutil.copy2(file_vocab_finetune, file_vocab)
with open(file_vocab, "r", encoding="utf-8-sig") as f:
vocab_char_map = {}
for i, char in enumerate(f):
vocab_char_map[char[:-1]] = i
vocab_size = len(vocab_char_map)
else:
with open(file_vocab, "w", encoding="utf-8-sig") as f:
for vocab in sorted(text_vocab_set):
f.write(vocab + "\n")
new_vocal += vocab + "\n"
vocab_size = len(text_vocab_set)
if error_files != []:
error_text = "\n".join([" = ".join(item) for item in error_files])
else:
error_text = ""
return (
f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\nvocab : {vocab_size}\n{error_text}",
new_vocal,
)
def check_user(value):
return gr.update(visible=not value), gr.update(visible=value)
def calculate_train(
name_project,
epochs,
learning_rate,
batch_size_per_gpu,
batch_size_type,
max_samples,
num_warmup_updates,
finetune,
):
path_project = os.path.join(path_data, name_project)
file_duration = os.path.join(path_project, "duration.json")
hop_length = 256
sampling_rate = 24000
if not os.path.isfile(file_duration):
return (
epochs,
learning_rate,
batch_size_per_gpu,
max_samples,
num_warmup_updates,
"project not found !",
)
with open(file_duration, "r") as file:
data = json.load(file)
duration_list = data["duration"]
max_sample_length = max(duration_list) * sampling_rate / hop_length
total_samples = len(duration_list)
total_duration = sum(duration_list)
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
total_memory = 0
for i in range(gpu_count):
gpu_properties = torch.cuda.get_device_properties(i)
total_memory += gpu_properties.total_memory / (1024**3) # in GB
elif torch.xpu.is_available():
gpu_count = torch.xpu.device_count()
total_memory = 0
for i in range(gpu_count):
gpu_properties = torch.xpu.get_device_properties(i)
total_memory += gpu_properties.total_memory / (1024**3)
elif torch.backends.mps.is_available():
gpu_count = 1
total_memory = psutil.virtual_memory().available / (1024**3)
avg_gpu_memory = total_memory / gpu_count
# rough estimate of batch size
if batch_size_type == "frame":
batch_size_per_gpu = max(int(38400 * (avg_gpu_memory - 5) / 75), int(max_sample_length))
elif batch_size_type == "sample":
batch_size_per_gpu = int(200 / (total_duration / total_samples))
if total_samples < 64:
max_samples = int(total_samples * 0.25)
num_warmup_updates = max(num_warmup_updates, int(total_samples * 0.05))
# take 1.2M updates as the maximum
max_updates = 1200000
if batch_size_type == "frame":
mini_batch_duration = batch_size_per_gpu * gpu_count * hop_length / sampling_rate
updates_per_epoch = total_duration / mini_batch_duration
elif batch_size_type == "sample":
updates_per_epoch = total_samples / batch_size_per_gpu / gpu_count
epochs = int(max_updates / updates_per_epoch)
if finetune:
learning_rate = 1e-5
else:
learning_rate = 7.5e-5
return (
epochs,
learning_rate,
batch_size_per_gpu,
max_samples,
num_warmup_updates,
total_samples,
)
def prune_checkpoint(checkpoint_path: str, new_checkpoint_path: str, save_ema: bool, safetensors: bool) -> str:
try:
checkpoint = torch.load(checkpoint_path, weights_only=True)
print("Original Checkpoint Keys:", checkpoint.keys())
to_retain = "ema_model_state_dict" if save_ema else "model_state_dict"
try:
model_state_dict_to_retain = checkpoint[to_retain]
except KeyError:
return f"{to_retain} not found in the checkpoint."
if safetensors:
new_checkpoint_path = new_checkpoint_path.replace(".pt", ".safetensors")
save_file(model_state_dict_to_retain, new_checkpoint_path)
else:
new_checkpoint_path = new_checkpoint_path.replace(".safetensors", ".pt")
new_checkpoint = {"ema_model_state_dict": model_state_dict_to_retain}
torch.save(new_checkpoint, new_checkpoint_path)
return f"New checkpoint saved at: {new_checkpoint_path}"
except Exception as e:
return f"An error occurred: {e}"
def expand_model_embeddings(ckpt_path, new_ckpt_path, num_new_tokens=42):
seed = 666
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if ckpt_path.endswith(".safetensors"):
ckpt = load_file(ckpt_path, device="cpu")
ckpt = {"ema_model_state_dict": ckpt}
elif ckpt_path.endswith(".pt"):
ckpt = torch.load(ckpt_path, map_location="cpu")
ema_sd = ckpt.get("ema_model_state_dict", {})
embed_key_ema = "ema_model.transformer.text_embed.text_embed.weight"
old_embed_ema = ema_sd[embed_key_ema]
vocab_old = old_embed_ema.size(0)
embed_dim = old_embed_ema.size(1)
vocab_new = vocab_old + num_new_tokens
def expand_embeddings(old_embeddings):
new_embeddings = torch.zeros((vocab_new, embed_dim))
new_embeddings[:vocab_old] = old_embeddings
new_embeddings[vocab_old:] = torch.randn((num_new_tokens, embed_dim))
return new_embeddings
ema_sd[embed_key_ema] = expand_embeddings(ema_sd[embed_key_ema])
if new_ckpt_path.endswith(".safetensors"):
save_file(ema_sd, new_ckpt_path)
elif new_ckpt_path.endswith(".pt"):
torch.save(ckpt, new_ckpt_path)
return vocab_new
def vocab_count(text):
return str(len(text.split(",")))
def vocab_extend(project_name, symbols, model_type):
if symbols == "":
return "Symbols empty!"
name_project = project_name
path_project = os.path.join(path_data, name_project)
file_vocab_project = os.path.join(path_project, "vocab.txt")
file_vocab = os.path.join(path_data, "Emilia_ZH_EN_pinyin/vocab.txt")
if not os.path.isfile(file_vocab):
return f"the file {file_vocab} not found !"
symbols = symbols.split(",")
if symbols == []:
return "Symbols to extend not found."
with open(file_vocab, "r", encoding="utf-8-sig") as f:
data = f.read()
vocab = data.split("\n")
vocab_check = set(vocab)
miss_symbols = []
for item in symbols:
item = item.replace(" ", "")
if item in vocab_check:
continue
miss_symbols.append(item)
if miss_symbols == []:
return "Symbols are okay no need to extend."
size_vocab = len(vocab)
vocab.pop()
for item in miss_symbols:
vocab.append(item)
vocab.append("")
with open(file_vocab_project, "w", encoding="utf-8") as f:
f.write("\n".join(vocab))
if model_type == "F5TTS_v1_Base":
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors"))
elif model_type == "F5TTS_Base":
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
elif model_type == "E2TTS_Base":
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
vocab_size_new = len(miss_symbols)
dataset_name = name_project.replace("_pinyin", "").replace("_char", "")
new_ckpt_path = os.path.join(path_project_ckpts, dataset_name)
os.makedirs(new_ckpt_path, exist_ok=True)
# Add pretrained_ prefix to model when copying for consistency with finetune_cli.py
new_ckpt_file = os.path.join(new_ckpt_path, "pretrained_" + os.path.basename(ckpt_path))
size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=vocab_size_new)
vocab_new = "\n".join(miss_symbols)
return f"vocab old size : {size_vocab}\nvocab new size : {size}\nvocab add : {vocab_size_new}\nnew symbols :\n{vocab_new}"
def vocab_check(project_name, tokenizer_type):
name_project = project_name
path_project = os.path.join(path_data, name_project)
file_metadata = os.path.join(path_project, "metadata.csv")
file_vocab = os.path.join(path_data, "Emilia_ZH_EN_pinyin/vocab.txt")
if not os.path.isfile(file_vocab):
return f"the file {file_vocab} not found !", ""
with open(file_vocab, "r", encoding="utf-8-sig") as f:
data = f.read()
vocab = data.split("\n")
vocab = set(vocab)
if not os.path.isfile(file_metadata):
return f"the file {file_metadata} not found !", ""
with open(file_metadata, "r", encoding="utf-8-sig") as f:
data = f.read()
miss_symbols = []
miss_symbols_keep = {}
for item in data.split("\n"):
sp = item.split("|")
if len(sp) != 2:
continue
text = sp[1].strip()
if tokenizer_type == "pinyin":
text = convert_char_to_pinyin([text], polyphone=True)[0]
for t in text:
if t not in vocab and t not in miss_symbols_keep:
miss_symbols.append(t)
miss_symbols_keep[t] = t
if miss_symbols == []:
vocab_miss = ""
info = "You can train using your language !"
else:
vocab_miss = ",".join(miss_symbols)
info = f"The following {len(miss_symbols)} symbols are missing in your language\n\n"
return info, vocab_miss
def get_random_sample_prepare(project_name):
name_project = project_name
path_project = os.path.join(path_data, name_project)
file_arrow = os.path.join(path_project, "raw.arrow")
if not os.path.isfile(file_arrow):
return "", None
dataset = Dataset_.from_file(file_arrow)
random_sample = dataset.shuffle(seed=random.randint(0, 1000)).select([0])
text = "[" + " , ".join(["' " + t + " '" for t in random_sample["text"][0]]) + "]"
audio_path = random_sample["audio_path"][0]
return text, audio_path
def get_random_sample_transcribe(project_name):
name_project = project_name
path_project = os.path.join(path_data, name_project)
file_metadata = os.path.join(path_project, "metadata.csv")
if not os.path.isfile(file_metadata):
return "", None
data = ""
with open(file_metadata, "r", encoding="utf-8-sig") as f:
data = f.read()
list_data = []
for item in data.split("\n"):
sp = item.split("|")
if len(sp) != 2:
continue
# fixed audio when it is absolute
file_audio = get_correct_audio_path(sp[0], os.path.join(path_project, "wavs"))
list_data.append([file_audio, sp[1]])
if list_data == []:
return "", None
random_item = random.choice(list_data)
return random_item[1], random_item[0]
def get_random_sample_infer(project_name):
text, audio = get_random_sample_transcribe(project_name)
return (
text,
text,
audio,
)
def infer(
project, file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, use_ema, speed, seed, remove_silence
):
global last_checkpoint, last_device, tts_api, last_ema
if not os.path.isfile(file_checkpoint):
return None, "checkpoint not found!"
if training_process is not None:
device_test = "cpu"
else:
device_test = None
if last_checkpoint != file_checkpoint or last_device != device_test or last_ema != use_ema or tts_api is None:
if last_checkpoint != file_checkpoint:
last_checkpoint = file_checkpoint
if last_device != device_test:
last_device = device_test
if last_ema != use_ema:
last_ema = use_ema
vocab_file = os.path.join(path_data, project, "vocab.txt")
tts_api = F5TTS(
model=exp_name, ckpt_file=file_checkpoint, vocab_file=vocab_file, device=device_test, use_ema=use_ema
)
print("update >> ", device_test, file_checkpoint, use_ema)
if seed == -1: # -1 used for random
seed = None
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
tts_api.infer(
ref_file=ref_audio,
ref_text=ref_text.strip(),
gen_text=gen_text.strip(),
nfe_step=nfe_step,
speed=speed,
remove_silence=remove_silence,
file_wave=f.name,
seed=seed,
)
return f.name, tts_api.device, str(tts_api.seed)
def check_finetune(finetune):
return gr.update(interactive=finetune), gr.update(interactive=finetune), gr.update(interactive=finetune)
def get_checkpoints_project(project_name, is_gradio=True):
if project_name is None:
return [], ""
project_name = project_name.replace("_pinyin", "").replace("_char", "")
if os.path.isdir(path_project_ckpts):
files_checkpoints = glob(os.path.join(path_project_ckpts, project_name, "*.pt"))
# Separate pretrained and regular checkpoints
pretrained_checkpoints = [f for f in files_checkpoints if "pretrained_" in os.path.basename(f)]
regular_checkpoints = [
f
for f in files_checkpoints
if "pretrained_" not in os.path.basename(f) and "model_last.pt" not in os.path.basename(f)
]
last_checkpoint = [f for f in files_checkpoints if "model_last.pt" in os.path.basename(f)]
# Sort regular checkpoints by number
regular_checkpoints = sorted(
regular_checkpoints, key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0])
)
# Combine in order: pretrained, regular, last
files_checkpoints = pretrained_checkpoints + regular_checkpoints + last_checkpoint
else:
files_checkpoints = []
selelect_checkpoint = None if not files_checkpoints else files_checkpoints[0]
if is_gradio:
return gr.update(choices=files_checkpoints, value=selelect_checkpoint)
return files_checkpoints, selelect_checkpoint
def get_audio_project(project_name, is_gradio=True):
if project_name is None:
return [], ""
project_name = project_name.replace("_pinyin", "").replace("_char", "")
if os.path.isdir(path_project_ckpts):
files_audios = glob(os.path.join(path_project_ckpts, project_name, "samples", "*.wav"))
files_audios = sorted(files_audios, key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0]))
files_audios = [item.replace("_gen.wav", "") for item in files_audios if item.endswith("_gen.wav")]
else:
files_audios = []
selelect_checkpoint = None if not files_audios else files_audios[0]
if is_gradio:
return gr.update(choices=files_audios, value=selelect_checkpoint)
return files_audios, selelect_checkpoint
def get_gpu_stats():
gpu_stats = ""
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
for i in range(gpu_count):
gpu_name = torch.cuda.get_device_name(i)
gpu_properties = torch.cuda.get_device_properties(i)
total_memory = gpu_properties.total_memory / (1024**3) # in GB
allocated_memory = torch.cuda.memory_allocated(i) / (1024**2) # in MB
reserved_memory = torch.cuda.memory_reserved(i) / (1024**2) # in MB
gpu_stats += (
f"GPU {i} Name: {gpu_name}\n"
f"Total GPU memory (GPU {i}): {total_memory:.2f} GB\n"
f"Allocated GPU memory (GPU {i}): {allocated_memory:.2f} MB\n"
f"Reserved GPU memory (GPU {i}): {reserved_memory:.2f} MB\n\n"
)
elif torch.xpu.is_available():
gpu_count = torch.xpu.device_count()
for i in range(gpu_count):
gpu_name = torch.xpu.get_device_name(i)
gpu_properties = torch.xpu.get_device_properties(i)
total_memory = gpu_properties.total_memory / (1024**3) # in GB
allocated_memory = torch.xpu.memory_allocated(i) / (1024**2) # in MB
reserved_memory = torch.xpu.memory_reserved(i) / (1024**2) # in MB
gpu_stats += (
f"GPU {i} Name: {gpu_name}\n"
f"Total GPU memory (GPU {i}): {total_memory:.2f} GB\n"
f"Allocated GPU memory (GPU {i}): {allocated_memory:.2f} MB\n"
f"Reserved GPU memory (GPU {i}): {reserved_memory:.2f} MB\n\n"
)
elif torch.backends.mps.is_available():
gpu_count = 1
gpu_stats += "MPS GPU\n"
total_memory = psutil.virtual_memory().total / (
1024**3
) # Total system memory (MPS doesn't have its own memory)
allocated_memory = 0
reserved_memory = 0
gpu_stats += (
f"Total system memory: {total_memory:.2f} GB\n"
f"Allocated GPU memory (MPS): {allocated_memory:.2f} MB\n"
f"Reserved GPU memory (MPS): {reserved_memory:.2f} MB\n"
)
else:
gpu_stats = "No GPU available"
return gpu_stats
def get_cpu_stats():
cpu_usage = psutil.cpu_percent(interval=1)
memory_info = psutil.virtual_memory()
memory_used = memory_info.used / (1024**2)
memory_total = memory_info.total / (1024**2)
memory_percent = memory_info.percent
pid = os.getpid()
process = psutil.Process(pid)
nice_value = process.nice()
cpu_stats = (
f"CPU Usage: {cpu_usage:.2f}%\n"
f"System Memory: {memory_used:.2f} MB used / {memory_total:.2f} MB total ({memory_percent}% used)\n"
f"Process Priority (Nice value): {nice_value}"
)
return cpu_stats
def get_combined_stats():
gpu_stats = get_gpu_stats()
cpu_stats = get_cpu_stats()
combined_stats = f"### GPU Stats\n{gpu_stats}\n\n### CPU Stats\n{cpu_stats}"
return combined_stats
def get_audio_select(file_sample):
select_audio_ref = file_sample
select_audio_gen = file_sample
if file_sample is not None:
select_audio_ref += "_ref.wav"
select_audio_gen += "_gen.wav"
return select_audio_ref, select_audio_gen
with gr.Blocks() as app:
gr.Markdown(
"""
# F5 TTS Automatic Finetune
This is a local web UI for F5 TTS finetuning support. This app supports the following TTS models:
* [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
* [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
The pretrained checkpoints support English and Chinese.
For tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussions/143)
"""
)
with gr.Row():
projects, projects_selelect = get_list_projects()
tokenizer_type = gr.Radio(label="Tokenizer Type", choices=["pinyin", "char", "custom"], value="pinyin")
project_name = gr.Textbox(label="Project Name", value="my_speak")
bt_create = gr.Button("Create a New Project")
with gr.Row():
cm_project = gr.Dropdown(
choices=projects, value=projects_selelect, label="Project", allow_custom_value=True, scale=6
)
ch_refresh_project = gr.Button("Refresh", scale=1)
bt_create.click(fn=create_data_project, inputs=[project_name, tokenizer_type], outputs=[cm_project])
with gr.Tabs():
with gr.TabItem("Transcribe Data"):
gr.Markdown("""```plaintext
Skip this step if you have your dataset, metadata.csv, and a folder wavs with all the audio files.
```""")
ch_manual = gr.Checkbox(label="Audio from Path", value=False)
mark_info_transcribe = gr.Markdown(
"""```plaintext
Place your 'wavs' folder and 'metadata.csv' file in the '{your_project_name}' directory.
my_speak/
│
└── dataset/
├── audio1.wav
└── audio2.wav
...
```""",
visible=False,
)
audio_speaker = gr.File(label="Voice", type="filepath", file_count="multiple")
txt_lang = gr.Textbox(label="Language", value="English")
bt_transcribe = bt_create = gr.Button("Transcribe")
txt_info_transcribe = gr.Textbox(label="Info", value="")
bt_transcribe.click(
fn=transcribe_all,
inputs=[cm_project, audio_speaker, txt_lang, ch_manual],
outputs=[txt_info_transcribe],
)
ch_manual.change(fn=check_user, inputs=[ch_manual], outputs=[audio_speaker, mark_info_transcribe])
random_sample_transcribe = gr.Button("Random Sample")
with gr.Row():
random_text_transcribe = gr.Textbox(label="Text")
random_audio_transcribe = gr.Audio(label="Audio", type="filepath")
random_sample_transcribe.click(
fn=get_random_sample_transcribe,
inputs=[cm_project],
outputs=[random_text_transcribe, random_audio_transcribe],
)
with gr.TabItem("Vocab Check"):
gr.Markdown("""```plaintext
Check the vocabulary for fine-tuning Emilia_ZH_EN to ensure all symbols are included. For fine-tuning a new language.
```""")
check_button = gr.Button("Check Vocab")
txt_info_check = gr.Textbox(label="Info", value="")
gr.Markdown("""```plaintext
Using the extended model, you can finetune to a new language that is missing symbols in the vocab. This creates a new model with a new vocabulary size and saves it in your ckpts/project folder.
```""")
exp_name_extend = gr.Radio(
label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base"
)
with gr.Row():
txt_extend = gr.Textbox(
label="Symbols",
value="",
placeholder="To add new symbols, make sure to use ',' for each symbol",
scale=6,
)
txt_count_symbol = gr.Textbox(label="New Vocab Size", value="", scale=1)
extend_button = gr.Button("Extend")
txt_info_extend = gr.Textbox(label="Info", value="")
txt_extend.change(vocab_count, inputs=[txt_extend], outputs=[txt_count_symbol])
check_button.click(
fn=vocab_check, inputs=[cm_project, tokenizer_type], outputs=[txt_info_check, txt_extend]
)
extend_button.click(
fn=vocab_extend, inputs=[cm_project, txt_extend, exp_name_extend], outputs=[txt_info_extend]
)
with gr.TabItem("Prepare Data"):
gr.Markdown("""```plaintext
Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt
```""")
gr.Markdown(
"""```plaintext
Place all your "wavs" folder and your "metadata.csv" file in your project name directory.
Supported audio formats: "wav", "mp3", "aac", "flac", "m4a", "alac", "ogg", "aiff", "wma", "amr"
Example wav format:
my_speak/
│
├── wavs/
│ ├── audio1.wav
│ └── audio2.wav
| ...
│
└── metadata.csv
File format metadata.csv:
audio1|text1 or audio1.wav|text1 or your_path/audio1.wav|text1
audio2|text1 or audio2.wav|text1 or your_path/audio2.wav|text1
...
```"""
)
ch_tokenizern = gr.Checkbox(label="Create Vocabulary", value=False, visible=False)
bt_prepare = bt_create = gr.Button("Prepare")
txt_info_prepare = gr.Textbox(label="Info", value="")
txt_vocab_prepare = gr.Textbox(label="Vocab", value="")
bt_prepare.click(
fn=create_metadata, inputs=[cm_project, ch_tokenizern], outputs=[txt_info_prepare, txt_vocab_prepare]
)
random_sample_prepare = gr.Button("Random Sample")
with gr.Row():
random_text_prepare = gr.Textbox(label="Tokenizer")
random_audio_prepare = gr.Audio(label="Audio", type="filepath")
random_sample_prepare.click(
fn=get_random_sample_prepare, inputs=[cm_project], outputs=[random_text_prepare, random_audio_prepare]
)
with gr.TabItem("Train Model"):
gr.Markdown("""```plaintext
The auto-setting is still experimental. Set a large value of epoch if not sure; and keep last N checkpoints if limited disk space.
If you encounter a memory error, try reducing the batch size per GPU to a smaller number.
```""")
with gr.Row():
exp_name = gr.Radio(label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"])
tokenizer_file = gr.Textbox(label="Tokenizer File")
file_checkpoint_train = gr.Textbox(label="Path to the Pretrained Checkpoint")
with gr.Row():
ch_finetune = bt_create = gr.Checkbox(label="Finetune")
lb_samples = gr.Label(label="Samples")
bt_calculate = bt_create = gr.Button("Auto Settings")
with gr.Row():
epochs = gr.Number(label="Epochs")
learning_rate = gr.Number(label="Learning Rate", step=0.5e-5)
max_grad_norm = gr.Number(label="Max Gradient Norm")
num_warmup_updates = gr.Number(label="Warmup Updates")
with gr.Row():
batch_size_type = gr.Radio(
label="Batch Size Type",
choices=["frame", "sample"],
info="frame is calculated as seconds * sampling_rate / hop_length",
)
batch_size_per_gpu = gr.Number(label="Batch Size per GPU", info="N frames or N samples")
grad_accumulation_steps = gr.Number(
label="Gradient Accumulation Steps", info="Effective batch size is multiplied by this value"
)
max_samples = gr.Number(label="Max Samples", info="Maximum number of samples per single GPU batch")
with gr.Row():
save_per_updates = gr.Number(
label="Save per Updates",
info="Save intermediate checkpoints every N updates",
minimum=10,
)
keep_last_n_checkpoints = gr.Number(
label="Keep Last N Checkpoints",
step=1,
precision=0,
info="-1 to keep all, 0 to not save intermediate, > 0 to keep last N",
minimum=-1,
)
last_per_updates = gr.Number(
label="Last per Updates",
info="Save latest checkpoint with suffix _last.pt every N updates",
minimum=10,
)
gr.Radio(label="") # placeholder
with gr.Row():
ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
mixed_precision = gr.Radio(label="Mixed Precision", choices=["none", "fp16", "bf16"])
cd_logger = gr.Radio(label="Logger", choices=["none", "wandb", "tensorboard"])
with gr.Column():
start_button = gr.Button("Start Training")
stop_button = gr.Button("Stop Training", interactive=False)
if projects_selelect is not None:
(
exp_name_value,
learning_rate_value,
batch_size_per_gpu_value,
batch_size_type_value,
max_samples_value,
grad_accumulation_steps_value,
max_grad_norm_value,
epochs_value,
num_warmup_updates_value,
save_per_updates_value,
keep_last_n_checkpoints_value,
last_per_updates_value,
finetune_value,
file_checkpoint_train_value,
tokenizer_type_value,
tokenizer_file_value,
mixed_precision_value,
logger_value,
bnb_optimizer_value,
) = load_settings(projects_selelect)
# Assigning values to the respective components
exp_name.value = exp_name_value
learning_rate.value = learning_rate_value
batch_size_per_gpu.value = batch_size_per_gpu_value
batch_size_type.value = batch_size_type_value
max_samples.value = max_samples_value
grad_accumulation_steps.value = grad_accumulation_steps_value
max_grad_norm.value = max_grad_norm_value
epochs.value = epochs_value
num_warmup_updates.value = num_warmup_updates_value
save_per_updates.value = save_per_updates_value
keep_last_n_checkpoints.value = keep_last_n_checkpoints_value
last_per_updates.value = last_per_updates_value
ch_finetune.value = finetune_value
file_checkpoint_train.value = file_checkpoint_train_value
tokenizer_type.value = tokenizer_type_value
tokenizer_file.value = tokenizer_file_value
mixed_precision.value = mixed_precision_value
cd_logger.value = logger_value
ch_8bit_adam.value = bnb_optimizer_value
ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True)
txt_info_train = gr.Textbox(label="Info", value="")
list_audios, select_audio = get_audio_project(projects_selelect, False)
select_audio_ref = select_audio
select_audio_gen = select_audio
if select_audio is not None:
select_audio_ref += "_ref.wav"
select_audio_gen += "_gen.wav"
with gr.Row():
ch_list_audio = gr.Dropdown(
choices=list_audios,
value=select_audio,
label="Audios",
allow_custom_value=True,
scale=6,
interactive=True,
)
bt_stream_audio = gr.Button("Refresh", scale=1)
bt_stream_audio.click(fn=get_audio_project, inputs=[cm_project], outputs=[ch_list_audio])
cm_project.change(fn=get_audio_project, inputs=[cm_project], outputs=[ch_list_audio])
with gr.Row():
audio_ref_stream = gr.Audio(label="Original", type="filepath", value=select_audio_ref)
audio_gen_stream = gr.Audio(label="Generate", type="filepath", value=select_audio_gen)
ch_list_audio.change(
fn=get_audio_select,
inputs=[ch_list_audio],
outputs=[audio_ref_stream, audio_gen_stream],
)
start_button.click(
fn=start_training,
inputs=[
cm_project,
exp_name,
learning_rate,
batch_size_per_gpu,
batch_size_type,
max_samples,
grad_accumulation_steps,
max_grad_norm,
epochs,
num_warmup_updates,
save_per_updates,
keep_last_n_checkpoints,
last_per_updates,
ch_finetune,
file_checkpoint_train,
tokenizer_type,
tokenizer_file,
mixed_precision,
ch_stream,
cd_logger,
ch_8bit_adam,
],
outputs=[txt_info_train, start_button, stop_button],
)
stop_button.click(fn=stop_training, outputs=[txt_info_train, start_button, stop_button])
bt_calculate.click(
fn=calculate_train,
inputs=[
cm_project,
epochs,
learning_rate,
batch_size_per_gpu,
batch_size_type,
max_samples,
num_warmup_updates,
ch_finetune,
],
outputs=[
epochs,
learning_rate,
batch_size_per_gpu,
max_samples,
num_warmup_updates,
lb_samples,
],
)
ch_finetune.change(
check_finetune, inputs=[ch_finetune], outputs=[file_checkpoint_train, tokenizer_file, tokenizer_type]
)
def setup_load_settings():
output_components = [
exp_name,
learning_rate,
batch_size_per_gpu,
batch_size_type,
max_samples,
grad_accumulation_steps,
max_grad_norm,
epochs,
num_warmup_updates,
save_per_updates,
keep_last_n_checkpoints,
last_per_updates,
ch_finetune,
file_checkpoint_train,
tokenizer_type,
tokenizer_file,
mixed_precision,
cd_logger,
ch_8bit_adam,
]
return output_components
outputs = setup_load_settings()
cm_project.change(
fn=load_settings,
inputs=[cm_project],
outputs=outputs,
)
ch_refresh_project.click(
fn=load_settings,
inputs=[cm_project],
outputs=outputs,
)
with gr.TabItem("Test Model"):
gr.Markdown("""```plaintext
Check the use_ema setting (True or False) for your model to see what works best for you. Set seed to -1 for random.
```""")
exp_name = gr.Radio(
label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base"
)
list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False)
with gr.Row():
nfe_step = gr.Number(label="NFE Step", value=32)
speed = gr.Slider(label="Speed", value=1.0, minimum=0.3, maximum=2.0, step=0.1)
seed = gr.Number(label="Random Seed", value=-1, minimum=-1)
remove_silence = gr.Checkbox(label="Remove Silence")
with gr.Row():
ch_use_ema = gr.Checkbox(
label="Use EMA", value=True, info="Turn off at early stage might offer better results"
)
cm_checkpoint = gr.Dropdown(
choices=list_checkpoints, value=checkpoint_select, label="Checkpoints", allow_custom_value=True
)
bt_checkpoint_refresh = gr.Button("Refresh")
random_sample_infer = gr.Button("Random Sample")
ref_text = gr.Textbox(label="Reference Text")
ref_audio = gr.Audio(label="Reference Audio", type="filepath")
gen_text = gr.Textbox(label="Text to Generate")
random_sample_infer.click(
fn=get_random_sample_infer, inputs=[cm_project], outputs=[ref_text, gen_text, ref_audio]
)
with gr.Row():
txt_info_gpu = gr.Textbox("", label="Inference on Device :")
seed_info = gr.Textbox(label="Used Random Seed :")
check_button_infer = gr.Button("Inference")
gen_audio = gr.Audio(label="Generated Audio", type="filepath")
check_button_infer.click(
fn=infer,
inputs=[
cm_project,
cm_checkpoint,
exp_name,
ref_text,
ref_audio,
gen_text,
nfe_step,
ch_use_ema,
speed,
seed,
remove_silence,
],
outputs=[gen_audio, txt_info_gpu, seed_info],
)
bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
cm_project.change(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
with gr.TabItem("Prune Checkpoint"):
gr.Markdown("""```plaintext
Reduce the Base model size from 5GB to 1.3GB. The new checkpoint file prunes out optimizer and etc., can be used for inference or finetuning afterward, but not able to resume pretraining.
```""")
txt_path_checkpoint = gr.Textbox(label="Path to Checkpoint:")
txt_path_checkpoint_small = gr.Textbox(label="Path to Output:")
with gr.Row():
ch_save_ema = gr.Checkbox(label="Save EMA checkpoint", value=True)
ch_safetensors = gr.Checkbox(label="Save with safetensors format", value=True)
txt_info_reduse = gr.Textbox(label="Info", value="")
reduse_button = gr.Button("Prune")
reduse_button.click(
fn=prune_checkpoint,
inputs=[txt_path_checkpoint, txt_path_checkpoint_small, ch_save_ema, ch_safetensors],
outputs=[txt_info_reduse],
)
with gr.TabItem("System Info"):
output_box = gr.Textbox(label="GPU and CPU Information", lines=20)
def update_stats():
return get_combined_stats()
update_button = gr.Button("Update Stats")
update_button.click(fn=update_stats, outputs=output_box)
def auto_update():
yield gr.update(value=update_stats())
gr.update(fn=auto_update, inputs=[], outputs=output_box)
@click.command()
@click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
@click.option("--host", "-H", default=None, help="Host to run the app on")
@click.option(
"--share",
"-s",
default=False,
is_flag=True,
help="Share the app via Gradio share link",
)
@click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
def main(port, host, share, api):
global app
print("Starting app...")
app.queue(api_open=api).launch(server_name=host, server_port=port, share=share, show_api=api)
if __name__ == "__main__":
main()
================================================
FILE: src/f5_tts/train/train.py
================================================
# training script.
import os
from importlib.resources import files
import hydra
from omegaconf import OmegaConf
from f5_tts.model import CFM, Trainer
from f5_tts.model.dataset import load_dataset
from f5_tts.model.utils import get_tokenizer
os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to root of project (local editable)
@hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
def main(model_cfg):
model_cls = hydra.utils.get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
tokenizer = model_cfg.model.tokenizer
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
wandb_project = model_cfg.ckpts.get("wandb_project", "CFM-TTS")
wandb_run_name = model_cfg.ckpts.get(
"wandb_run_name",
f"{model_cfg.model.name}_{mel_spec_type}_{model_cfg.model.tokenizer}_{model_cfg.datasets.name}",
)
wandb_resume_id = model_cfg.ckpts.get("wandb_resume_id", None)
# set text tokenizer
if tokenizer != "custom":
tokenizer_path = model_cfg.datasets.name
else:
tokenizer_path = model_cfg.model.tokenizer_path
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
# set model
model = CFM(
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=model_cfg.model.mel_spec.n_mel_channels),
mel_spec_kwargs=model_cfg.model.mel_spec,
vocab_char_map=vocab_char_map,
)
# init trainer
trainer = Trainer(
model,
epochs=model_cfg.optim.epochs,
learning_rate=model_cfg.optim.learning_rate,
num_warmup_updates=model_cfg.optim.num_warmup_updates,
save_per_updates=model_cfg.ckpts.save_per_updates,
keep_last_n_checkpoints=model_cfg.ckpts.keep_last_n_checkpoints,
checkpoint_path=str(files("f5_tts").joinpath(f"../../{model_cfg.ckpts.save_dir}")),
batch_size_per_gpu=model_cfg.datasets.batch_size_per_gpu,
batch_size_type=model_cfg.datasets.batch_size_type,
max_samples=model_cfg.datasets.max_samples,
grad_accumulation_steps=model_cfg.optim.grad_accumulation_steps,
max_grad_norm=model_cfg.optim.max_grad_norm,
logger=model_cfg.ckpts.logger,
wandb_project=wandb_project,
wandb_run_name=wandb_run_name,
wandb_resume_id=wandb_resume_id,
last_per_updates=model_cfg.ckpts.last_per_updates,
log_samples=model_cfg.ckpts.log_samples,
bnb_optimizer=model_cfg.optim.bnb_optimizer,
mel_spec_type=mel_spec_type,
is_local_vocoder=model_cfg.model.vocoder.is_local,
local_vocoder_path=model_cfg.model.vocoder.local_path,
model_cfg_dict=OmegaConf.to_container(model_cfg, resolve=True),
)
train_dataset = load_dataset(model_cfg.datasets.name, tokenizer, mel_spec_kwargs=model_cfg.model.mel_spec)
trainer.train(
train_dataset,
num_workers=model_cfg.datasets.num_workers,
resumable_with_seed=666, # seed for shuffling dataset
)
if __name__ == "__main__":
main()