Full Code of SWivid/F5-TTS for AI

main 623c96c29496 cached
90 files
571.3 KB
152.7k tokens
390 symbols
1 requests
Download .txt
Showing preview only (601K chars total). Download the full file or copy to clipboard to get everything.
Repository: SWivid/F5-TTS
Branch: main
Commit: 623c96c29496
Files: 90
Total size: 571.3 KB

Directory structure:
gitextract_4jvakfwh/

├── .github/
│   ├── ISSUE_TEMPLATE/
│   │   ├── bug_report.yml
│   │   ├── config.yml
│   │   ├── feature_request.yml
│   │   ├── help_wanted.yml
│   │   └── question.yml
│   └── workflows/
│       ├── pre-commit.yaml
│       ├── publish-docker-image.yaml
│       └── publish-pypi.yaml
├── .gitignore
├── .gitmodules
├── .pre-commit-config.yaml
├── Dockerfile
├── LICENSE
├── README.md
├── pyproject.toml
├── ruff.toml
└── src/
    └── f5_tts/
        ├── api.py
        ├── configs/
        │   ├── E2TTS_Base.yaml
        │   ├── E2TTS_Small.yaml
        │   ├── F5TTS_Base.yaml
        │   ├── F5TTS_Small.yaml
        │   └── F5TTS_v1_Base.yaml
        ├── eval/
        │   ├── README.md
        │   ├── ecapa_tdnn.py
        │   ├── eval_infer_batch.py
        │   ├── eval_infer_batch.sh
        │   ├── eval_infer_batch_example.sh
        │   ├── eval_librispeech_test_clean.py
        │   ├── eval_seedtts_testset.py
        │   ├── eval_utmos.py
        │   └── utils_eval.py
        ├── infer/
        │   ├── README.md
        │   ├── SHARED.md
        │   ├── examples/
        │   │   ├── basic/
        │   │   │   └── basic.toml
        │   │   ├── multi/
        │   │   │   ├── country.flac
        │   │   │   ├── main.flac
        │   │   │   ├── story.toml
        │   │   │   ├── story.txt
        │   │   │   └── town.flac
        │   │   └── vocab.txt
        │   ├── infer_cli.py
        │   ├── infer_gradio.py
        │   ├── speech_edit.py
        │   └── utils_infer.py
        ├── model/
        │   ├── __init__.py
        │   ├── backbones/
        │   │   ├── README.md
        │   │   ├── dit.py
        │   │   ├── mmdit.py
        │   │   └── unett.py
        │   ├── cfm.py
        │   ├── dataset.py
        │   ├── modules.py
        │   ├── trainer.py
        │   └── utils.py
        ├── runtime/
        │   └── triton_trtllm/
        │       ├── .gitignore
        │       ├── Dockerfile.server
        │       ├── README.md
        │       ├── benchmark.py
        │       ├── client_grpc.py
        │       ├── client_http.py
        │       ├── docker-compose.yml
        │       ├── model_repo_f5_tts/
        │       │   ├── f5_tts/
        │       │   │   ├── 1/
        │       │   │   │   ├── f5_tts_trtllm.py
        │       │   │   │   └── model.py
        │       │   │   └── config.pbtxt
        │       │   └── vocoder/
        │       │       ├── 1/
        │       │       │   └── .gitkeep
        │       │       └── config.pbtxt
        │       ├── patch/
        │       │   ├── __init__.py
        │       │   └── f5tts/
        │       │       ├── model.py
        │       │       └── modules.py
        │       ├── run.sh
        │       └── scripts/
        │           ├── conv_stft.py
        │           ├── convert_checkpoint.py
        │           ├── export_vocoder_to_onnx.py
        │           ├── export_vocos_trt.sh
        │           └── fill_template.py
        ├── scripts/
        │   ├── count_max_epoch.py
        │   ├── count_max_epoch_precise.py
        │   └── count_params_gflops.py
        ├── socket_client.py
        ├── socket_server.py
        └── train/
            ├── README.md
            ├── datasets/
            │   ├── prepare_csv_wavs.py
            │   ├── prepare_emilia.py
            │   ├── prepare_emilia_v2.py
            │   ├── prepare_libritts.py
            │   ├── prepare_ljspeech.py
            │   └── prepare_wenetspeech4tts.py
            ├── finetune_cli.py
            ├── finetune_gradio.py
            └── train.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .github/ISSUE_TEMPLATE/bug_report.yml
================================================
name: "Bug Report"
description: |
  Please provide as much details to help address the issue more efficiently, including input, output, logs and screenshots.
labels:
  - bug
body:
  - type: checkboxes
    attributes:
      label: Checks
      description: "To ensure timely help, please confirm the following:"
      options:
        - label: This template is only for bug reports, usage problems go with 'Help Wanted'.
          required: true
        - label: I have thoroughly reviewed the project documentation but couldn't find information to solve my problem.
          required: true
        - label: I have searched for existing issues, including closed ones, and couldn't find a solution.
          required: true
        - label: I am using English to submit this issue to facilitate community communication.
          required: true
  - type: textarea
    attributes:
      label: Environment Details
      description: "Provide details including OS, GPU info, Python version, any relevant software or dependencies, and trainer setting."
      placeholder: e.g., CentOS Linux 7, 4 * RTX 3090, Python 3.10, torch==2.3.0+cu118, cuda 11.8, config yaml is ...
    validations:
      required: true
  - type: textarea
    attributes:
      label: Steps to Reproduce
      description: |
        Include detailed steps, screenshots, and logs. Use the correct markdown syntax for code blocks.
      placeholder: |
        1. Create a new conda environment.
        2. Clone the repository, install as local editable and properly set up.
        3. Run the command: `accelerate launch src/f5_tts/train/train.py`.
        4. Have following error message... (attach logs).
    validations:
      required: true
  - type: textarea
    attributes:
      label: ✔️ Expected Behavior
      placeholder: Describe in detail what you expected to happen.
    validations:
      required: false
  - type: textarea
    attributes:
      label: ❌ Actual Behavior
      placeholder: Describe in detail what actually happened.
    validations:
      required: false

================================================
FILE: .github/ISSUE_TEMPLATE/config.yml
================================================
blank_issues_enabled: false


================================================
FILE: .github/ISSUE_TEMPLATE/feature_request.yml
================================================
name: "Feature Request"
description: |
  Some constructive suggestions and new ideas regarding current repo.
labels:
  - enhancement
body:
  - type: checkboxes
    attributes:
      label: Checks
      description: "To help us grasp quickly, please confirm the following:"
      options:
        - label: This template is only for feature request.
          required: true
        - label: I have thoroughly reviewed the project documentation but couldn't find any relevant information that meets my needs.
          required: true
        - label: I have searched for existing issues, including closed ones, and found not discussion yet.
          required: true
        - label: I am using English to submit this issue to facilitate community communication.
          required: true
  - type: textarea
    attributes:
      label: 1. Is this request related to a challenge you're experiencing? Tell us your story.
      description: |
        Describe the specific problem or scenario you're facing in detail. For example:
        *"I was trying to use [feature] for [specific task], but encountered [issue]. This was frustrating because...."*
      placeholder: Please describe the situation in as much detail as possible.
    validations:
      required: true

  - type: textarea
    attributes:
      label: 2. What is your suggested solution?
      description: |
        Provide a clear description of the feature or enhancement you'd like to propose. 
        How would this feature solve your issue or improve the project?
      placeholder: Describe your idea or proposed solution here.
    validations:
      required: true

  - type: textarea
    attributes:
      label: 3. Additional context or comments
      description: |
        Any other relevant information, links, documents, or screenshots that provide clarity. 
        Use this section for anything not covered above.
      placeholder: Add any extra details here.
    validations:
      required: false

  - type: checkboxes
    attributes:
      label: 4. Can you help us with this feature?
      description: |
        Let us know if you're interested in contributing. This is not a commitment but a way to express interest in collaboration.
      options:
        - label: I am interested in contributing to this feature.
          required: false

  - type: markdown
    attributes:
      value: |
        **Note:** Please submit only one request per issue to keep discussions focused and manageable.

================================================
FILE: .github/ISSUE_TEMPLATE/help_wanted.yml
================================================
name: "Help Wanted"
description: |
  Please provide as much details to help address the issue more efficiently, including input, output, logs and screenshots.
labels:
  - help wanted
body:
  - type: checkboxes
    attributes:
      label: Checks
      description: "To ensure timely help, please confirm the following:"
      options:
        - label: This template is only for usage issues encountered.
          required: true
        - label: I have thoroughly reviewed the project documentation but couldn't find information to solve my problem.
          required: true
        - label: I have searched for existing issues, including closed ones, and couldn't find a solution.
          required: true
        - label: I am using English to submit this issue to facilitate community communication.
          required: true
  - type: textarea
    attributes:
      label: Environment Details
      description: "Provide details such as OS, Python version, and any relevant software or dependencies."
      placeholder: |
        e.g., macOS 13.5, Python 3.10, torch==2.3.0, Gradio 4.44.1
        If training or finetuning related, provide detailed configuration including GPU info and training setup.
    validations:
      required: true
  - type: textarea
    attributes:
      label: Steps to Reproduce
      description: |
        Include detailed steps, screenshots, and logs. Provide used prompt wav and text. Use the correct markdown syntax for code blocks.
      placeholder: |
        1. Create a new conda environment.
        2. Clone the repository and install as pip package.
        3. Run the command: `f5-tts_infer-gradio` with no ref_text provided.
        4. Stuck there with the following message... (attach logs and also error msg e.g. after ctrl-c).
        5. Prompt & generated wavs are [change suffix to .mp4 to enable direct upload or pack all to .zip].
        6. Reference audio's transcription or provided ref_text is `xxx`, and text to generate is `xxx`.
    validations:
      required: true
  - type: textarea
    attributes:
      label: ✔️ Expected Behavior
      placeholder: Describe what you expected to happen in detail, e.g. output a generated audio.
    validations:
      required: false
  - type: textarea
    attributes:
      label: ❌ Actual Behavior
      placeholder: Describe what actually happened in detail, failure messages, etc.
    validations:
      required: false

================================================
FILE: .github/ISSUE_TEMPLATE/question.yml
================================================
name: "Question"
description: |
  Research question or pure inquiry about the project, usage issue goes with "help wanted".
labels:
  - question
body:
  - type: checkboxes
    attributes:
      label: Checks
      description: "To help us grasp quickly, please confirm the following:"
      options:
        - label: This template is only for research question, not usage problems, feature requests or bug reports.
          required: true
        - label: I have thoroughly reviewed the project documentation and read the related paper(s).
          required: true
        - label: I have searched for existing issues, including closed ones, no similar questions.
          required: true
        - label: I am using English to submit this issue to facilitate community communication.
          required: true
  - type: textarea
    attributes:
      label: Question details
      description: |
        Question details, clearly stated using proper markdown syntax.
    validations:
      required: true


================================================
FILE: .github/workflows/pre-commit.yaml
================================================
name: pre-commit

on:
  pull_request:
  push:
    branches: [main]

jobs:
  pre-commit:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v3
      - uses: actions/setup-python@v3
      - uses: pre-commit/action@v3.0.1


================================================
FILE: .github/workflows/publish-docker-image.yaml
================================================
name: Create and publish a Docker image

# Configures this workflow to run every time a change is pushed to the branch called `release`.
on:
  push:
    branches: ['main']

# Defines two custom environment variables for the workflow. These are used for the Container registry domain, and a name for the Docker image that this workflow builds.
env:
  REGISTRY: ghcr.io
  IMAGE_NAME: ${{ github.repository }}

# There is a single job in this workflow. It's configured to run on the latest available version of Ubuntu.
jobs:
  build-and-push-image:
    runs-on: ubuntu-latest
    # Sets the permissions granted to the `GITHUB_TOKEN` for the actions in this job.
    permissions:
      contents: read
      packages: write
      # 
    steps:
      - name: Checkout repository
        uses: actions/checkout@v4
      - name: Free Up GitHub Actions Ubuntu Runner Disk Space 🔧
        uses: jlumbroso/free-disk-space@main
        with:
          # This might remove tools that are actually needed, if set to "true" but frees about 6 GB
          tool-cache: false

          # All of these default to true, but feel free to set to "false" if necessary for your workflow
          android: true
          dotnet: true
          haskell: true
          large-packages: false
          swap-storage: false
          docker-images: false
      # Uses the `docker/login-action` action to log in to the Container registry registry using the account and password that will publish the packages. Once published, the packages are scoped to the account defined here.
      - name: Log in to the Container registry
        uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1
        with:
          registry: ${{ env.REGISTRY }}
          username: ${{ github.actor }}
          password: ${{ secrets.GITHUB_TOKEN }}
      # This step uses [docker/metadata-action](https://github.com/docker/metadata-action#about) to extract tags and labels that will be applied to the specified image. The `id` "meta" allows the output of this step to be referenced in a subsequent step. The `images` value provides the base name for the tags and labels.
      - name: Extract metadata (tags, labels) for Docker
        id: meta
        uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7
        with:
          images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
      # This step uses the `docker/build-push-action` action to build the image, based on your repository's `Dockerfile`. If the build succeeds, it pushes the image to GitHub Packages.
      # It uses the `context` parameter to define the build's context as the set of files located in the specified path. For more information, see "[Usage](https://github.com/docker/build-push-action#usage)" in the README of the `docker/build-push-action` repository.
      # It uses the `tags` and `labels` parameters to tag and label the image with the output from the "meta" step.
      - name: Build and push Docker image
        uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4
        with:
          context: .
          push: true
          tags: ${{ steps.meta.outputs.tags }}
          labels: ${{ steps.meta.outputs.labels }}


================================================
FILE: .github/workflows/publish-pypi.yaml
================================================
# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.

# GitHub recommends pinning actions to a commit SHA.
# To get a newer version, you will need to update the SHA.
# You can also reference a tag or branch, but the action may change without warning.

name: Upload Python Package

on:
  release:
    types: [published]

permissions:
  contents: read

jobs:
  release-build:
    runs-on: ubuntu-latest

    steps:
      - uses: actions/checkout@v4

      - uses: actions/setup-python@v5
        with:
          python-version: "3.x"

      - name: Build release distributions
        run: |
          # NOTE: put your own distribution build steps here.
          python -m pip install build
          python -m build

      - name: Upload distributions
        uses: actions/upload-artifact@v4
        with:
          name: release-dists
          path: dist/

  pypi-publish:
    runs-on: ubuntu-latest

    needs:
      - release-build

    permissions:
      # IMPORTANT: this permission is mandatory for trusted publishing
      id-token: write

    # Dedicated environments with protections for publishing are strongly recommended.
    environment:
      name: pypi
      # OPTIONAL: uncomment and update to include your PyPI project URL in the deployment status:
      # url: https://pypi.org/p/YOURPROJECT

    steps:
      - name: Retrieve release distributions
        uses: actions/download-artifact@v4
        with:
          name: release-dists
          path: dist/

      - name: Publish release distributions to PyPI
        uses: pypa/gh-action-pypi-publish@release/v1


================================================
FILE: .gitignore
================================================
# Customed
.vscode/
tests/
runs/
data/
ckpts/
wandb/
results/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# poetry
#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
#   in version control.
#   https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
#  and can be added to the global gitignore or merged into this file.  For a more nuclear
#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/


================================================
FILE: .gitmodules
================================================
[submodule "src/third_party/BigVGAN"]
	path = src/third_party/BigVGAN
	url = https://github.com/NVIDIA/BigVGAN.git


================================================
FILE: .pre-commit-config.yaml
================================================
repos:
  - repo: https://github.com/astral-sh/ruff-pre-commit
    # Ruff version.
    rev: v0.11.2
    hooks:
      - id: ruff
        name: ruff linter
        args: [--fix]
      - id: ruff-format
        name: ruff formatter
      - id: ruff
        name: ruff sorter
        args: [--select, I, --fix]
  - repo: https://github.com/pre-commit/pre-commit-hooks
    rev: v5.0.0
    hooks:
      - id: check-yaml


================================================
FILE: Dockerfile
================================================
FROM pytorch/pytorch:2.4.0-cuda12.4-cudnn9-devel

USER root

ARG DEBIAN_FRONTEND=noninteractive

LABEL github_repo="https://github.com/SWivid/F5-TTS"

RUN set -x \
    && apt-get update \
    && apt-get -y install wget curl man git less openssl libssl-dev unzip unar build-essential aria2 tmux vim \
    && apt-get install -y openssh-server sox libsox-fmt-all libsox-fmt-mp3 libsndfile1-dev ffmpeg \
    && apt-get install -y librdmacm1 libibumad3 librdmacm-dev libibverbs1 libibverbs-dev ibverbs-utils ibverbs-providers \
    && rm -rf /var/lib/apt/lists/* \
    && apt-get clean
    
WORKDIR /workspace

RUN git clone https://github.com/SWivid/F5-TTS.git \
    && cd F5-TTS \
    && git submodule update --init --recursive \
    && pip install -e . --no-cache-dir

ENV SHELL=/bin/bash

VOLUME /root/.cache/huggingface/hub/

EXPOSE 7860

WORKDIR /workspace/F5-TTS


================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2024 Yushen CHEN

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching

[![python](https://img.shields.io/badge/Python-3.10-brightgreen)](https://github.com/SWivid/F5-TTS)
[![arXiv](https://img.shields.io/badge/arXiv-2410.06885-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2410.06885)
[![demo](https://img.shields.io/badge/GitHub-Demo-orange.svg)](https://swivid.github.io/F5-TTS/)
[![hfspace](https://img.shields.io/badge/🤗-HF%20Space-yellow)](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
[![msspace](https://img.shields.io/badge/🤖-MS%20Space-blue)](https://modelscope.cn/studios/AI-ModelScope/E2-F5-TTS)
[![lab](https://img.shields.io/badge/🏫-X--LANCE-grey?labelColor=lightgrey)](https://x-lance.sjtu.edu.cn/)
[![lab](https://img.shields.io/badge/🏫-SII-grey?labelColor=lightgrey)](https://www.sii.edu.cn/)
[![lab](https://img.shields.io/badge/🏫-PCL-grey?labelColor=lightgrey)](https://www.pcl.ac.cn)
<!-- <img src="https://github.com/user-attachments/assets/12d7749c-071a-427c-81bf-b87b91def670" alt="Watermark" style="width: 40px; height: auto"> -->

**F5-TTS**: Diffusion Transformer with ConvNeXt V2, faster trained and inference.

**E2 TTS**: Flat-UNet Transformer, closest reproduction from [paper](https://arxiv.org/abs/2406.18009).

**Sway Sampling**: Inference-time flow step sampling strategy, greatly improves performance

### Thanks to all the contributors !

## News
- **2025/03/12**: 🔥 F5-TTS v1 base model with better training and inference performance. [Few demo](https://swivid.github.io/F5-TTS_updates).
- **2024/10/08**: F5-TTS & E2 TTS base models on [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS), [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), [🟣 Wisemodel](https://wisemodel.cn/models/SJTU_X-LANCE/F5-TTS_Emilia-ZH-EN).

## Installation

### Create a separate environment if needed

```bash
# Create a conda env with python_version>=3.10  (you could also use virtualenv)
conda create -n f5-tts python=3.11
conda activate f5-tts

# Install FFmpeg if you haven't yet
conda install ffmpeg
```

### Install PyTorch with matched device

<details>
<summary>NVIDIA GPU</summary>

> ```bash
> # Install pytorch with your CUDA version, e.g.
> pip install torch==2.8.0+cu128 torchaudio==2.8.0+cu128 --extra-index-url https://download.pytorch.org/whl/cu128
> 
> # And also possible previous versions, e.g.
> pip install torch==2.4.0+cu124 torchaudio==2.4.0+cu124 --extra-index-url https://download.pytorch.org/whl/cu124
> # etc.
> ```

</details>

<details>
<summary>AMD GPU</summary>

> ```bash
> # Install pytorch with your ROCm version (Linux only), e.g.
> pip install torch==2.5.1+rocm6.2 torchaudio==2.5.1+rocm6.2 --extra-index-url https://download.pytorch.org/whl/rocm6.2
> ```

</details>

<details>
<summary>Intel GPU</summary>

> ```bash
> # Install pytorch with your XPU version, e.g.
> # Intel® Deep Learning Essentials or Intel® oneAPI Base Toolkit must be installed
> pip install torch torchaudio --index-url https://download.pytorch.org/whl/test/xpu
> 
> # Intel GPU support is also available through IPEX (Intel® Extension for PyTorch)
> # IPEX does not require the Intel® Deep Learning Essentials or Intel® oneAPI Base Toolkit
> # See: https://pytorch-extension.intel.com/installation?request=platform
> ```

</details>

<details>
<summary>Apple Silicon</summary>

> ```bash
> # Install the stable pytorch, e.g.
> pip install torch torchaudio
> ```

</details>

### Then you can choose one from below:

> ### 1. As a pip package (if just for inference)
> 
> ```bash
> pip install f5-tts
> ```
> 
> ### 2. Local editable (if also do training, finetuning)
> 
> ```bash
> git clone https://github.com/SWivid/F5-TTS.git
> cd F5-TTS
> # git submodule update --init --recursive  # (optional, if use bigvgan as vocoder)
> pip install -e .
> ```

### Docker usage also available
```bash
# Build from Dockerfile
docker build -t f5tts:v1 .

# Run from GitHub Container Registry
docker container run --rm -it --gpus=all --mount 'type=volume,source=f5-tts,target=/root/.cache/huggingface/hub/' -p 7860:7860 ghcr.io/swivid/f5-tts:main

# Quickstart if you want to just run the web interface (not CLI)
docker container run --rm -it --gpus=all --mount 'type=volume,source=f5-tts,target=/root/.cache/huggingface/hub/' -p 7860:7860 ghcr.io/swivid/f5-tts:main f5-tts_infer-gradio --host 0.0.0.0
```

### Runtime

Deployment solution with Triton and TensorRT-LLM.

#### Benchmark Results
Decoding on a single L20 GPU, using 26 different prompt_audio & target_text pairs, 16 NFE.

| Model               | Concurrency    | Avg Latency | RTF    | Mode            |
|---------------------|----------------|-------------|--------|-----------------|
| F5-TTS Base (Vocos) | 2              | 253 ms      | 0.0394 | Client-Server   |
| F5-TTS Base (Vocos) | 1 (Batch_size) | -           | 0.0402 | Offline TRT-LLM |
| F5-TTS Base (Vocos) | 1 (Batch_size) | -           | 0.1467 | Offline Pytorch |

See [detailed instructions](src/f5_tts/runtime/triton_trtllm/README.md) for more information.


## Inference

- In order to achieve desired performance, take a moment to read [detailed guidance](src/f5_tts/infer).
- By properly searching the keywords of problem encountered, [issues](https://github.com/SWivid/F5-TTS/issues?q=is%3Aissue) are very helpful.

### 1. Gradio App

Currently supported features:

- Basic TTS with Chunk Inference
- Multi-Style / Multi-Speaker Generation
- Voice Chat powered by Qwen2.5-3B-Instruct
- [Custom inference with more language support](src/f5_tts/infer/SHARED.md)

```bash
# Launch a Gradio app (web interface)
f5-tts_infer-gradio

# Specify the port/host
f5-tts_infer-gradio --port 7860 --host 0.0.0.0

# Launch a share link
f5-tts_infer-gradio --share
```

<details>
<summary>NVIDIA device docker compose file example</summary>

```yaml
services:
  f5-tts:
    image: ghcr.io/swivid/f5-tts:main
    ports:
      - "7860:7860"
    environment:
      GRADIO_SERVER_PORT: 7860
    entrypoint: ["f5-tts_infer-gradio", "--port", "7860", "--host", "0.0.0.0"]
    deploy:
      resources:
        reservations:
          devices:
            - driver: nvidia
              count: 1
              capabilities: [gpu]

volumes:
  f5-tts:
    driver: local
```

</details>

### 2. CLI Inference

```bash
# Run with flags
# Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
f5-tts_infer-cli --model F5TTS_v1_Base \
--ref_audio "provide_prompt_wav_path_here.wav" \
--ref_text "The content, subtitle or transcription of reference audio." \
--gen_text "Some text you want TTS model generate for you."

# Run with default setting. src/f5_tts/infer/examples/basic/basic.toml
f5-tts_infer-cli
# Or with your own .toml file
f5-tts_infer-cli -c custom.toml

# Multi voice. See src/f5_tts/infer/README.md
f5-tts_infer-cli -c src/f5_tts/infer/examples/multi/story.toml
```


## Training

### 1. With Hugging Face Accelerate

Refer to [training & finetuning guidance](src/f5_tts/train) for best practice.

### 2. With Gradio App

```bash
# Quick start with Gradio web interface
f5-tts_finetune-gradio
```

Read [training & finetuning guidance](src/f5_tts/train) for more instructions.


## [Evaluation](src/f5_tts/eval)


## Development

Use pre-commit to ensure code quality (will run linters and formatters automatically):

```bash
pip install pre-commit
pre-commit install
```

When making a pull request, before each commit, run: 

```bash
pre-commit run --all-files
```

Note: Some model components have linting exceptions for E722 to accommodate tensor notation.


## Acknowledgements

- [E2-TTS](https://arxiv.org/abs/2406.18009) brilliant work, simple and effective
- [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763), [LibriTTS](https://arxiv.org/abs/1904.02882), [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) valuable datasets
- [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion
- [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure
- [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) and [BigVGAN](https://github.com/NVIDIA/BigVGAN) as vocoder
- [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech), [SpeechMOS](https://github.com/tarepan/SpeechMOS) for evaluation tools
- [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
- [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
- [f5-tts-mlx](https://github.com/lucasnewman/f5-tts-mlx/tree/main) Implementation with MLX framework by [Lucas Newman](https://github.com/lucasnewman)
- [F5-TTS-ONNX](https://github.com/DakeQQ/F5-TTS-ONNX) ONNX Runtime version by [DakeQQ](https://github.com/DakeQQ)
- [Yuekai Zhang](https://github.com/yuekaizhang) Triton and TensorRT-LLM support ~

## Citation
If our work and codebase is useful for you, please cite as:
```
@article{chen-etal-2024-f5tts,
      title={F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching}, 
      author={Yushen Chen and Zhikang Niu and Ziyang Ma and Keqi Deng and Chunhui Wang and Jian Zhao and Kai Yu and Xie Chen},
      journal={arXiv preprint arXiv:2410.06885},
      year={2024},
}
```
## License

Our code is released under MIT License. The pre-trained models are licensed under the CC-BY-NC license due to the training data Emilia, which is an in-the-wild dataset. Sorry for any inconvenience this may cause.


================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["setuptools >= 61.0", "setuptools-scm>=8.0"]
build-backend = "setuptools.build_meta"

[project]
name = "f5-tts"
version = "1.1.17"
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
readme = "README.md"
license = {text = "MIT License"}
classifiers = [
    "License :: OSI Approved :: MIT License",
    "Operating System :: OS Independent",
    "Programming Language :: Python :: 3",
]
dependencies = [
    "accelerate>=0.33.0",
    "bitsandbytes>0.37.0; platform_machine!='arm64' and platform_system!='Darwin'",
    "cached_path",
    "click",
    "datasets",
    "ema_pytorch>=0.5.2",
    "gradio>=6.0.0",
    "hydra-core>=1.3.0",
    "librosa",
    "matplotlib",
    "numpy<=1.26.4; python_version<='3.10'",
    "pydub",
    "pypinyin",
    "rjieba",
    "safetensors",
    "soundfile",
    "tomli",
    "torch>=2.0.0",
    "torchaudio>=2.0.0",
    "torchcodec",
    "torchdiffeq",
    "tqdm>=4.65.0",
    "transformers",
    "transformers_stream_generator",
    "unidecode",
    "vocos",
    "wandb",
    "x_transformers>=1.31.14",
]

[project.optional-dependencies]
eval = [
    "faster_whisper==0.10.1",
    "funasr",
    "jiwer",
    "modelscope",
    "zhconv",
    "zhon",
]

[project.urls]
Homepage = "https://github.com/SWivid/F5-TTS"

[project.scripts]
"f5-tts_infer-cli" = "f5_tts.infer.infer_cli:main"
"f5-tts_infer-gradio" = "f5_tts.infer.infer_gradio:main"
"f5-tts_finetune-cli" = "f5_tts.train.finetune_cli:main"
"f5-tts_finetune-gradio" = "f5_tts.train.finetune_gradio:main"


================================================
FILE: ruff.toml
================================================
line-length = 120
target-version = "py310"

[lint]
# Only ignore variables with names starting with "_".
dummy-variable-rgx = "^_.*$"

[lint.isort]
force-single-line = false
lines-after-imports = 2


================================================
FILE: src/f5_tts/api.py
================================================
import random
import sys
from importlib.resources import files

import soundfile as sf
import tqdm
from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf

from f5_tts.infer.utils_infer import (
    infer_process,
    load_model,
    load_vocoder,
    preprocess_ref_audio_text,
    remove_silence_for_generated_wav,
    save_spectrogram,
    transcribe,
)
from f5_tts.model.utils import seed_everything


class F5TTS:
    def __init__(
        self,
        model="F5TTS_v1_Base",
        ckpt_file="",
        vocab_file="",
        ode_method="euler",
        use_ema=True,
        vocoder_local_path=None,
        device=None,
        hf_cache_dir=None,
    ):
        model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
        model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
        model_arc = model_cfg.model.arch

        self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
        self.target_sample_rate = model_cfg.model.mel_spec.target_sample_rate

        self.ode_method = ode_method
        self.use_ema = use_ema

        if device is not None:
            self.device = device
        else:
            import torch

            self.device = (
                "cuda"
                if torch.cuda.is_available()
                else "xpu"
                if torch.xpu.is_available()
                else "mps"
                if torch.backends.mps.is_available()
                else "cpu"
            )

        # Load models
        self.vocoder = load_vocoder(
            self.mel_spec_type, vocoder_local_path is not None, vocoder_local_path, self.device, hf_cache_dir
        )

        repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"

        # override for previous models
        if model == "F5TTS_Base":
            if self.mel_spec_type == "vocos":
                ckpt_step = 1200000
            elif self.mel_spec_type == "bigvgan":
                model = "F5TTS_Base_bigvgan"
                ckpt_type = "pt"
        elif model == "E2TTS_Base":
            repo_name = "E2-TTS"
            ckpt_step = 1200000

        if not ckpt_file:
            ckpt_file = str(
                cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}", cache_dir=hf_cache_dir)
            )
        self.ema_model = load_model(
            model_cls, model_arc, ckpt_file, self.mel_spec_type, vocab_file, self.ode_method, self.use_ema, self.device
        )

    def transcribe(self, ref_audio, language=None):
        return transcribe(ref_audio, language)

    def export_wav(self, wav, file_wave, remove_silence=False):
        sf.write(file_wave, wav, self.target_sample_rate)

        if remove_silence:
            remove_silence_for_generated_wav(file_wave)

    def export_spectrogram(self, spec, file_spec):
        save_spectrogram(spec, file_spec)

    def infer(
        self,
        ref_file,
        ref_text,
        gen_text,
        show_info=print,
        progress=tqdm,
        target_rms=0.1,
        cross_fade_duration=0.15,
        sway_sampling_coef=-1,
        cfg_strength=2,
        nfe_step=32,
        speed=1.0,
        fix_duration=None,
        remove_silence=False,
        file_wave=None,
        file_spec=None,
        seed=None,
    ):
        if seed is None:
            seed = random.randint(0, sys.maxsize)
        seed_everything(seed)
        self.seed = seed

        ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, show_info=show_info)

        wav, sr, spec = infer_process(
            ref_file,
            ref_text,
            gen_text,
            self.ema_model,
            self.vocoder,
            self.mel_spec_type,
            show_info=show_info,
            progress=progress,
            target_rms=target_rms,
            cross_fade_duration=cross_fade_duration,
            nfe_step=nfe_step,
            cfg_strength=cfg_strength,
            sway_sampling_coef=sway_sampling_coef,
            speed=speed,
            fix_duration=fix_duration,
            device=self.device,
        )

        if file_wave is not None:
            self.export_wav(wav, file_wave, remove_silence)

        if file_spec is not None:
            self.export_spectrogram(spec, file_spec)

        return wav, sr, spec


if __name__ == "__main__":
    f5tts = F5TTS()

    wav, sr, spec = f5tts.infer(
        ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
        ref_text="Some call me nature, others call me mother nature.",
        gen_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring.",
        file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
        file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")),
        seed=None,
    )

    print("seed :", f5tts.seed)


================================================
FILE: src/f5_tts/configs/E2TTS_Base.yaml
================================================
hydra:
  run:
    dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}

datasets:
  name: Emilia_ZH_EN  # dataset name
  batch_size_per_gpu: 38400  # 8 GPUs, 8 * 38400 = 307200
  batch_size_type: frame  # frame | sample
  max_samples: 64  # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
  num_workers: 16

optim:
  epochs: 11
  learning_rate: 7.5e-5
  num_warmup_updates: 20000  # warmup updates
  grad_accumulation_steps: 1  # note: updates = steps / grad_accumulation_steps
  max_grad_norm: 1.0  # gradient clipping
  bnb_optimizer: False  # use bnb 8bit AdamW optimizer or not

model:
  name: E2TTS_Base
  tokenizer: pinyin
  tokenizer_path: null  # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
  backbone: UNetT
  arch:
    dim: 1024
    depth: 24
    heads: 16
    ff_mult: 4
    text_mask_padding: False
    pe_attn_head: 1
  mel_spec:
    target_sample_rate: 24000
    n_mel_channels: 100
    hop_length: 256
    win_length: 1024
    n_fft: 1024
    mel_spec_type: vocos  # vocos | bigvgan
  vocoder:
    is_local: False  # use local offline ckpt or not
    local_path: null  # local vocoder path

ckpts:
  logger: wandb  # wandb | tensorboard | null
  wandb_project: CFM-TTS  # wandb project name
  wandb_run_name: ${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}  # wandb run name
  wandb_resume_id: null  # wandb run id for resuming, null to auto-detect from checkpoint
  log_samples: True  # infer random sample per save checkpoint. wip, normal to fail with extra long samples
  save_per_updates: 50000  # save checkpoint per updates
  keep_last_n_checkpoints: -1  # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
  last_per_updates: 5000  # save last checkpoint per updates
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}

================================================
FILE: src/f5_tts/configs/E2TTS_Small.yaml
================================================
hydra:
  run:
    dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}

datasets:
  name: Emilia_ZH_EN
  batch_size_per_gpu: 38400  # 8 GPUs, 8 * 38400 = 307200
  batch_size_type: frame  # frame | sample
  max_samples: 64  # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
  num_workers: 16

optim:
  epochs: 11
  learning_rate: 7.5e-5
  num_warmup_updates: 20000  # warmup updates
  grad_accumulation_steps: 1  # note: updates = steps / grad_accumulation_steps
  max_grad_norm: 1.0
  bnb_optimizer: False  

model:
  name: E2TTS_Small
  tokenizer: pinyin
  tokenizer_path: null  # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
  backbone: UNetT
  arch:
    dim: 768
    depth: 20
    heads: 12
    ff_mult: 4
    text_mask_padding: False
    pe_attn_head: 1
  mel_spec:
    target_sample_rate: 24000
    n_mel_channels: 100
    hop_length: 256
    win_length: 1024
    n_fft: 1024
    mel_spec_type: vocos  # vocos | bigvgan
  vocoder:
    is_local: False  # use local offline ckpt or not
    local_path: null  # local vocoder path

ckpts:
  logger: wandb  # wandb | tensorboard | null
  wandb_project: CFM-TTS  # wandb project name
  wandb_run_name: ${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}  # wandb run name
  wandb_resume_id: null  # wandb run id for resuming, null to auto-detect from checkpoint
  log_samples: True  # infer random sample per save checkpoint. wip, normal to fail with extra long samples
  save_per_updates: 50000  # save checkpoint per updates
  keep_last_n_checkpoints: -1  # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
  last_per_updates: 5000  # save last checkpoint per updates
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}

================================================
FILE: src/f5_tts/configs/F5TTS_Base.yaml
================================================
hydra:
  run:
    dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}

datasets:
  name: Emilia_ZH_EN  # dataset name
  batch_size_per_gpu: 38400  # 8 GPUs, 8 * 38400 = 307200
  batch_size_type: frame  # frame | sample
  max_samples: 64  # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
  num_workers: 16

optim:
  epochs: 11
  learning_rate: 7.5e-5
  num_warmup_updates: 20000  # warmup updates
  grad_accumulation_steps: 1  # note: updates = steps / grad_accumulation_steps
  max_grad_norm: 1.0  # gradient clipping
  bnb_optimizer: False  # use bnb 8bit AdamW optimizer or not

model:
  name: F5TTS_Base  # model name
  tokenizer: pinyin  # tokenizer type
  tokenizer_path: null  # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
  backbone: DiT
  arch:
    dim: 1024
    depth: 22
    heads: 16
    ff_mult: 2
    text_dim: 512
    text_mask_padding: False
    conv_layers: 4
    pe_attn_head: 1
    attn_backend: torch  # torch | flash_attn
    attn_mask_enabled: False
    checkpoint_activations: False  # recompute activations and save memory for extra compute
  mel_spec:
    target_sample_rate: 24000
    n_mel_channels: 100
    hop_length: 256
    win_length: 1024
    n_fft: 1024
    mel_spec_type: vocos  # vocos | bigvgan
  vocoder:
    is_local: False  # use local offline ckpt or not
    local_path: null  # local vocoder path

ckpts:
  logger: wandb  # wandb | tensorboard | null
  wandb_project: CFM-TTS  # wandb project name
  wandb_run_name: ${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}  # wandb run name
  wandb_resume_id: null  # wandb run id for resuming, null to auto-detect from checkpoint
  log_samples: True  # infer random sample per save checkpoint. wip, normal to fail with extra long samples
  save_per_updates: 50000  # save checkpoint per updates
  keep_last_n_checkpoints: -1  # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
  last_per_updates: 5000  # save last checkpoint per updates
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}

================================================
FILE: src/f5_tts/configs/F5TTS_Small.yaml
================================================
hydra:
  run:
    dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}

datasets:
  name: Emilia_ZH_EN
  batch_size_per_gpu: 38400  # 8 GPUs, 8 * 38400 = 307200
  batch_size_type: frame  # frame | sample
  max_samples: 64  # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
  num_workers: 16

optim:
  epochs: 11  # only suitable for Emilia, if you want to train it on LibriTTS, set epoch 686
  learning_rate: 7.5e-5
  num_warmup_updates: 20000  # warmup updates
  grad_accumulation_steps: 1  # note: updates = steps / grad_accumulation_steps
  max_grad_norm: 1.0  # gradient clipping
  bnb_optimizer: False  # use bnb 8bit AdamW optimizer or not

model:
  name: F5TTS_Small
  tokenizer: pinyin
  tokenizer_path: null  # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
  backbone: DiT
  arch:
    dim: 768
    depth: 18
    heads: 12
    ff_mult: 2
    text_dim: 512
    text_mask_padding: False
    conv_layers: 4
    pe_attn_head: 1
    attn_backend: torch  # torch | flash_attn
    attn_mask_enabled: False
    checkpoint_activations: False  # recompute activations and save memory for extra compute
  mel_spec:
    target_sample_rate: 24000
    n_mel_channels: 100
    hop_length: 256
    win_length: 1024
    n_fft: 1024
    mel_spec_type: vocos  # vocos | bigvgan
  vocoder:
    is_local: False  # use local offline ckpt or not
    local_path: null  # local vocoder path

ckpts:
  logger: wandb  # wandb | tensorboard | null
  wandb_project: CFM-TTS  # wandb project name
  wandb_run_name: ${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}  # wandb run name
  wandb_resume_id: null  # wandb run id for resuming, null to auto-detect from checkpoint
  log_samples: True  # infer random sample per save checkpoint. wip, normal to fail with extra long samples
  save_per_updates: 50000  # save checkpoint per updates
  keep_last_n_checkpoints: -1  # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
  last_per_updates: 5000  # save last checkpoint per updates
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}


================================================
FILE: src/f5_tts/configs/F5TTS_v1_Base.yaml
================================================
hydra:
  run:
    dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}

datasets:
  name: Emilia_ZH_EN  # dataset name
  batch_size_per_gpu: 38400  # 8 GPUs, 8 * 38400 = 307200
  batch_size_type: frame  # frame | sample
  max_samples: 64  # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
  num_workers: 16

optim:
  epochs: 11
  learning_rate: 7.5e-5
  num_warmup_updates: 20000  # warmup updates
  grad_accumulation_steps: 1  # note: updates = steps / grad_accumulation_steps
  max_grad_norm: 1.0  # gradient clipping
  bnb_optimizer: False  # use bnb 8bit AdamW optimizer or not

model:
  name: F5TTS_v1_Base  # model name
  tokenizer: pinyin  # tokenizer type
  tokenizer_path: null  # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
  backbone: DiT
  arch:
    dim: 1024
    depth: 22
    heads: 16
    ff_mult: 2
    text_dim: 512
    text_mask_padding: True
    qk_norm: null  # null | rms_norm
    conv_layers: 4
    pe_attn_head: null
    attn_backend: torch  # torch | flash_attn
    attn_mask_enabled: False
    checkpoint_activations: False  # recompute activations and save memory for extra compute
  mel_spec:
    target_sample_rate: 24000
    n_mel_channels: 100
    hop_length: 256
    win_length: 1024
    n_fft: 1024
    mel_spec_type: vocos  # vocos | bigvgan
  vocoder:
    is_local: False  # use local offline ckpt or not
    local_path: null  # local vocoder path

ckpts:
  logger: wandb  # wandb | tensorboard | null
  wandb_project: CFM-TTS  # wandb project name
  wandb_run_name: ${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}  # wandb run name
  wandb_resume_id: null  # wandb run id for resuming, null to auto-detect from checkpoint
  log_samples: True  # infer random sample per save checkpoint. wip, normal to fail with extra long samples
  save_per_updates: 50000  # save checkpoint per updates
  keep_last_n_checkpoints: -1  # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
  last_per_updates: 5000  # save last checkpoint per updates
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}

================================================
FILE: src/f5_tts/eval/README.md
================================================

# Evaluation

Install packages for evaluation:

```bash
pip install -e .[eval]
```

> [!IMPORTANT]
> For [faster-whisper](https://github.com/SYSTRAN/faster-whisper), for various compatibilities:   
> `pip install ctranslate2==4.5.0` if CUDA 12 and cuDNN 9;  
> `pip install ctranslate2==4.4.0` if CUDA 12 and cuDNN 8;  
> `pip install ctranslate2==3.24.0` if CUDA 11 and cuDNN 8.

## Generating Samples for Evaluation

### Prepare Test Datasets

1. *Seed-TTS testset*: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
2. *LibriSpeech test-clean*: Download from [OpenSLR](http://www.openslr.org/12/).
3. Unzip the downloaded datasets and place them in the `data/` directory.
4. Our filtered LibriSpeech-PC 4-10s subset: `data/librispeech_pc_test_clean_cross_sentence.lst`

### Batch Inference for Test Set

To run batch inference for evaluations, execute the following commands:

```bash
# if not setup accelerate config yet
accelerate config

# if only perform inference
bash src/f5_tts/eval/eval_infer_batch.sh --infer-only

# if inference and with corresponding evaluation, setup the following tools first
bash src/f5_tts/eval/eval_infer_batch.sh
```

## Objective Evaluation on Generated Results

### Download Evaluation Model Checkpoints

1. Chinese ASR Model: [Paraformer-zh](https://huggingface.co/funasr/paraformer-zh)
2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3)
3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view).

> [!NOTE]  
> ASR model will be automatically downloaded if `--local` not set for evaluation scripts.  
> Otherwise, you should update the `asr_ckpt_dir` path values in `eval_librispeech_test_clean.py` or `eval_seedtts_testset.py`.
> 
> WavLM model must be downloaded and your `wavlm_ckpt_dir` path updated in `eval_librispeech_test_clean.py` and `eval_seedtts_testset.py`.

### Objective Evaluation Examples

Update the path with your batch-inferenced results, and carry out WER / SIM / UTMOS evaluations:
```bash
# Evaluation [WER] for Seed-TTS test [ZH] set
python src/f5_tts/eval/eval_seedtts_testset.py --eval_task wer --lang zh --gen_wav_dir <GEN_WAV_DIR> --gpu_nums 8

# Evaluation [SIM] for LibriSpeech-PC test-clean (cross-sentence)
python src/f5_tts/eval/eval_librispeech_test_clean.py --eval_task sim --gen_wav_dir <GEN_WAV_DIR> --librispeech_test_clean_path <TEST_CLEAN_PATH>

# Evaluation [UTMOS]. --ext: Audio extension
python src/f5_tts/eval/eval_utmos.py --audio_dir <WAV_DIR> --ext wav
```

> [!NOTE]  
> Evaluation results can also be found in `_*_results.jsonl` files saved in `<GEN_WAV_DIR>`/`<WAV_DIR>`.


================================================
FILE: src/f5_tts/eval/ecapa_tdnn.py
================================================
# just for speaker similarity evaluation, third-party code

# From https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN

import os

import torch
import torch.nn as nn
import torch.nn.functional as F


""" Res2Conv1d + BatchNorm1d + ReLU
"""


class Res2Conv1dReluBn(nn.Module):
    """
    in_channels == out_channels == channels
    """

    def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
        super().__init__()
        assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
        self.scale = scale
        self.width = channels // scale
        self.nums = scale if scale == 1 else scale - 1

        self.convs = []
        self.bns = []
        for i in range(self.nums):
            self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
            self.bns.append(nn.BatchNorm1d(self.width))
        self.convs = nn.ModuleList(self.convs)
        self.bns = nn.ModuleList(self.bns)

    def forward(self, x):
        out = []
        spx = torch.split(x, self.width, 1)
        for i in range(self.nums):
            if i == 0:
                sp = spx[i]
            else:
                sp = sp + spx[i]
            # Order: conv -> relu -> bn
            sp = self.convs[i](sp)
            sp = self.bns[i](F.relu(sp))
            out.append(sp)
        if self.scale != 1:
            out.append(spx[self.nums])
        out = torch.cat(out, dim=1)

        return out


""" Conv1d + BatchNorm1d + ReLU
"""


class Conv1dReluBn(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
        super().__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
        self.bn = nn.BatchNorm1d(out_channels)

    def forward(self, x):
        return self.bn(F.relu(self.conv(x)))


""" The SE connection of 1D case.
"""


class SE_Connect(nn.Module):
    def __init__(self, channels, se_bottleneck_dim=128):
        super().__init__()
        self.linear1 = nn.Linear(channels, se_bottleneck_dim)
        self.linear2 = nn.Linear(se_bottleneck_dim, channels)

    def forward(self, x):
        out = x.mean(dim=2)
        out = F.relu(self.linear1(out))
        out = torch.sigmoid(self.linear2(out))
        out = x * out.unsqueeze(2)

        return out


""" SE-Res2Block of the ECAPA-TDNN architecture.
"""

# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
#     return nn.Sequential(
#         Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
#         Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
#         Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
#         SE_Connect(channels)
#     )


class SE_Res2Block(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
        super().__init__()
        self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale)
        self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)

        self.shortcut = None
        if in_channels != out_channels:
            self.shortcut = nn.Conv1d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1,
            )

    def forward(self, x):
        residual = x
        if self.shortcut:
            residual = self.shortcut(x)

        x = self.Conv1dReluBn1(x)
        x = self.Res2Conv1dReluBn(x)
        x = self.Conv1dReluBn2(x)
        x = self.SE_Connect(x)

        return x + residual


""" Attentive weighted mean and standard deviation pooling.
"""


class AttentiveStatsPool(nn.Module):
    def __init__(self, in_dim, attention_channels=128, global_context_att=False):
        super().__init__()
        self.global_context_att = global_context_att

        # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
        if global_context_att:
            self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1)  # equals W and b in the paper
        else:
            self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1)  # equals W and b in the paper
        self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1)  # equals V and k in the paper

    def forward(self, x):
        if self.global_context_att:
            context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
            context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
            x_in = torch.cat((x, context_mean, context_std), dim=1)
        else:
            x_in = x

        # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
        alpha = torch.tanh(self.linear1(x_in))
        # alpha = F.relu(self.linear1(x_in))
        alpha = torch.softmax(self.linear2(alpha), dim=2)
        mean = torch.sum(alpha * x, dim=2)
        residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
        std = torch.sqrt(residuals.clamp(min=1e-9))
        return torch.cat([mean, std], dim=1)


class ECAPA_TDNN(nn.Module):
    def __init__(
        self,
        feat_dim=80,
        channels=512,
        emb_dim=192,
        global_context_att=False,
        feat_type="wavlm_large",
        sr=16000,
        feature_selection="hidden_states",
        update_extract=False,
        config_path=None,
    ):
        super().__init__()

        self.feat_type = feat_type
        self.feature_selection = feature_selection
        self.update_extract = update_extract
        self.sr = sr

        torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
        try:
            local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
            self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source="local", config_path=config_path)
        except:  # noqa: E722
            self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type)

        if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
            self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"
        ):
            self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
        if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
            self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"
        ):
            self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False

        self.feat_num = self.get_feat_num()
        self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))

        if feat_type != "fbank" and feat_type != "mfcc":
            freeze_list = ["final_proj", "label_embs_concat", "mask_emb", "project_q", "quantizer"]
            for name, param in self.feature_extract.named_parameters():
                for freeze_val in freeze_list:
                    if freeze_val in name:
                        param.requires_grad = False
                        break

        if not self.update_extract:
            for param in self.feature_extract.parameters():
                param.requires_grad = False

        self.instance_norm = nn.InstanceNorm1d(feat_dim)
        # self.channels = [channels] * 4 + [channels * 3]
        self.channels = [channels] * 4 + [1536]

        self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
        self.layer2 = SE_Res2Block(
            self.channels[0],
            self.channels[1],
            kernel_size=3,
            stride=1,
            padding=2,
            dilation=2,
            scale=8,
            se_bottleneck_dim=128,
        )
        self.layer3 = SE_Res2Block(
            self.channels[1],
            self.channels[2],
            kernel_size=3,
            stride=1,
            padding=3,
            dilation=3,
            scale=8,
            se_bottleneck_dim=128,
        )
        self.layer4 = SE_Res2Block(
            self.channels[2],
            self.channels[3],
            kernel_size=3,
            stride=1,
            padding=4,
            dilation=4,
            scale=8,
            se_bottleneck_dim=128,
        )

        # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
        cat_channels = channels * 3
        self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
        self.pooling = AttentiveStatsPool(
            self.channels[-1], attention_channels=128, global_context_att=global_context_att
        )
        self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
        self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)

    def get_feat_num(self):
        self.feature_extract.eval()
        wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
        with torch.no_grad():
            features = self.feature_extract(wav)
        select_feature = features[self.feature_selection]
        if isinstance(select_feature, (list, tuple)):
            return len(select_feature)
        else:
            return 1

    def get_feat(self, x):
        if self.update_extract:
            x = self.feature_extract([sample for sample in x])
        else:
            with torch.no_grad():
                if self.feat_type == "fbank" or self.feat_type == "mfcc":
                    x = self.feature_extract(x) + 1e-6  # B x feat_dim x time_len
                else:
                    x = self.feature_extract([sample for sample in x])

        if self.feat_type == "fbank":
            x = x.log()

        if self.feat_type != "fbank" and self.feat_type != "mfcc":
            x = x[self.feature_selection]
            if isinstance(x, (list, tuple)):
                x = torch.stack(x, dim=0)
            else:
                x = x.unsqueeze(0)
            norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
            x = (norm_weights * x).sum(dim=0)
            x = torch.transpose(x, 1, 2) + 1e-6

        x = self.instance_norm(x)
        return x

    def forward(self, x):
        x = self.get_feat(x)

        out1 = self.layer1(x)
        out2 = self.layer2(out1)
        out3 = self.layer3(out2)
        out4 = self.layer4(out3)

        out = torch.cat([out2, out3, out4], dim=1)
        out = F.relu(self.conv(out))
        out = self.bn(self.pooling(out))
        out = self.linear(out)

        return out


def ECAPA_TDNN_SMALL(
    feat_dim,
    emb_dim=256,
    feat_type="wavlm_large",
    sr=16000,
    feature_selection="hidden_states",
    update_extract=False,
    config_path=None,
):
    return ECAPA_TDNN(
        feat_dim=feat_dim,
        channels=512,
        emb_dim=emb_dim,
        feat_type=feat_type,
        sr=sr,
        feature_selection=feature_selection,
        update_extract=update_extract,
        config_path=config_path,
    )


================================================
FILE: src/f5_tts/eval/eval_infer_batch.py
================================================
import os
import sys


sys.path.append(os.getcwd())

import argparse
import time
from importlib.resources import files

import torch
import torchaudio
from accelerate import Accelerator
from hydra.utils import get_class
from omegaconf import OmegaConf
from tqdm import tqdm

from f5_tts.eval.utils_eval import (
    get_inference_prompt,
    get_librispeech_test_clean_metainfo,
    get_seedtts_testset_metainfo,
)
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
from f5_tts.model import CFM
from f5_tts.model.utils import get_tokenizer


accelerator = Accelerator()
device = f"cuda:{accelerator.process_index}"


use_ema = True
target_rms = 0.1


rel_path = str(files("f5_tts").joinpath("../../"))


def main():
    parser = argparse.ArgumentParser(description="batch inference")

    parser.add_argument("-s", "--seed", default=None, type=int)
    parser.add_argument("-n", "--expname", required=True)
    parser.add_argument("-c", "--ckptstep", default=1250000, type=int)

    parser.add_argument("-nfe", "--nfestep", default=32, type=int)
    parser.add_argument("-o", "--odemethod", default="euler")
    parser.add_argument("-ss", "--swaysampling", default=-1, type=float)

    parser.add_argument("-t", "--testset", required=True)
    parser.add_argument(
        "-p", "--librispeech_test_clean_path", default=f"{rel_path}/data/LibriSpeech/test-clean", type=str
    )

    parser.add_argument("--local", action="store_true", help="Use local vocoder checkpoint directory")

    args = parser.parse_args()

    seed = args.seed
    exp_name = args.expname
    ckpt_step = args.ckptstep

    nfe_step = args.nfestep
    ode_method = args.odemethod
    sway_sampling_coef = args.swaysampling

    testset = args.testset

    infer_batch_size = 1  # max frames. 1 for ddp single inference (recommended)
    cfg_strength = 2.0
    speed = 1.0
    use_truth_duration = False
    no_ref_audio = False

    model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
    model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
    model_arc = model_cfg.model.arch

    dataset_name = model_cfg.datasets.name
    tokenizer = model_cfg.model.tokenizer

    mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
    target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
    n_mel_channels = model_cfg.model.mel_spec.n_mel_channels
    hop_length = model_cfg.model.mel_spec.hop_length
    win_length = model_cfg.model.mel_spec.win_length
    n_fft = model_cfg.model.mel_spec.n_fft

    if testset == "ls_pc_test_clean":
        metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
        librispeech_test_clean_path = args.librispeech_test_clean_path
        metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)

    elif testset == "seedtts_test_zh":
        metalst = rel_path + "/data/seedtts_testset/zh/meta.lst"
        metainfo = get_seedtts_testset_metainfo(metalst)

    elif testset == "seedtts_test_en":
        metalst = rel_path + "/data/seedtts_testset/en/meta.lst"
        metainfo = get_seedtts_testset_metainfo(metalst)

    # path to save genereted wavs
    output_dir = (
        f"{rel_path}/"
        f"results/{exp_name}_{ckpt_step}/{testset}/"
        f"seed{seed}_{ode_method}_nfe{nfe_step}_{mel_spec_type}"
        f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
        f"_cfg{cfg_strength}_speed{speed}"
        f"{'_gt-dur' if use_truth_duration else ''}"
        f"{'_no-ref-audio' if no_ref_audio else ''}"
    )

    # -------------------------------------------------#

    prompts_all = get_inference_prompt(
        metainfo,
        speed=speed,
        tokenizer=tokenizer,
        target_sample_rate=target_sample_rate,
        n_mel_channels=n_mel_channels,
        hop_length=hop_length,
        mel_spec_type=mel_spec_type,
        target_rms=target_rms,
        use_truth_duration=use_truth_duration,
        infer_batch_size=infer_batch_size,
    )

    # Vocoder model
    local = args.local
    if mel_spec_type == "vocos":
        vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
    elif mel_spec_type == "bigvgan":
        vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
    vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path)

    # Tokenizer
    vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)

    # Model
    model = CFM(
        transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
        mel_spec_kwargs=dict(
            n_fft=n_fft,
            hop_length=hop_length,
            win_length=win_length,
            n_mel_channels=n_mel_channels,
            target_sample_rate=target_sample_rate,
            mel_spec_type=mel_spec_type,
        ),
        odeint_kwargs=dict(
            method=ode_method,
        ),
        vocab_char_map=vocab_char_map,
    ).to(device)

    ckpt_prefix = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}"
    if os.path.exists(ckpt_prefix + ".pt"):
        ckpt_path = ckpt_prefix + ".pt"
    elif os.path.exists(ckpt_prefix + ".safetensors"):
        ckpt_path = ckpt_prefix + ".safetensors"
    else:
        print("Loading from self-organized training checkpoints rather than released pretrained.")
        ckpt_prefix = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}"
        if os.path.exists(ckpt_prefix + ".pt"):
            ckpt_path = ckpt_prefix + ".pt"
        elif os.path.exists(ckpt_prefix + ".safetensors"):
            ckpt_path = ckpt_prefix + ".safetensors"
        else:
            raise ValueError("The checkpoint does not exist or cannot be found in given location.")

    dtype = torch.float32 if mel_spec_type == "bigvgan" else None
    model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)

    if not os.path.exists(output_dir) and accelerator.is_main_process:
        os.makedirs(output_dir)

    # start batch inference
    accelerator.wait_for_everyone()
    start = time.time()

    with accelerator.split_between_processes(prompts_all) as prompts:
        for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
            utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
            ref_mels = ref_mels.to(device)
            ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
            total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)

            # Inference
            with torch.inference_mode():
                generated, _ = model.sample(
                    cond=ref_mels,
                    text=final_text_list,
                    duration=total_mel_lens,
                    lens=ref_mel_lens,
                    steps=nfe_step,
                    cfg_strength=cfg_strength,
                    sway_sampling_coef=sway_sampling_coef,
                    no_ref_audio=no_ref_audio,
                    seed=seed,
                )
                # Final result
                for i, gen in enumerate(generated):
                    gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
                    gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
                    if mel_spec_type == "vocos":
                        generated_wave = vocoder.decode(gen_mel_spec).cpu()
                    elif mel_spec_type == "bigvgan":
                        generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()

                    if ref_rms_list[i] < target_rms:
                        generated_wave = generated_wave * ref_rms_list[i] / target_rms
                    torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)

    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        timediff = time.time() - start
        print(f"Done batch inference in {timediff / 60:.2f} minutes.")


if __name__ == "__main__":
    main()


================================================
FILE: src/f5_tts/eval/eval_infer_batch.sh
================================================
#!/bin/bash
set -e
export PYTHONWARNINGS="ignore::UserWarning,ignore::FutureWarning"

# Configuration parameters
MODEL_NAME="F5TTS_v1_Base"
SEEDS=(0 1 2)
CKPTSTEPS=(1250000)
TASKS=("seedtts_test_zh" "seedtts_test_en" "ls_pc_test_clean")
LS_TEST_CLEAN_PATH="data/LibriSpeech/test-clean"
GPUS="[0,1,2,3,4,5,6,7]"
OFFLINE_MODE=false

# Parse arguments
if [ $OFFLINE_MODE = true ]; then
    LOCAL="--local"
else
    LOCAL=""
fi
INFER_ONLY=false
while [[ $# -gt 0 ]]; do
    case $1 in
        --infer-only)
            INFER_ONLY=true
            shift
            ;;
        *)
            echo "======== Unknown parameter: $1"
            exit 1
            ;;
    esac
done

echo "======== Starting F5-TTS batch evaluation task..."
if [ "$INFER_ONLY" = true ]; then
    echo "======== Mode: Execute infer tasks only"
else
    echo "======== Mode: Execute full pipeline (infer + eval)"
fi

# Function: Execute eval tasks
execute_eval_tasks() {
    local ckptstep=$1
    local seed=$2
    local task_name=$3
    
    local gen_wav_dir="results/${MODEL_NAME}_${ckptstep}/${task_name}/seed${seed}_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0"
    
    echo ">>>>>>>> Starting eval task: ckptstep=${ckptstep}, seed=${seed}, task=${task_name}"
    
    case $task_name in
        "seedtts_test_zh")
            python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh -g "$gen_wav_dir" -n "$GPUS" $LOCAL
            python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh -g "$gen_wav_dir" -n "$GPUS" $LOCAL
            python src/f5_tts/eval/eval_utmos.py --audio_dir "$gen_wav_dir"
            ;;
        "seedtts_test_en")
            python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l en -g "$gen_wav_dir" -n "$GPUS" $LOCAL
            python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l en -g "$gen_wav_dir" -n "$GPUS" $LOCAL
            python src/f5_tts/eval/eval_utmos.py --audio_dir "$gen_wav_dir"
            ;;
        "ls_pc_test_clean")
            python src/f5_tts/eval/eval_librispeech_test_clean.py -e wer -g "$gen_wav_dir" -n "$GPUS" -p "$LS_TEST_CLEAN_PATH" $LOCAL
            python src/f5_tts/eval/eval_librispeech_test_clean.py -e sim -g "$gen_wav_dir" -n "$GPUS" -p "$LS_TEST_CLEAN_PATH" $LOCAL
            python src/f5_tts/eval/eval_utmos.py --audio_dir "$gen_wav_dir"
            ;;
    esac
    
    echo ">>>>>>>> Completed eval task: ckptstep=${ckptstep}, seed=${seed}, task=${task_name}"
}

# Main execution loop
for ckptstep in "${CKPTSTEPS[@]}"; do
    echo "======== Processing ckptstep: ${ckptstep}"
    
    for seed in "${SEEDS[@]}"; do
        echo "-------- Processing seed: ${seed}"
        
        # Store eval task PIDs for current seed (if not infer-only mode)
        if [ "$INFER_ONLY" = false ]; then
            declare -a eval_pids
        fi
        
        # Execute each infer task sequentially
        for task in "${TASKS[@]}"; do
            echo ">>>>>>>> Executing infer task: accelerate launch src/f5_tts/eval/eval_infer_batch.py -s ${seed} -n \"${MODEL_NAME}\" -t \"${task}\" -c ${ckptstep} $LOCAL"
            
            # Execute infer task (foreground execution, wait for completion)
            accelerate launch src/f5_tts/eval/eval_infer_batch.py -s ${seed} -n "${MODEL_NAME}" -t "${task}" -c ${ckptstep} -p "${LS_TEST_CLEAN_PATH}" $LOCAL
            
            # If not infer-only mode, launch corresponding eval task
            if [ "$INFER_ONLY" = false ]; then
                # Launch corresponding eval task (background execution, non-blocking for next infer)
                execute_eval_tasks $ckptstep $seed $task &
                eval_pids+=($!)
            fi
        done
        
        # If not infer-only mode, wait for all eval tasks of current seed to complete
        if [ "$INFER_ONLY" = false ]; then
            echo ">>>>>>>> All infer tasks for seed ${seed} completed, waiting for corresponding eval tasks to finish..."
            
            for pid in "${eval_pids[@]}"; do
                wait $pid
            done
            
            unset eval_pids  # Clean up array
        fi
        echo "-------- All eval tasks for seed ${seed} completed"
    done
    
    echo "======== Completed ckptstep: ${ckptstep}"
    echo
done

echo "======== All tasks completed!"

================================================
FILE: src/f5_tts/eval/eval_infer_batch_example.sh
================================================
#!/bin/bash

# e.g. F5-TTS, 16 NFE
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_zh" -nfe 16
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_en" -nfe 16
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "ls_pc_test_clean" -nfe 16 -p data/LibriSpeech/test-clean

# e.g. Vanilla E2 TTS, 32 NFE
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_zh" -o "midpoint" -ss 0
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_en" -o "midpoint" -ss 0
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "ls_pc_test_clean" -o "midpoint" -ss 0 -p data/LibriSpeech/test-clean

# e.g. evaluate F5-TTS 16 NFE result on Seed-TTS test-zh
python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
python src/f5_tts/eval/eval_utmos.py --audio_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0

# etc.


================================================
FILE: src/f5_tts/eval/eval_librispeech_test_clean.py
================================================
# Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)

import argparse
import ast
import json
import os
import sys


sys.path.append(os.getcwd())

import multiprocessing as mp
from importlib.resources import files

import numpy as np

from f5_tts.eval.utils_eval import get_librispeech_test, run_asr_wer, run_sim


rel_path = str(files("f5_tts").joinpath("../../"))


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"])
    parser.add_argument("-l", "--lang", type=str, default="en")
    parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
    parser.add_argument("-p", "--librispeech_test_clean_path", type=str, required=True)
    parser.add_argument(
        "-n", "--gpu_nums", type=str, default="8", help="Number of GPUs to use (e.g., 8) or GPU list (e.g., [0,1,2,3])"
    )
    parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory")
    return parser.parse_args()


def parse_gpu_nums(gpu_nums_str):
    try:
        if gpu_nums_str.startswith("[") and gpu_nums_str.endswith("]"):
            gpu_list = ast.literal_eval(gpu_nums_str)
            if isinstance(gpu_list, list):
                return gpu_list
        return list(range(int(gpu_nums_str)))
    except (ValueError, SyntaxError):
        raise argparse.ArgumentTypeError(
            f"Invalid GPU specification: {gpu_nums_str}. Use a number (e.g., 8) or a list (e.g., [0,1,2,3])"
        )


def main():
    args = get_args()
    eval_task = args.eval_task
    lang = args.lang
    librispeech_test_clean_path = args.librispeech_test_clean_path  # test-clean path
    gen_wav_dir = args.gen_wav_dir
    metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"

    gpus = parse_gpu_nums(args.gpu_nums)
    test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)

    ## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
    ## leading to a low similarity for the ground truth in some cases.
    # test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = True)  # eval ground truth

    local = args.local
    if local:  # use local custom checkpoint dir
        asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
    else:
        asr_ckpt_dir = ""  # auto download to cache dir
    wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"

    # --------------------------------------------------------------------------

    full_results = []
    metrics = []

    if eval_task == "wer":
        with mp.Pool(processes=len(gpus)) as pool:
            args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
            results = pool.map(run_asr_wer, args)
            for r in results:
                full_results.extend(r)
    elif eval_task == "sim":
        with mp.Pool(processes=len(gpus)) as pool:
            args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
            results = pool.map(run_sim, args)
            for r in results:
                full_results.extend(r)
    else:
        raise ValueError(f"Unknown metric type: {eval_task}")

    result_path = f"{gen_wav_dir}/_{eval_task}_results.jsonl"
    with open(result_path, "w") as f:
        for line in full_results:
            metrics.append(line[eval_task])
            f.write(json.dumps(line, ensure_ascii=False) + "\n")
        metric = round(np.mean(metrics), 5)
        f.write(f"\n{eval_task.upper()}: {metric}\n")

    print(f"\nTotal {len(metrics)} samples")
    print(f"{eval_task.upper()}: {metric}")
    print(f"{eval_task.upper()} results saved to {result_path}")


if __name__ == "__main__":
    main()


================================================
FILE: src/f5_tts/eval/eval_seedtts_testset.py
================================================
# Evaluate with Seed-TTS testset

import argparse
import ast
import json
import os
import sys


sys.path.append(os.getcwd())

import multiprocessing as mp
from importlib.resources import files

import numpy as np

from f5_tts.eval.utils_eval import get_seed_tts_test, run_asr_wer, run_sim


rel_path = str(files("f5_tts").joinpath("../../"))


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"])
    parser.add_argument("-l", "--lang", type=str, default="en", choices=["zh", "en"])
    parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
    parser.add_argument(
        "-n", "--gpu_nums", type=str, default="8", help="Number of GPUs to use (e.g., 8) or GPU list (e.g., [0,1,2,3])"
    )
    parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory")
    return parser.parse_args()


def parse_gpu_nums(gpu_nums_str):
    try:
        if gpu_nums_str.startswith("[") and gpu_nums_str.endswith("]"):
            gpu_list = ast.literal_eval(gpu_nums_str)
            if isinstance(gpu_list, list):
                return gpu_list
        return list(range(int(gpu_nums_str)))
    except (ValueError, SyntaxError):
        raise argparse.ArgumentTypeError(
            f"Invalid GPU specification: {gpu_nums_str}. Use a number (e.g., 8) or a list (e.g., [0,1,2,3])"
        )


def main():
    args = get_args()
    eval_task = args.eval_task
    lang = args.lang
    gen_wav_dir = args.gen_wav_dir
    metalst = rel_path + f"/data/seedtts_testset/{lang}/meta.lst"  # seed-tts testset

    # NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
    #       zh 1.254 seems a result of 4 workers wer_seed_tts
    gpus = parse_gpu_nums(args.gpu_nums)
    test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)

    local = args.local
    if local:  # use local custom checkpoint dir
        if lang == "zh":
            asr_ckpt_dir = "../checkpoints/funasr"  # paraformer-zh dir under funasr
        elif lang == "en":
            asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
    else:
        asr_ckpt_dir = ""  # auto download to cache dir
    wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"

    # --------------------------------------------------------------------------

    full_results = []
    metrics = []

    if eval_task == "wer":
        with mp.Pool(processes=len(gpus)) as pool:
            args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
            results = pool.map(run_asr_wer, args)
            for r in results:
                full_results.extend(r)
    elif eval_task == "sim":
        with mp.Pool(processes=len(gpus)) as pool:
            args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
            results = pool.map(run_sim, args)
            for r in results:
                full_results.extend(r)
    else:
        raise ValueError(f"Unknown metric type: {eval_task}")

    result_path = f"{gen_wav_dir}/_{eval_task}_results.jsonl"
    with open(result_path, "w") as f:
        for line in full_results:
            metrics.append(line[eval_task])
            f.write(json.dumps(line, ensure_ascii=False) + "\n")
        metric = round(np.mean(metrics), 5)
        f.write(f"\n{eval_task.upper()}: {metric}\n")

    print(f"\nTotal {len(metrics)} samples")
    print(f"{eval_task.upper()}: {metric}")
    print(f"{eval_task.upper()} results saved to {result_path}")


if __name__ == "__main__":
    main()


================================================
FILE: src/f5_tts/eval/eval_utmos.py
================================================
import argparse
import json
from pathlib import Path

import librosa
import torch
from tqdm import tqdm


def main():
    parser = argparse.ArgumentParser(description="UTMOS Evaluation")
    parser.add_argument("--audio_dir", type=str, required=True, help="Audio file path.")
    parser.add_argument("--ext", type=str, default="wav", help="Audio extension.")
    args = parser.parse_args()

    device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "cpu"

    predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True)
    predictor = predictor.to(device)

    audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
    utmos_score = 0

    utmos_result_path = Path(args.audio_dir) / "_utmos_results.jsonl"
    with open(utmos_result_path, "w", encoding="utf-8") as f:
        for audio_path in tqdm(audio_paths, desc="Processing"):
            wav, sr = librosa.load(audio_path, sr=None, mono=True)
            wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0)
            score = predictor(wav_tensor, sr)
            line = {}
            line["wav"], line["utmos"] = str(audio_path.stem), score.item()
            utmos_score += score.item()
            f.write(json.dumps(line, ensure_ascii=False) + "\n")
        avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0
        f.write(f"\nUTMOS: {avg_score:.4f}\n")

    print(f"UTMOS: {avg_score:.4f}")
    print(f"UTMOS results saved to {utmos_result_path}")


if __name__ == "__main__":
    main()


================================================
FILE: src/f5_tts/eval/utils_eval.py
================================================
import math
import os
import random
import string
from pathlib import Path

import torch
import torch.nn.functional as F
import torchaudio
from tqdm import tqdm

from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL
from f5_tts.model.modules import MelSpec
from f5_tts.model.utils import convert_char_to_pinyin


# seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
def get_seedtts_testset_metainfo(metalst):
    f = open(metalst)
    lines = f.readlines()
    f.close()
    metainfo = []
    for line in lines:
        if len(line.strip().split("|")) == 5:
            utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
        elif len(line.strip().split("|")) == 4:
            utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
            gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
        if not os.path.isabs(prompt_wav):
            prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
        metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
    return metainfo


# librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
    f = open(metalst)
    lines = f.readlines()
    f.close()
    metainfo = []
    for line in lines:
        ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")

        # ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.'  # if use librispeech test-clean (no-pc)
        ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
        ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")

        # gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.'  # if use librispeech test-clean (no-pc)
        gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
        gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")

        metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))

    return metainfo


# padded to max length mel batch
def padded_mel_batch(ref_mels):
    max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
    padded_ref_mels = []
    for mel in ref_mels:
        padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0)
        padded_ref_mels.append(padded_ref_mel)
    padded_ref_mels = torch.stack(padded_ref_mels)
    padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
    return padded_ref_mels


# get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav


def get_inference_prompt(
    metainfo,
    speed=1.0,
    tokenizer="pinyin",
    polyphone=True,
    target_sample_rate=24000,
    n_fft=1024,
    win_length=1024,
    n_mel_channels=100,
    hop_length=256,
    mel_spec_type="vocos",
    target_rms=0.1,
    use_truth_duration=False,
    infer_batch_size=1,
    num_buckets=200,
    min_secs=3,
    max_secs=40,
):
    prompts_all = []

    min_tokens = min_secs * target_sample_rate // hop_length
    max_tokens = max_secs * target_sample_rate // hop_length

    batch_accum = [0] * num_buckets
    utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
        [[] for _ in range(num_buckets)] for _ in range(6)
    )

    mel_spectrogram = MelSpec(
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        n_mel_channels=n_mel_channels,
        target_sample_rate=target_sample_rate,
        mel_spec_type=mel_spec_type,
    )

    for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
        # Audio
        ref_audio, ref_sr = torchaudio.load(prompt_wav)
        ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
        if ref_rms < target_rms:
            ref_audio = ref_audio * target_rms / ref_rms
        assert ref_audio.shape[-1] > 5000, f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
        if ref_sr != target_sample_rate:
            resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
            ref_audio = resampler(ref_audio)

        # Text
        if len(prompt_text[-1].encode("utf-8")) == 1:
            prompt_text = prompt_text + " "
        text = [prompt_text + gt_text]
        if tokenizer == "pinyin":
            text_list = convert_char_to_pinyin(text, polyphone=polyphone)
        else:
            text_list = text

        # to mel spectrogram
        ref_mel = mel_spectrogram(ref_audio)
        ref_mel = ref_mel.squeeze(0)

        # Duration, mel frame length
        ref_mel_len = ref_mel.shape[-1]

        if use_truth_duration:
            gt_audio, gt_sr = torchaudio.load(gt_wav)
            if gt_sr != target_sample_rate:
                resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
                gt_audio = resampler(gt_audio)
            total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)

            # # test vocoder resynthesis
            # ref_audio = gt_audio
        else:
            ref_text_len = len(prompt_text.encode("utf-8"))
            gen_text_len = len(gt_text.encode("utf-8"))
            total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)

        # deal with batch
        assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
        assert min_tokens <= total_mel_len <= max_tokens, (
            f"Audio {utt} has duration {total_mel_len * hop_length // target_sample_rate}s out of range [{min_secs}, {max_secs}]."
        )
        bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)

        utts[bucket_i].append(utt)
        ref_rms_list[bucket_i].append(ref_rms)
        ref_mels[bucket_i].append(ref_mel)
        ref_mel_lens[bucket_i].append(ref_mel_len)
        total_mel_lens[bucket_i].append(total_mel_len)
        final_text_list[bucket_i].extend(text_list)

        batch_accum[bucket_i] += total_mel_len

        if batch_accum[bucket_i] >= infer_batch_size:
            # print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
            prompts_all.append(
                (
                    utts[bucket_i],
                    ref_rms_list[bucket_i],
                    padded_mel_batch(ref_mels[bucket_i]),
                    ref_mel_lens[bucket_i],
                    total_mel_lens[bucket_i],
                    final_text_list[bucket_i],
                )
            )
            batch_accum[bucket_i] = 0
            (
                utts[bucket_i],
                ref_rms_list[bucket_i],
                ref_mels[bucket_i],
                ref_mel_lens[bucket_i],
                total_mel_lens[bucket_i],
                final_text_list[bucket_i],
            ) = [], [], [], [], [], []

    # add residual
    for bucket_i, bucket_frames in enumerate(batch_accum):
        if bucket_frames > 0:
            prompts_all.append(
                (
                    utts[bucket_i],
                    ref_rms_list[bucket_i],
                    padded_mel_batch(ref_mels[bucket_i]),
                    ref_mel_lens[bucket_i],
                    total_mel_lens[bucket_i],
                    final_text_list[bucket_i],
                )
            )
    # not only leave easy work for last workers
    random.seed(666)
    random.shuffle(prompts_all)

    return prompts_all


# get wav_res_ref_text of seed-tts test metalst
# https://github.com/BytedanceSpeech/seed-tts-eval


def get_seed_tts_test(metalst, gen_wav_dir, gpus):
    f = open(metalst)
    lines = f.readlines()
    f.close()

    test_set_ = []
    for line in tqdm(lines):
        if len(line.strip().split("|")) == 5:
            utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
        elif len(line.strip().split("|")) == 4:
            utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")

        if not os.path.exists(os.path.join(gen_wav_dir, utt + ".wav")):
            continue
        gen_wav = os.path.join(gen_wav_dir, utt + ".wav")
        if not os.path.isabs(prompt_wav):
            prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)

        test_set_.append((gen_wav, prompt_wav, gt_text))

    num_jobs = len(gpus)
    if num_jobs == 1:
        return [(gpus[0], test_set_)]

    wav_per_job = len(test_set_) // num_jobs + 1
    test_set = []
    for i in range(num_jobs):
        test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))

    return test_set


# get librispeech test-clean cross sentence test


def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth=False):
    f = open(metalst)
    lines = f.readlines()
    f.close()

    test_set_ = []
    for line in tqdm(lines):
        ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")

        if eval_ground_truth:
            gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
            gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
        else:
            if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + ".wav")):
                raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
            gen_wav = os.path.join(gen_wav_dir, gen_utt + ".wav")

        ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
        ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")

        test_set_.append((gen_wav, ref_wav, gen_txt))

    num_jobs = len(gpus)
    if num_jobs == 1:
        return [(gpus[0], test_set_)]

    wav_per_job = len(test_set_) // num_jobs + 1
    test_set = []
    for i in range(num_jobs):
        test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))

    return test_set


# load asr model


def load_asr_model(lang, ckpt_dir=""):
    if lang == "zh":
        from funasr import AutoModel

        model = AutoModel(
            model=os.path.join(ckpt_dir, "paraformer-zh"),
            # vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
            # punc_model = os.path.join(ckpt_dir, "ct-punc"),
            # spk_model = os.path.join(ckpt_dir, "cam++"),
            disable_update=True,
        )  # following seed-tts setting
    elif lang == "en":
        from faster_whisper import WhisperModel

        model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
        model = WhisperModel(model_size, device="cuda", compute_type="float16")
    return model


# WER Evaluation, the way Seed-TTS does


def run_asr_wer(args):
    rank, lang, test_set, ckpt_dir = args

    if lang == "zh":
        import zhconv

        torch.cuda.set_device(rank)
    elif lang == "en":
        os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
    else:
        raise NotImplementedError(
            "lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now."
        )

    asr_model = load_asr_model(lang, ckpt_dir=ckpt_dir)

    from zhon.hanzi import punctuation

    punctuation_all = punctuation + string.punctuation
    wer_results = []

    from jiwer import process_words

    for gen_wav, prompt_wav, truth in tqdm(test_set):
        if lang == "zh":
            res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
            hypo = res[0]["text"]
            hypo = zhconv.convert(hypo, "zh-cn")
        elif lang == "en":
            segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
            hypo = ""
            for segment in segments:
                hypo = hypo + " " + segment.text

        raw_truth = truth
        raw_hypo = hypo

        for x in punctuation_all:
            truth = truth.replace(x, "")
            hypo = hypo.replace(x, "")

        truth = truth.replace("  ", " ")
        hypo = hypo.replace("  ", " ")

        if lang == "zh":
            truth = " ".join([x for x in truth])
            hypo = " ".join([x for x in hypo])
        elif lang == "en":
            truth = truth.lower()
            hypo = hypo.lower()

        measures = process_words(truth, hypo)
        wer = measures.wer

        # ref_list = truth.split(" ")
        # subs = measures.substitutions / len(ref_list)
        # dele = measures.deletions / len(ref_list)
        # inse = measures.insertions / len(ref_list)

        wer_results.append(
            {
                "wav": Path(gen_wav).stem,
                "truth": raw_truth,
                "hypo": raw_hypo,
                "wer": wer,
            }
        )

    return wer_results


# SIM Evaluation


def run_sim(args):
    rank, test_set, ckpt_dir = args
    device = f"cuda:{rank}"

    model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type="wavlm_large", config_path=None)
    state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
    model.load_state_dict(state_dict["model"], strict=False)

    use_gpu = True if torch.cuda.is_available() else False
    if use_gpu:
        model = model.cuda(device)
    model.eval()

    sim_results = []
    for gen_wav, prompt_wav, truth in tqdm(test_set):
        wav1, sr1 = torchaudio.load(gen_wav)
        wav2, sr2 = torchaudio.load(prompt_wav)

        if use_gpu:
            wav1 = wav1.cuda(device)
            wav2 = wav2.cuda(device)

        if sr1 != 16000:
            resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
            if use_gpu:
                resample1 = resample1.cuda(device)
            wav1 = resample1(wav1)
        if sr2 != 16000:
            resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
            if use_gpu:
                resample2 = resample2.cuda(device)
            wav2 = resample2(wav2)

        with torch.no_grad():
            emb1 = model(wav1)
            emb2 = model(wav2)

        sim = F.cosine_similarity(emb1, emb2)[0].item()
        # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
        sim_results.append(
            {
                "wav": Path(gen_wav).stem,
                "sim": sim,
            }
        )

    return sim_results


================================================
FILE: src/f5_tts/infer/README.md
================================================
# Inference

The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS) and [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), or will be automatically downloaded when running inference scripts.

**More checkpoints with whole community efforts can be found in [SHARED.md](SHARED.md), supporting more languages.**

Currently support **30s for a single** generation, which is the **total length** (same logic if `fix_duration`) including both prompt and output audio. However, `infer_cli` and `infer_gradio` will automatically do chunk generation for longer text. Long reference audio will be **clip short to ~12s**.

To avoid possible inference failures, make sure you have seen through the following instructions.

- Use reference audio <12s and leave proper silence space (e.g. 1s) at the end. Otherwise there is a risk of truncating in the middle of word, leading to suboptimal generation.
- <ins>Uppercased letters</ins> (best with form like K.F.C.) will be uttered letter by letter, and lowercased letters used for common words. 
- Add some spaces (blank: " ") or punctuations (e.g. "," ".") <ins>to explicitly introduce some pauses</ins>.
- If English punctuation marks the end of a sentence, make sure there is a space " " after it. Otherwise not regarded as when chunk.
- <ins>Preprocess numbers</ins> to Chinese letters if you want to have them read in Chinese, otherwise in English.
- If the generation output is blank (pure silence), <ins>check for FFmpeg installation</ins>.
- Try <ins>turn off `use_ema` if using an early-stage</ins> finetuned checkpoint (which goes just few updates).


## Gradio App

Currently supported features:

- Basic TTS with Chunk Inference
- Multi-Style / Multi-Speaker Generation
- Voice Chat powered by Qwen2.5-3B-Instruct
- [Custom inference with more language support](SHARED.md)

The cli command `f5-tts_infer-gradio` equals to `python src/f5_tts/infer/infer_gradio.py`, which launches a Gradio APP (web interface) for inference.

The script will load model checkpoints from Huggingface. You can also manually download files and update the path to `load_model()` in `infer_gradio.py`. Currently only load TTS models first, will load ASR model to do transcription if `ref_text` not provided, will load LLM model if use Voice Chat.

More flags options:

```bash
# Automatically launch the interface in the default web browser
f5-tts_infer-gradio --inbrowser

# Set the root path of the application, if it's not served from the root ("/") of the domain
# For example, if the application is served at "https://example.com/myapp"
f5-tts_infer-gradio --root_path "/myapp"
```

Could also be used as a component for larger application:
```python
import gradio as gr
from f5_tts.infer.infer_gradio import app

with gr.Blocks() as main_app:
    gr.Markdown("# This is an example of using F5-TTS within a bigger Gradio app")

    # ... other Gradio components

    app.render()

main_app.launch()
```


## CLI Inference

The cli command `f5-tts_infer-cli` equals to `python src/f5_tts/infer/infer_cli.py`, which is a command line tool for inference.

The script will load model checkpoints from Huggingface. You can also manually download files and use `--ckpt_file` to specify the model you want to load, or directly update in `infer_cli.py`.

For change vocab.txt use `--vocab_file` to provide your `vocab.txt` file.

Basically you can inference with flags:
```bash
# Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
f5-tts_infer-cli \
--model F5TTS_v1_Base \
--ref_audio "ref_audio.wav" \
--ref_text "The content, subtitle or transcription of reference audio." \
--gen_text "Some text you want TTS model generate for you."

# Use BigVGAN as vocoder. Currently only support F5TTS_Base. 
f5-tts_infer-cli --model F5TTS_Base --vocoder_name bigvgan --load_vocoder_from_local

# Use custom path checkpoint, e.g.
f5-tts_infer-cli --ckpt_file ckpts/F5TTS_v1_Base/model_1250000.safetensors

# More instructions
f5-tts_infer-cli --help
```

And a `.toml` file would help with more flexible usage.

```bash
f5-tts_infer-cli -c custom.toml
```

For example, you can use `.toml` to pass in variables, refer to `src/f5_tts/infer/examples/basic/basic.toml`:

```toml
# F5TTS_v1_Base | E2TTS_Base
model = "F5TTS_v1_Base"
ref_audio = "infer/examples/basic/basic_ref_en.wav"
# If an empty "", transcribes the reference audio automatically.
ref_text = "Some call me nature, others call me mother nature."
gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
# File with text to generate. Ignores the text above.
gen_file = ""
remove_silence = false
output_dir = "tests"
```

You can also leverage `.toml` file to do multi-style generation, refer to `src/f5_tts/infer/examples/multi/story.toml`.

```toml
# F5TTS_v1_Base | E2TTS_Base
model = "F5TTS_v1_Base"
ref_audio = "infer/examples/multi/main.flac"
# If an empty "", transcribes the reference audio automatically.
ref_text = ""
gen_text = ""
# File with text to generate. Ignores the text above.
gen_file = "infer/examples/multi/story.txt"
remove_silence = true
output_dir = "tests"

[voices.town]
ref_audio = "infer/examples/multi/town.flac"
ref_text = ""

[voices.country]
ref_audio = "infer/examples/multi/country.flac"
ref_text = ""
```
You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`.

## API Usage

```python
from importlib.resources import files
from f5_tts.api import F5TTS

f5tts = F5TTS()
wav, sr, spec = f5tts.infer(
    ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
    ref_text="some call me nature, others call me mother nature.",
    gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
    file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
    file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")),
    seed=None,
)
```
Check [api.py](../api.py) for more details.

## TensorRT-LLM Deployment

See [detailed instructions](../runtime/triton_trtllm/README.md) for more information.

## Socket Real-time Service

Real-time voice output with chunk stream:

```bash
# Start socket server
python src/f5_tts/socket_server.py

# If PyAudio not installed
sudo apt-get install portaudio19-dev
pip install pyaudio

# Communicate with socket client
python src/f5_tts/socket_client.py
```

## Speech Editing

To test speech editing capabilities, use the following command:

```bash
python src/f5_tts/infer/speech_edit.py
```



================================================
FILE: src/f5_tts/infer/SHARED.md
================================================
<!-- omit in toc -->
# Shared Model Cards

<!-- omit in toc -->
### **Prerequisites of using**
- This document is serving as a quick lookup table for the community training/finetuning result, with various language support.
- The models in this repository are open source and are based on voluntary contributions from contributors.
- The use of models must be conditioned on respect for the respective creators. The convenience brought comes from their efforts.

<!-- omit in toc -->
### **Welcome to share here**
- Have a pretrained/finetuned result: model checkpoint (pruned best to facilitate inference, i.e. leave only `ema_model_state_dict`) and corresponding vocab file (for tokenization).
- Host a public [huggingface model repository](https://huggingface.co/new) and upload the model related files.
- Make a pull request adding a model card to the current page, i.e. `src\f5_tts\infer\SHARED.md`.

<!-- omit in toc -->
### Supported Languages
- [Multilingual](#multilingual)
    - [F5-TTS v1 v0 Base @ zh \& en @ F5-TTS](#f5-tts-v1-v0-base--zh--en--f5-tts)
- [Arabic](#arabic)
    - [F5-TTS Small @ ar & en @ SILMA AI](#f5-tts-small--ar--en--silma-ai)
- [English](#english)
- [Finnish](#finnish)
    - [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen)
- [French](#french)
    - [F5-TTS Base @ fr @ RASPIAUDIO](#f5-tts-base--fr--raspiaudio)
- [German](#german)
    - [F5-TTS Base @ de @ hvoss-techfak](#f5-tts-base--de--hvoss-techfak)
- [Hindi](#hindi)
    - [F5-TTS Small @ hi @ SPRINGLab](#f5-tts-small--hi--springlab)
- [Italian](#italian)
    - [F5-TTS Base @ it @ alien79](#f5-tts-base--it--alien79)
- [Japanese](#japanese)
    - [F5-TTS Base @ ja @ Jmica](#f5-tts-base--ja--jmica)
- [Latvian](#latvian)
    - [F5-TTS Base @ lv @ RaivisDejus](#f5-tts-base--lv--raivisdejus)
- [Mandarin](#mandarin)
- [Russian](#russian)
    - [F5-TTS Base @ ru @ HotDro4illa](#f5-tts-base--ru--hotdro4illa)
- [Spanish](#spanish)
    - [F5-TTS Base @ es @ jpgallegoar](#f5-tts-base--es--jpgallegoar)


## Multilingual

#### F5-TTS v1 v0 Base @ zh & en @ F5-TTS
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS v1 Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_v1_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|

```bash
Model: hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors
# A Variant Model: hf://SWivid/F5-TTS/F5TTS_v1_Base_no_zero_init/model_1250000.safetensors
Vocab: hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
```

|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|

```bash
Model: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors
Vocab: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```

*Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...*


## Arabic

#### F5-TTS Small @ ar & en @ SILMA AI
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Small|[ckpt & vocab](https://huggingface.co/silma-ai/silma-tts)| Tens of thousands EN/AR |Apache-2.0|

- Pretrained by [SILMA.AI](https://silma.ai)
- [GitHub repo](https://github.com/SILMA-AI/silma-tts), Inference code


## English


## Finnish

#### F5-TTS Base @ fi @ AsmoKoskinen
|Model|🤗Hugging Face|Data|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/AsmoKoskinen/F5-TTS_Finnish_Model)|[Common Voice](https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0), [Vox Populi](https://huggingface.co/datasets/facebook/voxpopuli)|cc-by-nc-4.0|

```bash
Model: hf://AsmoKoskinen/F5-TTS_Finnish_Model/model_common_voice_fi_vox_populi_fi_20241206.safetensors
Vocab: hf://AsmoKoskinen/F5-TTS_Finnish_Model/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```


## French

#### F5-TTS Base @ fr @ RASPIAUDIO
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/RASPIAUDIO/F5-French-MixedSpeakers-reduced)|[LibriVox](https://librivox.org/)|cc-by-nc-4.0|

```bash
Model: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt
Vocab: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```

- [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
- [Tutorial video to train a new language model](https://www.youtube.com/watch?v=UO4usaOojys).
- [Discussion about this training can be found here](https://github.com/SWivid/F5-TTS/issues/434).


## German

#### F5-TTS Base @ de @ hvoss-techfak
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/hvoss-techfak/F5-TTS-German)|[Mozilla Common Voice 19.0](https://commonvoice.mozilla.org/en/datasets) & 800 hours Crowdsourced |cc-by-nc-4.0|

```bash
Model: hf://hvoss-techfak/F5-TTS-German/model_f5tts_german.pt
Vocab: hf://hvoss-techfak/F5-TTS-German/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```

- Finetuned by [@hvoss-techfak](https://github.com/hvoss-techfak)


## Hindi

#### F5-TTS Small @ hi @ SPRINGLab
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Small|[ckpt & vocab](https://huggingface.co/SPRINGLab/F5-Hindi-24KHz)|[IndicTTS Hi](https://huggingface.co/datasets/SPRINGLab/IndicTTS-Hindi) & [IndicVoices-R Hi](https://huggingface.co/datasets/SPRINGLab/IndicVoices-R_Hindi) |cc-by-4.0|

```bash
Model: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors
Vocab: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt
Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```

- Authors: SPRING Lab, Indian Institute of Technology, Madras
- Website: https://asr.iitm.ac.in/


## Italian

#### F5-TTS Base @ it @ alien79
|Model|🤗Hugging Face|Data|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/alien79/F5-TTS-italian)|[ylacombe/cml-tts](https://huggingface.co/datasets/ylacombe/cml-tts) |cc-by-nc-4.0|

```bash
Model: hf://alien79/F5-TTS-italian/model_159600.safetensors
Vocab: hf://alien79/F5-TTS-italian/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```

- Trained by [Mithril Man](https://github.com/MithrilMan)
- Model details on [hf project home](https://huggingface.co/alien79/F5-TTS-italian)
- Open to collaborations to further improve the model


## Japanese

#### F5-TTS Base @ ja @ Jmica
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/Jmica/F5TTS/tree/main/JA_21999120)|[Emilia 1.7k JA](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07) & [Galgame Dataset 5.4k](https://huggingface.co/datasets/OOPPEENN/Galgame_Dataset)|cc-by-nc-4.0|

```bash
Model: hf://Jmica/F5TTS/JA_21999120/model_21999120.pt
Vocab: hf://Jmica/F5TTS/JA_21999120/vocab_japanese.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```


## Latvian

#### F5-TTS Base @ lv @ RaivisDejus
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/RaivisDejus/F5-TTS-Latvian)|[Common voice](https://datacollective.mozillafoundation.org/datasets/cmj8u3pec00flnxxbntvfb4as)|cc0-1.0|

```bash
Model: hf://RaivisDejus/F5-TTS-Latvian/model.safetensors
Vocab: hf://RaivisDejus/F5-TTS-Latvian/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```


## Mandarin


## Russian

#### F5-TTS Base @ ru @ HotDro4illa
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/hotstone228/F5-TTS-Russian)|[Common voice](https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0)|cc-by-nc-4.0|

```bash
Model: hf://hotstone228/F5-TTS-Russian/model_last.safetensors
Vocab: hf://hotstone228/F5-TTS-Russian/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
- Finetuned by [HotDro4illa](https://github.com/HotDro4illa)
- Any improvements are welcome


## Spanish

#### F5-TTS Base @ es @ jpgallegoar
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/jpgallegoar/F5-Spanish)|[Voxpopuli](https://huggingface.co/datasets/facebook/voxpopuli) & Crowdsourced & TEDx, 218 hours|cc0-1.0|

- @jpgallegoar [GitHub repo](https://github.com/jpgallegoar/Spanish-F5), Jupyter Notebook and Gradio usage for Spanish model.


================================================
FILE: src/f5_tts/infer/examples/basic/basic.toml
================================================
# F5TTS_v1_Base | E2TTS_Base
model = "F5TTS_v1_Base"
ref_audio = "infer/examples/basic/basic_ref_en.wav"
# If an empty "", transcribes the reference audio automatically.
ref_text = "Some call me nature, others call me mother nature."
gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
# File with text to generate. Ignores the text above.
gen_file = ""
remove_silence = false
output_dir = "tests"
output_file = "infer_cli_basic.wav"


================================================
FILE: src/f5_tts/infer/examples/multi/story.toml
================================================
# F5TTS_v1_Base | E2TTS_Base
model = "F5TTS_v1_Base"
ref_audio = "infer/examples/multi/main.flac"
# If an empty "", transcribes the reference audio automatically.
ref_text = ""
gen_text = ""
# File with text to generate. Ignores the text above.
gen_file = "infer/examples/multi/story.txt"
remove_silence = true
output_dir = "tests"
output_file = "infer_cli_story.wav"

[voices.town]
ref_audio = "infer/examples/multi/town.flac"
ref_text = ""
speed = 0.8  # will ignore global speed

[voices.country]
ref_audio = "infer/examples/multi/country.flac"
ref_text = ""


================================================
FILE: src/f5_tts/infer/examples/multi/story.txt
================================================
A Town Mouse and a Country Mouse were acquaintances, and the Country Mouse one day invited his friend to come and see him at his home in the fields. The Town Mouse came, and they sat down to a dinner of barleycorns and roots, the latter of which had a distinctly earthy flavour. The fare was not much to the taste of the guest, and presently he broke out with [town] "My poor dear friend, you live here no better than the ants! Now, you should just see how I fare! My larder is a regular horn of plenty. You must come and stay with me, and I promise you you shall live on the fat of the land." [main] So when he returned to town he took the Country Mouse with him, and showed him into a larder containing flour and oatmeal and figs and honey and dates. The Country Mouse had never seen anything like it, and sat down to enjoy the luxuries his friend provided: but before they had well begun, the door of the larder opened and someone came in. The two Mice scampered off and hid themselves in a narrow and exceedingly uncomfortable hole. Presently, when all was quiet, they ventured out again; but someone else came in, and off they scuttled again. This was too much for the visitor. [country] "Goodbye," [main] said he, [country] "I'm off. You live in the lap of luxury, I can see, but you are surrounded by dangers; whereas at home I can enjoy my simple dinner of roots and corn in peace."

================================================
FILE: src/f5_tts/infer/examples/vocab.txt
================================================
 
!
"
#
$
%
&
'
(
)
*
+
,
-
.
/
0
1
2
3
4
5
6
7
8
9
:
;
=
>
?
@
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
[
\
]
_
a
a1
ai1
ai2
ai3
ai4
an1
an3
an4
ang1
ang2
ang4
ao1
ao2
ao3
ao4
b
ba
ba1
ba2
ba3
ba4
bai1
bai2
bai3
bai4
ban1
ban2
ban3
ban4
bang1
bang2
bang3
bang4
bao1
bao2
bao3
bao4
bei
bei1
bei2
bei3
bei4
ben1
ben2
ben3
ben4
beng
beng1
beng2
beng3
beng4
bi1
bi2
bi3
bi4
bian1
bian2
bian3
bian4
biao1
biao2
biao3
bie1
bie2
bie3
bie4
bin1
bin4
bing1
bing2
bing3
bing4
bo
bo1
bo2
bo3
bo4
bu2
bu3
bu4
c
ca1
cai1
cai2
cai3
cai4
can1
can2
can3
can4
cang1
cang2
cao1
cao2
cao3
ce4
cen1
cen2
ceng1
ceng2
ceng4
cha1
cha2
cha3
cha4
chai1
chai2
chan1
chan2
chan3
chan4
chang1
chang2
chang3
chang4
chao1
chao2
chao3
che1
che2
che3
che4
chen1
chen2
chen3
chen4
cheng1
cheng2
cheng3
cheng4
chi1
chi2
chi3
chi4
chong1
chong2
chong3
chong4
chou1
chou2
chou3
chou4
chu1
chu2
chu3
chu4
chua1
chuai1
chuai2
chuai3
chuai4
chuan1
chuan2
chuan3
chuan4
chuang1
chuang2
chuang3
chuang4
chui1
chui2
chun1
chun2
chun3
chuo1
chuo4
ci1
ci2
ci3
ci4
cong1
cong2
cou4
cu1
cu4
cuan1
cuan2
cuan4
cui1
cui3
cui4
cun1
cun2
cun4
cuo1
cuo2
cuo4
d
da
da1
da2
da3
da4
dai1
dai2
dai3
dai4
dan1
dan2
dan3
dan4
dang1
dang2
dang3
dang4
dao1
dao2
dao3
dao4
de
de1
de2
dei3
den4
deng1
deng2
deng3
deng4
di1
di2
di3
di4
dia3
dian1
dian2
dian3
dian4
diao1
diao3
diao4
die1
die2
die4
ding1
ding2
ding3
ding4
diu1
dong1
dong3
dong4
dou1
dou2
dou3
dou4
du1
du2
du3
du4
duan1
duan2
duan3
duan4
dui1
dui4
dun1
dun3
dun4
duo1
duo2
duo3
duo4
e
e1
e2
e3
e4
ei2
en1
en4
er
er2
er3
er4
f
fa1
fa2
fa3
fa4
fan1
fan2
fan3
fan4
fang1
fang2
fang3
fang4
fei1
fei2
fei3
fei4
fen1
fen2
fen3
fen4
feng1
feng2
feng3
feng4
fo2
fou2
fou3
fu1
fu2
fu3
fu4
g
ga1
ga2
ga3
ga4
gai1
gai2
gai3
gai4
gan1
gan2
gan3
gan4
gang1
gang2
gang3
gang4
gao1
gao2
gao3
gao4
ge1
ge2
ge3
ge4
gei2
gei3
gen1
gen2
gen3
gen4
geng1
geng3
geng4
gong1
gong3
gong4
gou1
gou2
gou3
gou4
gu
gu1
gu2
gu3
gu4
gua1
gua2
gua3
gua4
guai1
guai2
guai3
guai4
guan1
guan2
guan3
guan4
guang1
guang2
guang3
guang4
gui1
gui2
gui3
gui4
gun3
gun4
guo1
guo2
guo3
guo4
h
ha1
ha2
ha3
hai1
hai2
hai3
hai4
han1
han2
han3
han4
hang1
hang2
hang4
hao1
hao2
hao3
hao4
he1
he2
he4
hei1
hen2
hen3
hen4
heng1
heng2
heng4
hong1
hong2
hong3
hong4
hou1
hou2
hou3
hou4
hu1
hu2
hu3
hu4
hua1
hua2
hua4
huai2
huai4
huan1
huan2
huan3
huan4
huang1
huang2
huang3
huang4
hui1
hui2
hui3
hui4
hun1
hun2
hun4
huo
huo1
huo2
huo3
huo4
i
j
ji1
ji2
ji3
ji4
jia
jia1
jia2
jia3
jia4
jian1
jian2
jian3
jian4
jiang1
jiang2
jiang3
jiang4
jiao1
jiao2
jiao3
jiao4
jie1
jie2
jie3
jie4
jin1
jin2
jin3
jin4
jing1
jing2
jing3
jing4
jiong3
jiu1
jiu2
jiu3
jiu4
ju1
ju2
ju3
ju4
juan1
juan2
juan3
juan4
jue1
jue2
jue4
jun1
jun4
k
ka1
ka2
ka3
kai1
kai2
kai3
kai4
kan1
kan2
kan3
kan4
kang1
kang2
kang4
kao1
kao2
kao3
kao4
ke1
ke2
ke3
ke4
ken3
keng1
kong1
kong3
kong4
kou1
kou2
kou3
kou4
ku1
ku2
ku3
ku4
kua1
kua3
kua4
kuai3
kuai4
kuan1
kuan2
kuan3
kuang1
kuang2
kuang4
kui1
kui2
kui3
kui4
kun1
kun3
kun4
kuo4
l
la
la1
la2
la3
la4
lai2
lai4
lan2
lan3
lan4
lang1
lang2
lang3
lang4
lao1
lao2
lao3
lao4
le
le1
le4
lei
lei1
lei2
lei3
lei4
leng1
leng2
leng3
leng4
li
li1
li2
li3
li4
lia3
lian2
lian3
lian4
liang2
liang3
liang4
liao1
liao2
liao3
liao4
lie1
lie2
lie3
lie4
lin1
lin2
lin3
lin4
ling2
ling3
ling4
liu1
liu2
liu3
liu4
long1
long2
long3
long4
lou1
lou2
lou3
lou4
lu1
lu2
lu3
lu4
luan2
luan3
luan4
lun1
lun2
lun4
luo1
luo2
luo3
luo4
lv2
lv3
lv4
lve3
lve4
m
ma
ma1
ma2
ma3
ma4
mai2
mai3
mai4
man1
man2
man3
man4
mang2
mang3
mao1
mao2
mao3
mao4
me
mei2
mei3
mei4
men
men1
men2
men4
meng
meng1
meng2
meng3
meng4
mi1
mi2
mi3
mi4
mian2
mian3
mian4
miao1
miao2
miao3
miao4
mie1
mie4
min2
min3
ming2
ming3
ming4
miu4
mo1
mo2
mo3
mo4
mou1
mou2
mou3
mu2
mu3
mu4
n
n2
na1
na2
na3
na4
nai2
nai3
nai4
nan1
nan2
nan3
nan4
nang1
nang2
nang3
nao1
nao2
nao3
nao4
ne
ne2
ne4
nei3
nei4
nen4
neng2
ni1
ni2
ni3
ni4
nian1
nian2
nian3
nian4
niang2
niang4
niao2
niao3
niao4
nie1
nie4
nin2
ning2
ning3
ning4
niu1
niu2
niu3
niu4
nong2
nong4
nou4
nu2
nu3
nu4
nuan3
nuo2
nuo4
nv2
nv3
nve4
o
o1
o2
ou1
ou2
ou3
ou4
p
pa1
pa2
pa4
pai1
pai2
pai3
pai4
pan1
pan2
pan4
pang1
pang2
pang4
pao1
pao2
pao3
pao4
pei1
pei2
pei4
pen1
pen2
pen4
peng1
peng2
peng3
peng4
pi1
pi2
pi3
pi4
pian1
pian2
pian4
piao1
piao2
piao3
piao4
pie1
pie2
pie3
pin1
pin2
pin3
pin4
ping1
ping2
po1
po2
po3
po4
pou1
pu1
pu2
pu3
pu4
q
qi1
qi2
qi3
qi4
qia1
qia3
qia4
qian1
qian2
qian3
qian4
qiang1
qiang2
qiang3
qiang4
qiao1
qiao2
qiao3
qiao4
qie1
qie2
qie3
qie4
qin1
qin2
qin3
qin4
qing1
qing2
qing3
qing4
qiong1
qiong2
qiu1
qiu2
qiu3
qu1
qu2
qu3
qu4
quan1
quan2
quan3
quan4
que1
que2
que4
qun2
r
ran2
ran3
rang1
rang2
rang3
rang4
rao2
rao3
rao4
re2
re3
re4
ren2
ren3
ren4
reng1
reng2
ri4
rong1
rong2
rong3
rou2
rou4
ru2
ru3
ru4
ruan2
ruan3
rui3
rui4
run4
ruo4
s
sa1
sa2
sa3
sa4
sai1
sai4
san1
san2
san3
san4
sang1
sang3
sang4
sao1
sao2
sao3
sao4
se4
sen1
seng1
sha1
sha2
sha3
sha4
shai1
shai2
shai3
shai4
shan1
shan3
shan4
shang
shang1
shang3
shang4
shao1
shao2
shao3
shao4
she1
she2
she3
she4
shei2
shen1
shen2
shen3
shen4
sheng1
sheng2
sheng3
sheng4
shi
shi1
shi2
shi3
shi4
shou1
shou2
shou3
shou4
shu1
shu2
shu3
shu4
shua1
shua2
shua3
shua4
shuai1
shuai3
shuai4
shuan1
shuan4
shuang1
shuang3
shui2
shui3
shui4
shun3
shun4
shuo1
shuo4
si1
si2
si3
si4
song1
song3
song4
sou1
sou3
sou4
su1
su2
su4
suan1
suan4
sui1
sui2
sui3
sui4
sun1
sun3
suo
suo1
suo2
suo3
t
ta1
ta2
ta3
ta4
tai1
tai2
tai4
tan1
tan2
tan3
tan4
tang1
tang2
tang3
tang4
tao1
tao2
tao3
tao4
te4
teng2
ti1
ti2
ti3
ti4
tian1
tian2
tian3
tiao1
tiao2
tiao3
tiao4
tie1
tie2
tie3
tie4
ting1
ting2
ting3
tong1
tong2
tong3
tong4
tou
tou1
tou2
tou4
tu1
tu2
tu3
tu4
tuan1
tuan2
tui1
tui2
tui3
tui4
tun1
tun2
tun4
tuo1
tuo2
tuo3
tuo4
u
v
w
wa
wa1
wa2
wa3
wa4
wai1
wai3
wai4
wan1
wan2
wan3
wan4
wang1
wang2
wang3
wang4
wei1
wei2
wei3
wei4
wen1
wen2
wen3
wen4
weng1
weng4
wo1
wo2
wo3
wo4
wu1
wu2
wu3
wu4
x
xi1
xi2
xi3
xi4
xia1
xia2
xia4
xian1
xian2
xian3
xian4
xiang1
xiang2
xiang3
xiang4
xiao1
xiao2
xiao3
xiao4
xie1
xie2
xie3
xie4
xin1
xin2
xin4
xing1
xing2
xing3
xing4
xiong1
xiong2
xiu1
xiu3
xiu4
xu
xu1
xu2
xu3
xu4
xuan1
xuan2
xuan3
xuan4
xue1
xue2
xue3
xue4
xun1
xun2
xun4
y
ya
ya1
ya2
ya3
ya4
yan1
yan2
yan3
yan4
yang1
yang2
yang3
yang4
yao1
yao2
yao3
yao4
ye1
ye2
ye3
ye4
yi
yi1
yi2
yi3
yi4
yin1
yin2
yin3
yin4
ying1
ying2
ying3
ying4
yo1
yong1
yong2
yong3
yong4
you1
you2
you3
you4
yu1
yu2
yu3
yu4
yuan1
yuan2
yuan3
yuan4
yue1
yue4
yun1
yun2
yun3
yun4
z
za1
za2
za3
zai1
zai3
zai4
zan1
zan2
zan3
zan4
zang1
zang4
zao1
zao2
zao3
zao4
ze2
ze4
zei2
zen3
zeng1
zeng4
zha1
zha2
zha3
zha4
zhai1
zhai2
zhai3
zhai4
zhan1
zhan2
zhan3
zhan4
zhang1
zhang2
zhang3
zhang4
zhao1
zhao2
zhao3
zhao4
zhe
zhe1
zhe2
zhe3
zhe4
zhen1
zhen2
zhen3
zhen4
zheng1
zheng2
zheng3
zheng4
zhi1
zhi2
zhi3
zhi4
zhong1
zhong2
zhong3
zhong4
zhou1
zhou2
zhou3
zhou4
zhu1
zhu2
zhu3
zhu4
zhua1
zhua2
zhua3
zhuai1
zhuai3
zhuai4
zhuan1
zhuan2
zhuan3
zhuan4
zhuang1
zhuang4
zhui1
zhui4
zhun1
zhun2
zhun3
zhuo1
zhuo2
zi
zi1
zi2
zi3
zi4
zong1
zong2
zong3
zong4
zou1
zou2
zou3
zou4
zu1
zu2
zu3
zuan1
zuan3
zuan4
zui2
zui3
zui4
zun1
zuo
zuo1
zuo2
zuo3
zuo4
{
~
¡
¢
£
¥
§
¨
©
«
®
¯
°
±
²
³
´
µ
·
¹
º
»
¼
½
¾
¿
À
Á
Â
Ã
Ä
Å
Æ
Ç
È
É
Ê
Í
Î
Ñ
Ó
Ö
×
Ø
Ú
Ü
Ý
Þ
ß
à
á
â
ã
ä
å
æ
ç
è
é
ê
ë
ì
í
î
ï
ð
ñ
ò
ó
ô
õ
ö
ø
ù
ú
û
ü
ý
Ā
ā
ă
ą
ć
Č
č
Đ
đ
ē
ė
ę
ě
ĝ
ğ
ħ
ī
į
İ
ı
Ł
ł
ń
ņ
ň
ŋ
Ō
ō
ő
œ
ř
Ś
ś
Ş
ş
Š
š
Ť
ť
ũ
ū
ź
Ż
ż
Ž
ž
ơ
ư
ǎ
ǐ
ǒ
ǔ
ǚ
ș
ț
ɑ
ɔ
ɕ
ə
ɛ
ɜ
ɡ
ɣ
ɪ
ɫ
ɴ
ɹ
ɾ
ʃ
ʊ
ʌ
ʒ
ʔ
ʰ
ʷ
ʻ
ʾ
ʿ
ˈ
ː
˙
˜
ˢ
́
̅
Α
Β
Δ
Ε
Θ
Κ
Λ
Μ
Ξ
Π
Σ
Τ
Φ
Χ
Ψ
Ω
ά
έ
ή
ί
α
β
γ
δ
ε
ζ
η
θ
ι
κ
λ
μ
ν
ξ
ο
π
ρ
ς
σ
τ
υ
φ
χ
ψ
ω
ϊ
ό
ύ
ώ
ϕ
ϵ
Ё
А
Б
В
Г
Д
Е
Ж
З
И
Й
К
Л
М
Н
О
П
Р
С
Т
У
Ф
Х
Ц
Ч
Ш
Щ
Ы
Ь
Э
Ю
Я
а
б
в
г
д
е
ж
з
и
й
к
л
м
н
о
п
р
с
т
у
ф
х
ц
ч
ш
щ
ъ
ы
ь
э
ю
я
ё
і
ְ
ִ
ֵ
ֶ
ַ
ָ
ֹ
ּ
־
ׁ
א
ב
ג
ד
ה
ו
ז
ח
ט
י
כ
ל
ם
מ
ן
נ
ס
ע
פ
ק
ר
ש
ת
أ
ب
ة
ت
ج
ح
د
ر
ز
س
ص
ط
ع
ق
ك
ل
م
ن
ه
و
ي
َ
ُ
ِ
ْ
ก
ข
ง
จ
ต
ท
น
ป
ย
ร
ว
ส
ห
อ
ฮ
ั
า
ี
ึ
โ
ใ
ไ
่
้
์
ḍ
Ḥ
ḥ
ṁ
ṃ
ṅ
ṇ
Ṛ
ṛ
Ṣ
ṣ
Ṭ
ṭ
ạ
ả
Ấ
ấ
ầ
ậ
ắ
ằ
ẻ
ẽ
ế
ề
ể
ễ
ệ
ị
ọ
ỏ
ố
ồ
ộ
ớ
ờ
ở
ụ
ủ
ứ
ữ
ἀ
ἁ
Ἀ
ἐ
ἔ
ἰ
ἱ
ὀ
ὁ
ὐ
ὲ
ὸ
ᾶ
᾽
ῆ
ῇ
ῶ
‎
‑
‒
–
—
―
‖
†
‡
•
…
‧
‬
′
″
⁄
⁡
⁰
⁴
⁵
⁶
⁷
⁸
⁹
₁
₂
₃
€
₱
₹
₽
℃
ℏ
ℓ
№
ℝ
™
⅓
⅔
⅛
→
∂
∈
∑
−
∗
√
∞
∫
≈
≠
≡
≤
≥
⋅
⋯
█
♪
⟨
⟩
、
。
《
》
「
」
【
】
あ
う
え
お
か
が
き
ぎ
く
ぐ
け
げ
こ
ご
さ
し
じ
す
ず
せ
ぜ
そ
ぞ
た
だ
ち
っ
つ
で
と
ど
な
に
ね
の
は
ば
ひ
ぶ
へ
べ
ま
み
む
め
も
ゃ
や
ゆ
ょ
よ
ら
り
る
れ
ろ
わ
を
ん
ァ
ア
ィ
イ
ウ
ェ
エ
オ
カ
ガ
キ
ク
ケ
ゲ
コ
ゴ
サ
ザ
シ
ジ
ス
ズ
セ
ゾ
タ
ダ
チ
ッ
ツ
テ
デ
ト
ド
ナ
ニ
ネ
ノ
バ
パ
ビ
ピ
フ
プ
ヘ
ベ
ペ
ホ
ボ
ポ
マ
ミ
ム
メ
モ
ャ
ヤ
ュ
ユ
ョ
ヨ
ラ
リ
ル
レ
ロ
ワ
ン
・
ー
ㄋ
ㄍ
ㄎ
ㄏ
ㄓ
ㄕ
ㄚ
ㄜ
ㄟ
ㄤ
ㄥ
ㄧ
ㄱ
ㄴ
ㄷ
ㄹ
ㅁ
ㅂ
ㅅ
ㅈ
ㅍ
ㅎ
ㅏ
ㅓ
ㅗ
ㅜ
ㅡ
ㅣ
㗎
가
각
간
갈
감
갑
갓
갔
강
같
개
거
건
걸
겁
것
겉
게
겠
겨
결
겼
경
계
고
곤
골
곱
공
과
관
광
교
구
국
굴
귀
귄
그
근
글
금
기
긴
길
까
깍
깔
깜
깨
께
꼬
꼭
꽃
꾸
꿔
끔
끗
끝
끼
나
난
날
남
납
내
냐
냥
너
넘
넣
네
녁
년
녕
노
녹
놀
누
눈
느
는
늘
니
님
닙
다
닥
단
달
닭
당
대
더
덕
던
덥
데
도
독
동
돼
됐
되
된
될
두
둑
둥
드
들
등
디
따
딱
딸
땅
때
떤
떨
떻
또
똑
뚱
뛰
뜻
띠
라
락
란
람
랍
랑
래
랜
러
런
럼
렇
레
려
력
렵
렸
로
록
롬
루
르
른
를
름
릉
리
릴
림
마
막
만
많
말
맑
맙
맛
매
머
먹
멍
메
면
명
몇
모
목
몸
못
무
문
물
뭐
뭘
미
민
밌
밑
바
박
밖
반
받
발
밤
밥
방
배
백
밸
뱀
버
번
벌
벚
베
벼
벽
별
병
보
복
본
볼
봐
봤
부
분
불
비
빔
빛
빠
빨
뼈
뽀
뿅
쁘
사
산
살
삼
샀
상
새
색
생
서
선
설
섭
섰
성
세
셔
션
셨
소
속
손
송
수
숙
순
술
숫
숭
숲
쉬
쉽
스
슨
습
슷
시
식
신
실
싫
심
십
싶
싸
써
쓰
쓴
씌
씨
씩
씬
아
악
안
않
알
야
약
얀
양
얘
어
언
얼
엄
업
없
었
엉
에
여
역
연
염
엽
영
옆
예
옛
오
온
올
옷
옹
와
왔
왜
요
욕
용
우
운
울
웃
워
원
월
웠
위
윙
유
육
윤
으
은
을
음
응
의
이
익
인
일
읽
임
입
있
자
작
잔
잖
잘
잡
잤
장
재
저
전
점
정
제
져
졌
조
족
좀
종
좋
죠
주
준
줄
중
줘
즈
즐
즘
지
진
집
짜
짝
쩌
쪼
쪽
쫌
쭈
쯔
찌
찍
차
착
찾
책
처
천
철
체
쳐
쳤
초
촌
추
출
춤
춥
춰
치
친
칠
침
칩
칼
커
켓
코
콩
쿠
퀴
크
큰
큽
키
킨
타
태
터
턴
털
테
토
통
투
트
특
튼
틀
티
팀
파
팔
패
페
펜
펭
평
포
폭
표
품
풍
프
플
피
필
하
학
한
할
함
합
항
해
햇
했
행
허
험
형
혜
호
혼
홀
화
회
획
후
휴
흐
흔
희
히
힘
ﷺ
ﷻ
!
,
?
�
𠮶


================================================
FILE: src/f5_tts/infer/infer_cli.py
================================================
import argparse
import codecs
import os
import re
from datetime import datetime
from importlib.resources import files
from pathlib import Path

import numpy as np
import soundfile as sf
import tomli
from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf
from unidecode import unidecode

from f5_tts.infer.utils_infer import (
    cfg_strength,
    cross_fade_duration,
    device,
    fix_duration,
    infer_process,
    load_model,
    load_vocoder,
    mel_spec_type,
    nfe_step,
    preprocess_ref_audio_text,
    remove_silence_for_generated_wav,
    speed,
    sway_sampling_coef,
    target_rms,
)


parser = argparse.ArgumentParser(
    prog="python3 infer-cli.py",
    description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.",
    epilog="Specify options above to override one or more settings from config.",
)
parser.add_argument(
    "-c",
    "--config",
    type=str,
    default=os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"),
    help="The configuration file, default see infer/examples/basic/basic.toml",
)


# Note. Not to provide default value here in order to read default from config file

parser.add_argument(
    "-m",
    "--model",
    type=str,
    help="The model name: F5TTS_v1_Base | F5TTS_Base | E2TTS_Base | etc.",
)
parser.add_argument(
    "-mc",
    "--model_cfg",
    type=str,
    help="The path to F5-TTS model config file .yaml",
)
parser.add_argument(
    "-p",
    "--ckpt_file",
    type=str,
    help="The path to model checkpoint .pt, leave blank to use default",
)
parser.add_argument(
    "-v",
    "--vocab_file",
    type=str,
    help="The path to vocab file .txt, leave blank to use default",
)
parser.add_argument(
    "-r",
    "--ref_audio",
    type=str,
    help="The reference audio file.",
)
parser.add_argument(
    "-s",
    "--ref_text",
    type=str,
    help="The transcript/subtitle for the reference audio",
)
parser.add_argument(
    "-t",
    "--gen_text",
    type=str,
    help="The text to make model synthesize a speech",
)
parser.add_argument(
    "-f",
    "--gen_file",
    type=str,
    help="The file with text to generate, will ignore --gen_text",
)
parser.add_argument(
    "-o",
    "--output_dir",
    type=str,
    help="The path to output folder",
)
parser.add_argument(
    "-w",
    "--output_file",
    type=str,
    help="The name of output file",
)
parser.add_argument(
    "--save_chunk",
    action="store_true",
    help="To save each audio chunks during inference",
)
parser.add_argument(
    "--no_legacy_text",
    action="store_false",
    help="Not to use lossy ASCII transliterations of unicode text in saved file names.",
)
parser.add_argument(
    "--remove_silence",
    action="store_true",
    help="To remove long silence found in ouput",
)
parser.add_argument(
    "--load_vocoder_from_local",
    action="store_true",
    help="To load vocoder from local dir, default to ../checkpoints/vocos-mel-24khz",
)
parser.add_argument(
    "--vocoder_name",
    type=str,
    choices=["vocos", "bigvgan"],
    help=f"Used vocoder name: vocos | bigvgan, default {mel_spec_type}",
)
parser.add_argument(
    "--target_rms",
    type=float,
    help=f"Target output speech loudness normalization value, default {target_rms}",
)
parser.add_argument(
    "--cross_fade_duration",
    type=float,
    help=f"Duration of cross-fade between audio segments in seconds, default {cross_fade_duration}",
)
parser.add_argument(
    "--nfe_step",
    type=int,
    help=f"The number of function evaluation (denoising steps), default {nfe_step}",
)
parser.add_argument(
    "--cfg_strength",
    type=float,
    help=f"Classifier-free guidance strength, default {cfg_strength}",
)
parser.add_argument(
    "--sway_sampling_coef",
    type=float,
    help=f"Sway Sampling coefficient, default {sway_sampling_coef}",
)
parser.add_argument(
    "--speed",
    type=float,
    help=f"The speed of the generated audio, default {speed}",
)
parser.add_argument(
    "--fix_duration",
    type=float,
    help=f"Fix the total duration (ref and gen audios) in seconds, default {fix_duration}",
)
parser.add_argument(
    "--device",
    type=str,
    help="Specify the device to run on",
)
args = parser.parse_args()


# config file

config = tomli.load(open(args.config, "rb"))


# command-line interface parameters

model = args.model or config.get("model", "F5TTS_v1_Base")
ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
vocab_file = args.vocab_file or config.get("vocab_file", "")

ref_audio = args.ref_audio or config.get("ref_audio", "infer/examples/basic/basic_ref_en.wav")
ref_text = (
    args.ref_text
    if args.ref_text is not None
    else config.get("ref_text", "Some call me nature, others call me mother nature.")
)
gen_text = args.gen_text or config.get("gen_text", "Here we generate something just for test.")
gen_file = args.gen_file or config.get("gen_file", "")

output_dir = args.output_dir or config.get("output_dir", "tests")
output_file = args.output_file or config.get(
    "output_file", f"infer_cli_{datetime.now().strftime(r'%Y%m%d_%H%M%S')}.wav"
)

save_chunk = args.save_chunk or config.get("save_chunk", False)
use_legacy_text = args.no_legacy_text or config.get("no_legacy_text", False)  # no_legacy_text is a store_false arg
if save_chunk and use_legacy_text:
    print(
        "\nWarning to --save_chunk: lossy ASCII transliterations of unicode text for legacy (.wav) file names, --no_legacy_text to disable.\n"
    )

remove_silence = args.remove_silence or config.get("remove_silence", False)
load_vocoder_from_local = args.load_vocoder_from_local or config.get("load_vocoder_from_local", False)

vocoder_name = args.vocoder_name or config.get("vocoder_name", mel_spec_type)
target_rms = args.target_rms or config.get("target_rms", target_rms)
cross_fade_duration = args.cross_fade_duration or config.get("cross_fade_duration", cross_fade_duration)
nfe_step = args.nfe_step or config.get("nfe_step", nfe_step)
cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength)
sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef)
speed = args.speed or config.get("speed", speed)
fix_duration = args.fix_duration or config.get("fix_duration", fix_duration)
device = args.device or config.get("device", device)


# patches for pip pkg user
if "infer/examples/" in ref_audio:
    ref_audio = str(files("f5_tts").joinpath(f"{ref_audio}"))
if "infer/examples/" in gen_file:
    gen_file = str(files("f5_tts").joinpath(f"{gen_file}"))
if "voices" in config:
    for voice in config["voices"]:
        voice_ref_audio = config["voices"][voice]["ref_audio"]
        if "infer/examples/" in voice_ref_audio:
            config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}"))


# ignore gen_text if gen_file provided

if gen_file:
    gen_text = codecs.open(gen_file, "r", "utf-8").read()


# output path

wave_path = Path(output_dir) / output_file
# spectrogram_path = Path(output_dir) / "infer_cli_out.png"
if save_chunk:
    output_chunk_dir = os.path.join(output_dir, f"{Path(output_file).stem}_chunks")
    if not os.path.exists(output_chunk_dir):
        os.makedirs(output_chunk_dir)


# load vocoder

if vocoder_name == "vocos":
    vocoder_local_path = "../checkpoints/vocos-mel-24khz"
elif vocoder_name == "bigvgan":
    vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"

vocoder = load_vocoder(
    vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path, device=device
)


# load TTS model

model_cfg = OmegaConf.load(
    args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
)
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch

repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"

if model != "F5TTS_Base":
    assert vocoder_name == model_cfg.model.mel_spec.mel_spec_type

# override for previous models
if model == "F5TTS_Base":
    if vocoder_name == "vocos":
        ckpt_step = 1200000
    elif vocoder_name == "bigvgan":
        model = "F5TTS_Base_bigvgan"
        ckpt_type = "pt"
elif model == "E2TTS_Base":
    repo_name = "E2-TTS"
    ckpt_step = 1200000

if not ckpt_file:
    ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
elif ckpt_file.startswith("hf://"):
    ckpt_file = str(cached_path(ckpt_file))

if vocab_file.startswith("hf://"):
    vocab_file = str(cached_path(vocab_file))

print(f"Using {model}...")
ema_model = load_model(
    model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file, device=device
)


# inference process


def main():
    main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
    if "voices" not in config:
        voices = {"main": main_voice}
    else:
        voices = config["voices"]
        voices["main"] = main_voice
    for voice in voices:
        print("Voice:", voice)
        print("ref_audio ", voices[voice]["ref_audio"])
        voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
            voices[voice]["ref_audio"], voices[voice]["ref_text"]
        )
        print("ref_audio_", voices[voice]["ref_audio"], "\n\n")

    generated_audio_segments = []
    reg1 = r"(?=\[\w+\])"
    chunks = re.split(reg1, gen_text)
    reg2 = r"\[(\w+)\]"
    for text in chunks:
        if not text.strip():
            continue
        match = re.match(reg2, text)
        if match:
            voice = match[1]
        else:
            print("No voice tag found, using main.")
            voice = "main"
        if voice not in voices:
            print(f"Voice {voice} not found, using main.")
            voice = "main"
        text = re.sub(reg2, "", text)
        ref_audio_ = voices[voice]["ref_audio"]
        ref_text_ = voices[voice]["ref_text"]
        local_speed = voices[voice].get("speed", speed)
        gen_text_ = text.strip()
        print(f"Voice: {voice}")
        audio_segment, final_sample_rate, spectrogram = infer_process(
            ref_audio_,
            ref_text_,
            gen_text_,
            ema_model,
            vocoder,
            mel_spec_type=vocoder_name,
            target_rms=target_rms,
            cross_fade_duration=cross_fade_duration,
            nfe_step=nfe_step,
            cfg_strength=cfg_strength,
            sway_sampling_coef=sway_sampling_coef,
            speed=local_speed,
            fix_duration=fix_duration,
            device=device,
        )
        generated_audio_segments.append(audio_segment)

        if save_chunk:
            if len(gen_text_) > 200:
                gen_text_ = gen_text_[:200] + " ... "
            if use_legacy_text:
                gen_text_ = unidecode(gen_text_)
            sf.write(
                os.path.join(output_chunk_dir, f"{len(generated_audio_segments) - 1}_{gen_text_}.wav"),
                audio_segment,
                final_sample_rate,
            )

    if generated_audio_segments:
        final_wave = np.concatenate(generated_audio_segments)

        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        with open(wave_path, "wb") as f:
            sf.write(f.name, final_wave, final_sample_rate)
            # Remove silence
            if remove_silence:
                remove_silence_for_generated_wav(f.name)
            print(f.name)


if __name__ == "__main__":
    main()


================================================
FILE: src/f5_tts/infer/infer_gradio.py
================================================
# ruff: noqa: E402
# Above allows ruff to ignore E402: module level import not at top of file

import gc
import json
import os
import re
import tempfile
from collections import OrderedDict
from functools import lru_cache
from importlib.resources import files

import click
import gradio as gr
import numpy as np
import soundfile as sf
import torch
import torchaudio
from cached_path import cached_path
from transformers import AutoModelForCausalLM, AutoTokenizer


try:
    import spaces

    USING_SPACES = True
except ImportError:
    USING_SPACES = False


def gpu_decorator(func):
    if USING_SPACES:
        return spaces.GPU(func)
    else:
        return func


from f5_tts.infer.utils_infer import (
    infer_process,
    load_model,
    load_vocoder,
    preprocess_ref_audio_text,
    remove_silence_for_generated_wav,
    save_spectrogram,
    tempfile_kwargs,
)
from f5_tts.model import DiT, UNetT


DEFAULT_TTS_MODEL = "F5-TTS_v1"
tts_model_choice = DEFAULT_TTS_MODEL

DEFAULT_TTS_MODEL_CFG = [
    "hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors",
    "hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt",
    json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)),
]


# load models

vocoder = load_vocoder()


def load_f5tts():
    ckpt_path = str(cached_path(DEFAULT_TTS_MODEL_CFG[0]))
    F5TTS_model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2])
    return load_model(DiT, F5TTS_model_cfg, ckpt_path)


def load_e2tts():
    ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
    E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4, text_mask_padding=False, pe_attn_head=1)
    return load_model(UNetT, E2TTS_model_cfg, ckpt_path)


def load_custom(ckpt_path: str, vocab_path="", model_cfg=None):
    ckpt_path, vocab_path = ckpt_path.strip(), vocab_path.strip()
    if ckpt_path.startswith("hf://"):
        ckpt_path = str(cached_path(ckpt_path))
    if vocab_path.startswith("hf://"):
        vocab_path = str(cached_path(vocab_path))
    if model_cfg is None:
        model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2])
    elif isinstance(model_cfg, str):
        model_cfg = json.loads(model_cfg)
    return load_model(DiT, model_cfg, ckpt_path, vocab_file=vocab_path)


F5TTS_ema_model = load_f5tts()
E2TTS_ema_model = load_e2tts() if USING_SPACES else None
custom_ema_model, pre_custom_path = None, ""

chat_model_state = None
chat_tokenizer_state = None


@gpu_decorator
def chat_model_inference(messages, model, tokenizer):
    """Generate response using Qwen"""
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )

    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=512,
        temperature=0.7,
        top_p=0.95,
    )

    generated_ids = [
        output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]


@gpu_decorator
def load_text_from_file(file):
    if file:
        with open(file, "r", encoding="utf-8") as f:
            text = f.read().strip()
    else:
        text = ""
    return gr.update(value=text)


@lru_cache(maxsize=1000)  # NOTE. need to ensure params of infer() hashable
@gpu_decorator
def infer(
    ref_audio_orig,
    ref_text,
    gen_text,
    model,
    remove_silence,
    seed,
    cross_fade_duration=0.15,
    nfe_step=32,
    speed=1,
    show_info=gr.Info,
):
    if not ref_audio_orig:
        gr.Warning("Please provide reference audio.")
        return gr.update(), gr.update(), ref_text

    # Set inference seed
    if seed < 0 or seed > 2**31 - 1:
        gr.Warning("Seed must in range 0 ~ 2147483647. Using random seed instead.")
        seed = np.random.randint(0, 2**31 - 1)
    torch.manual_seed(seed)
    used_seed = seed

    if not gen_text.strip():
        gr.Warning("Please enter text to generate or upload a text file.")
        return gr.update(), gr.update(), ref_text

    ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)

    if model == DEFAULT_TTS_MODEL:
        ema_model = F5TTS_ema_model
    elif model == "E2-TTS":
        global E2TTS_ema_model
        if E2TTS_ema_model is None:
            show_info("Loading E2-TTS model...")
            E2TTS_ema_model = load_e2tts()
        ema_model = E2TTS_ema_model
    elif isinstance(model, tuple) and model[0] == "Custom":
        assert not USING_SPACES, "Only official checkpoints allowed in Spaces."
        global custom_ema_model, pre_custom_path
        if pre_custom_path != model[1]:
            show_info("Loading Custom TTS model...")
            custom_ema_model = load_custom(model[1], vocab_path=model[2], model_cfg=model[3])
            pre_custom_path = model[1]
        ema_model = custom_ema_model

    final_wave, final_sample_rate, combined_spectrogram = infer_process(
        ref_audio,
        ref_text,
        gen_text,
        ema_model,
        vocoder,
        cross_fade_duration=cross_fade_duration,
        nfe_step=nfe_step,
        speed=speed,
        show_info=show_info,
        progress=gr.Progress(),
    )

    # Remove silence
    if remove_silence:
        with tempfile.NamedTemporaryFile(suffix=".wav", **tempfile_kwargs) as f:
            temp_path = f.name
        try:
            sf.write(temp_path, final_wave, final_sample_rate)
            remove_silence_for_generated_wav(f.name)
            final_wave, _ = torchaudio.load(f.name)
        finally:
            os.unlink(temp_path)
        final_wave = final_wave.squeeze().cpu().numpy()

    # Save the spectrogram
    with tempfile.NamedTemporaryFile(suffix=".png", **tempfile_kwargs) as tmp_spectrogram:
        spectrogram_path = tmp_spectrogram.name
    save_spectrogram(combined_spectrogram, spectrogram_path)

    return (final_sample_rate, final_wave), spectrogram_path, ref_text, used_seed


with gr.Blocks() as app_tts:
    gr.Markdown("# Batched TTS")
    ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
    with gr.Row():
        gen_text_input = gr.Textbox(
            label="Text to Generate",
            lines=10,
            max_lines=40,
            scale=4,
        )
        gen_text_file = gr.File(label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1)
    generate_btn = gr.Button("Synthesize", variant="primary")
    with gr.Accordion("Advanced Settings", open=True) as adv_settn:
        with gr.Row():
            ref_text_input = gr.Textbox(
                label="Reference Text",
                info="Leave blank to automatically transcribe the reference audio. If you enter text or upload a file, it will override automatic transcription.",
                lines=2,
                scale=4,
            )
            ref_text_file = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"], scale=1)
        with gr.Row():
            randomize_seed = gr.Checkbox(
                label="Randomize Seed",
                info="Check to use a random seed for each generation. Uncheck to use the seed specified.",
                value=True,
                scale=3,
            )
            seed_input = gr.Number(show_label=False, value=0, precision=0, scale=1)
            with gr.Column(scale=4):
                remove_silence = gr.Checkbox(
                    label="Remove Silences",
                    info="If undesired long silence(s) produced, turn on to automatically detect and crop.",
                    value=False,
                )
        speed_slider = gr.Slider(
            label="Speed",
            minimum=0.3,
            maximum=2.0,
            value=1.0,
            step=0.1,
            info="Adjust the speed of the audio.",
        )
        nfe_slider = gr.Slider(
            label="NFE Steps",
            minimum=4,
            maximum=64,
            value=32,
            step=2,
            info="Set the number of denoising steps.",
        )
        cross_fade_duration_slider = gr.Slider(
            label="Cross-Fade Duration (s)",
            minimum=0.0,
            maximum=1.0,
            value=0.15,
            step=0.01,
            info="Set the duration of the cross-fade between audio clips.",
        )

    def collapse_accordion():
        return gr.Accordion(open=False)

    # Workaround for https://github.com/SWivid/F5-TTS/issues/1239#issuecomment-3677987413
    # i.e. to set gr.Accordion(open=True) by default, then collapse manually Blocks loaded
    app_tts.load(
        fn=collapse_accordion,
        inputs=None,
        outputs=adv_settn,
    )

    audio_output = gr.Audio(label="Synthesized Audio")
    spectrogram_output = gr.Image(label="Spectrogram")

    @gpu_decorator
    def basic_tts(
        ref_audio_input,
        ref_text_input,
        gen_text_input,
        remove_silence,
        randomize_seed,
        seed_input,
        cross_fade_duration_slider,
        nfe_slider,
        speed_slider,
    ):
        if randomize_seed:
            seed_input = np.random.randint(0, 2**31 - 1)

        audio_out, spectrogram_path, ref_text_out, used_seed = infer(
            ref_audio_input,
            ref_text_input,
            gen_text_input,
            tts_model_choice,
            remove_silence,
            seed=seed_input,
            cross_fade_duration=cross_fade_duration_slider,
            nfe_step=nfe_slider,
            speed=speed_slider,
        )
        return audio_out, spectrogram_path, ref_text_out, used_seed

    gen_text_file.upload(
        load_text_from_file,
        inputs=[gen_text_file],
        outputs=[gen_text_input],
    )

    ref_text_file.upload(
        load_text_from_file,
        inputs=[ref_text_file],
        outputs=[ref_text_input],
    )

    ref_audio_input.clear(
        lambda: [None, None],
        None,
        [ref_text_input, ref_text_file],
    )

    generate_btn.click(
        basic_tts,
        inputs=[
            ref_audio_input,
            ref_text_input,
            gen_text_input,
            remove_silence,
            randomize_seed,
            seed_input,
            cross_fade_duration_slider,
            nfe_slider,
            speed_slider,
        ],
        outputs=[audio_output, spectrogram_output, ref_text_input, seed_input],
    )


def parse_speechtypes_text(gen_text):
    # Pattern to find {str} or {"name": str, "seed": int, "speed": float}
    pattern = r"(\{.*?\})"

    # Split the text by the pattern
    tokens = re.split(pattern, gen_text)

    segments = []

    current_type_dict = {
        "name": "Regular",
        "seed": -1,
        "speed": 1.0,
    }

    for i in range(len(tokens)):
        if i % 2 == 0:
            # This is text
            text = tokens[i].strip()
            if text:
                current_type_dict["text"] = text
                segments.append(current_type_dict)
        else:
            # This is type
            type_str = tokens[i].strip()
            try:  # if type dict
                current_type_dict = json.loads(type_str)
            except json.decoder.JSONDecodeError:
                type_str = type_str[1:-1]  # remove brace {}
                current_type_dict = {"name": type_str, "seed": -1, "speed": 1.0}

    return segments


with gr.Blocks() as app_multistyle:
    # New section for multistyle generation
    gr.Markdown(
        """
    # Multiple Speech-Type Generation

    This section allows you to generate multiple speech types or multiple people's voices. Enter your text in the format shown below, or upload a .txt file with the same format. The system will generate speech using the appropriate type. If unspecified, the model will use the regular speech type. The current speech type will be used until the next speech type is specified.
    """
    )

    with gr.Row():
        gr.Markdown(
            """
            **Example Input:** <br>
            {Regular} Hello, I'd like to order a sandwich please. <br>
            {Surprised} What do you mean you're out of bread? <br>
            {Sad} I really wanted a sandwich though... <br>
            {Angry} You know what, darn you and your little shop! <br>
            {Whisper} I'll just go back home and cry now. <br>
            {Shouting} Why me?!
            """
        )

        gr.Markdown(
            """
            **Example Input 2:** <br>
            {"name": "Speaker1_Happy", "seed": -1, "speed": 1} Hello, I'd like to order a sandwich please. <br>
            {"name": "Speaker2_Regular", "seed": -1, "speed": 1} Sorry, we're out of bread. <br>
            {"name": "Speaker1_Sad", "seed": -1, "speed": 1} I really wanted a sandwich though... <br>
            {"name": "Speaker2_Whisper", "seed": -1, "speed": 1} I'll give you the last one I was hiding.
            """
        )

    gr.Markdown(
        'Upload different audio clips for each speech type. The first speech type is mandatory. You can add additional speech types by clicking the "Add Speech Type" button.'
    )

    # Regular speech type (mandatory)
    with gr.Row(variant="compact") as regular_row:
        with gr.Column(scale=1, min_width=160):
            regular_name = gr.Textbox(value="Regular", label="Speech Type Name")
            regular_insert = gr.Button("Insert Label", variant="secondary")
        with gr.Column(scale=3):
            regular_audio = gr.Audio(label="Regular Reference Audio", type="filepath")
        with gr.Column(scale=3):
            regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=4)
            with gr.Row():
                regular_seed_slider = gr.Slider(
                    show_label=False, minimum=-1, maximum=999, value=-1, step=1, info="Seed, -1 for random"
                )
                regular_speed_slider = gr.Slider(
                    show_label=False, minimum=0.3, maximum=2.0, value=1.0, step=0.1, info="Adjust the speed"
                )
        with gr.Column(scale=1, min_width=160):
            regular_ref_text_file = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"])

    # Regular speech type (max 100)
    max_speech_types = 100
    speech_type_rows = [regular_row]
    speech_type_names = [regular_name]
    speech_type_audios = [regular_audio]
    speech_type_ref_texts = [regular_ref_text]
    speech_type_ref_text_files = [regular_ref_text_file]
    speech_type_seeds = [regular_seed_slider]
    speech_type_speeds = [regular_speed_slider]
    speech_type_delete_btns = [None]
    speech_type_insert_btns = [regular_insert]

    # Additional speech types (99 more)
    for i in range(max_speech_types - 1):
        with gr.Row(variant="compact", visible=False) as row:
            with gr.Column(scale=1, min_width=160):
                name_input = gr.Textbox(label="Speech Type Name")
                insert_btn = gr.Button("Insert Label", variant="secondary")
                delete_btn = gr.Button("Delete Type", variant="stop")
            with gr.Column(scale=3):
                audio_input = gr.Audio(label="Reference Audio", type="filepath")
            with gr.Column(scale=3):
                ref_text_input = gr.Textbox(label="Reference Text", lines=4)
                with gr.Row():
                    seed_input = gr.Slider(
                        show_label=False, minimum=-1, maximum=999, value=-1, step=1, info="Seed. -1 for random"
                    )
                    speed_input = gr.Slider(
                        show_label=False, minimum=0.3, maximum=2.0, value=1.0, step=0.1, info="Adjust the speed"
                    )
            with gr.Column(scale=1, min_width=160):
                ref_text_file_input = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"])
        speech_type_rows.append(row)
        speech_type_names.append(name_input)
        speech_type_audios.append(audio_input)
        speech_type_ref_texts.append(ref_text_input)
        speech_type_ref_text_files.append(ref_text_file_input)
        speech_type_seeds.append(seed_input)
        speech_type_speeds.append(speed_input)
        speech_type_delete_btns.append(delete_btn)
        speech_type_insert_btns.append(insert_btn)

    # Global logic for all speech types
    for i in range(max_speech_types):
        speech_type_audios[i].clear(
            lambda: [None, None],
            None,
            [speech_type_ref_texts[i], speech_type_ref_text_files[i]],
        )
        speech_type_ref_text_files[i].upload(
            load_text_from_file,
            inputs=[speech_type_ref_text_files[i]],
            outputs=[speech_type_ref_texts[i]],
        )

    # Button to add speech type
    add_speech_type_btn = gr.Button("Add Speech Type")

    # Keep track of autoincrement of speech types, no roll back
    speech_type_count = 1

    # Function to add a speech type
    def add_speech_type_fn():
        row_updates = [gr.update() for _ in range(max_speech_types)]
        global speech_type_count
        if speech_type_count < max_speech_types:
            row_updates[speech_type_count] = gr.update(visible=True)
            speech_type_count += 1
        else:
            gr.Warning("Exhausted maximum number of speech types. Consider restart the app.")
        return row_updates

    add_speech_type_btn.click(add_speech_type_fn, outputs=speech_type_rows)

    # Function to delete a speech type
    def delete_speech_type_fn():
        return gr.update(visible=False), None, None, None, None

    # Update delete button clicks and ref text file changes
    for i in range(1, len(speech_type_delete_btns)):
        speech_type_delete_btns[i].click(
            delete_speech_type_fn,
            outputs=[
                speech_type_rows[i],
                speech_type_names[i],
                speech_type_audios[i],
                speech_type_ref_texts[i],
                speech_type_ref_text_files[i],
            ],
        )

    # Text input for the prompt
    with gr.Row():
        gen_text_input_multistyle = gr.Textbox(
            label="Text to Generate",
            lines=10,
            max_lines=40,
            scale=4,
            placeholder="Enter the script with speaker names (or emotion types) at the start of each block, e.g.:\n\n{Regular} Hello, I'd like to order a sandwich please.\n{Surprised} What do you mean you're out of bread?\n{Sad} I really wanted a sandwich though...\n{Angry} You know what, darn you and your little shop!\n{Whisper} I'll just go back home and cry now.\n{Shouting} Why me?!",
        )
        gen_text_file_multistyle = gr.File(label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1)

    def make_insert_speech_type_fn(index):
        def insert_speech_type_fn(current_text, speech_type_name, speech_type_seed, speech_type_speed):
            current_text = current_text or ""
            if not speech_type_name:
                gr.Warning("Please enter speech type name before insert.")
                return current_text
            speech_type_dict = {
                "name": speech_type_name,
                "seed": speech_type_seed,
                "speed": speech_type_speed,
            }
            updated_text = current_text + json.dumps(speech_type_dict) + " "
            return updated_text

        return insert_speech_type_fn

    for i, insert_btn in enumerate(speech_type_insert_btns):
        insert_fn = make_insert_speech_type_fn(i)
        insert_btn.click(
            insert_fn,
            inputs=[gen_text_input_multistyle, speech_type_names[i], speech_type_seeds[i], speech_type_speeds[i]],
            outputs=gen_text_input_multistyle,
        )

    with gr.Accordion("Advanced Settings", open=True):
        with gr.Row():
            with gr.Column():
                show_cherrypick_multistyle = gr.Checkbox(
                    label="Show Cherry-pick Interface",
                    info="Turn on to show interface, picking seeds from previous generations.",
                    value=False,
                )
            with gr.Column():
                remove_silence_multistyle = gr.Checkbox(
                    label="Remove Silences",
                    info="Turn on to automatically detect and crop long silences.",
                    value=True,
                )

    # Generate button
    generate_multistyle_btn = gr.Button("Generate Multi-Style Speech", variant="primary")

    # Output audio
    audio_output_multistyle = gr.Audio(label="Synthesized Audio")

    # Used seed gallery
    cherrypick_interface_multistyle = gr.Textbox(
        label="Cherry-pick Interface",
        lines=10,
        max_lines=40,
        buttons=["copy"],  # show_copy_button=True if gradio<6.0
        interactive=False,
        visible=False,
    )

    # Logic control to show/hide the cherrypick interface
    show_cherrypick_multistyle.change(
        lambda is_visible: gr.update(visible=is_visible),
        show_cherrypick_multistyle,
        cherrypick_interface_multistyle,
    )

    # Function to load text to generate from file
    gen_text_file_multistyle.upload(
        load_text_from_file,
        inputs=[gen_text_file_multistyle],
        outputs=[gen_text_input_multistyle],
    )

    @gpu_decorator
    def generate_multistyle_speech(
        gen_text,
        *args,
    ):
        speech_type_names_list = args[:max_speech_types]
        speech_type_audios_list = args[max_speech_types : 2 * max_speech_types]
        speech_type_ref_texts_list = args[2 * max_speech_types : 3 * max_speech_types]
        remove_silence = args[3 * max_speech_types]
        # Collect the speech types and their audios into a dict
        speech_types = OrderedDict()

        ref_text_idx = 0
        for name_input, audio_input, ref_text_input in zip(
            speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list
        ):
            if name_input and audio_input:
                speech_types[name_input] = {"audio": audio_input, "ref_text": ref_text_input}
            else:
                speech_types[f"@{ref_text_idx}@"] = {"audio": "", "ref_text": ""}
            ref_text_idx += 1

        # Parse the gen_text into segments
        segments = parse_speechtypes_text(gen_text)

        # For each segment, generate speech
        generated_audio_segments = []
        current_type_name = "Regular"
        inference_meta_data = ""

        for segment in segments:
            name = segment["name"]
            seed_input = segment["seed"]
            speed = segment["speed"]
            text = segment["text"]

            if name in speech_types:
                current_type_name = name
            else:
                gr.Warning(f"Type {name} is not available, will use Regular as default.")
                current_type_name = "Regular"

            try:
                ref_audio = speech_types[current_type_name]["audio"]
            except KeyError:
                gr.Warning(f"Please provide reference audio for type {current_type_name}.")
                return [None] + [speech_types[name]["ref_text"] for name in speech_types] + [None]
            ref_text = speech_types[current_type_name].get("ref_text", "")

            if seed_input == -1:
                seed_input = np.random.randint(0, 2**31 - 1)

            # Generate or retrieve speech for this segment
            audio_out, _, ref_text_out, used_seed = infer(
                ref_audio,
                ref_text,
                text,
                tts_model_choice,
                remove_silence,
                seed=seed_input,
                cross_fade_duration=0,
                speed=speed,
                show_info=print,  # no pull to top when generating
            )
            sr, audio_data = audio_out

            generated_audio_segments.append(audio_data)
            speech_types[current_type_name]["ref_text"] = ref_text_out
            inference_meta_data += json.dumps(dict(name=name, seed=used_seed, speed=speed)) + f" {text}\n"

        # Concatenate all audio segments
        if generated_audio_segments:
            final_audio_data = np.concatenate(generated_audio_segments)
            return (
                [(sr, final_audio_data)]
                + [speech_types[name]["ref_text"] for name in speech_types]
                + [inference_meta_data]
            )
        else:
            gr.Warning("No audio generated.")
            return [None] + [speech_types[name]["ref_text"] for name in speech_types] + [None]

    generate_multistyle_btn.click(
        generate_multistyle_speech,
        inputs=[
            gen_text_input_multistyle,
        ]
        + speech_type_names
        + speech_type_audios
        + speech_type_ref_texts
        + [
            remove_silence_multistyle,
        ],
        outputs=[audio_output_multistyle] + speech_type_ref_texts + [cherrypick_interface_multistyle],
    )

    # Validation function to disable Generate button if speech types are missing
    def validate_speech_types(gen_text, regular_name, *args):
        speech_type_names_list = args

        # Collect the speech types names
        speech_types_available = set()
        if regular_name:
            speech_types_available.add(regular_name)
        for name_input in speech_type_names_list:
            if name_input:
                speech_types_available.add(name_input)

        # Parse the gen_text to get the speech types used
        segments = parse_speechtypes_text(gen_text)
        speech_types_in_text = set(segment["name"] for segment in segments)

        # Check if all speech types in text are available
        missing_speech_types = speech_types_in_text - speech_types_available

        if missing_speech_types:
            # Disable the generate button
            return gr.update(interactive=False)
        else:
            # Enable the generate button
            return gr.update(interactive=True)

    gen_text_input_multistyle.change(
        validate_speech_types,
        inputs=[gen_text_input_multistyle, regular_name] + speech_type_names,
        outputs=generate_multistyle_btn,
    )


with gr.Blocks() as app_chat:
    gr.Markdown(
        """
# Voice Chat
Have a conversation with an AI using your reference voice!
1. Upload a reference audio clip and optionally its transcript (via text or .txt file).
2. Load the chat model.
3. Record your message through your microphone or type it.
4. The AI will respond using the reference voice.
"""
    )

    chat_model_name_list = [
        "Qwen/Qwen2.5-3B-Instruct",
        "microsoft/Phi-4-mini-instruct",
    ]

    @gpu_decorator
    def load_chat_model(chat_model_name):
        show_info = gr.Info
        global chat_model_state, chat_tokenizer_state
        if chat_model_state is not None:
            chat_model_state = None
            chat_tokenizer_state = None
            gc.collect()
            torch.cuda.empty_cache()

        show_info(f"Loading chat model: {chat_model_name}")
        chat_model_state = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype="auto", device_map="auto")
        chat_tokenizer_state = AutoTokenizer.from_pretrained(chat_model_name)
        show_info(f"Chat model {chat_model_name} loaded successfully!")

        return gr.update(visible=False), gr.update(visible=True)

    if USING_SPACES:
        load_chat_model(chat_model_name_list[0])

    chat_model_name_input = gr.Dropdown(
        choices=chat_model_name_list,
        value=chat_model_name_list[0],
        label="Chat Model Name",
        info="Enter the name of a HuggingFace chat model",
        allow_custom_value=not USING_SPACES,
    )
    load_chat_model_btn = gr.Button("Load Chat Model", variant="primary", visible=not USING_SPACES)
    chat_interface_container = gr.Column(visible=USING_SPACES)

    chat_model_name_input.change(
        lambda: gr.update(visible=True),
        None,
        load_chat_model_btn,
        show_progress="hidden",
    )
    load_chat_model_btn.click(
        load_chat_model, inputs=[chat_model_name_input], outputs=[load_chat_model_btn, chat_interface_container]
    )

    with chat_interface_container:
        with gr.Row():
            with gr.Column():
                ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath")
            with gr.Column():
                with gr.Accordion("Advanced Settings", open=False):
                    with gr.Row():
                        ref_text_chat = gr.Textbox(
                            label="Reference Text",
                            info="Optional: Leave blank to auto-transcribe",
                            lines=2,
                            scale=3,
                        )
                        ref_text_file_chat = gr.File(
                            label="Load Reference Text from File (.txt)", file_types=[".txt"], scale=1
                        )
                    with gr.Row():
                        randomize_seed_chat = gr.Checkbox(
                            label="Randomize Seed",
                            value=True,
                            info="Uncheck to use the seed specified.",
                            scale=3,
                        )
                        seed_input_chat = gr.Number(show_label=False, value=0, precision=0, scale=1)
                    remove_silence_chat = gr.Checkbox(
                        label="Remove Silences",
                        value=True,
                    )
                    system_prompt_chat = gr.Textbox(
                        label="System Prompt",
                        value="You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
                        lines=2,
                    )

        chatbot_interface = gr.Chatbot(
            label="Conversation"
        )  # type="messages" hard-coded and no need to pass in since gradio 6.0

        with gr.Row():
            with gr.Column():
                audio_input_chat = gr.Microphone(
                    label="Speak your message",
                    type="filepath",
                )
                audio_output_chat = gr.Audio(autoplay=True)
            with gr.Column():
                text_input_chat = gr.Textbox(
                    label="Type your message",
                    lines=1,
                )
                send_btn_chat = gr.Button("Send Message")
                clear_btn_chat = gr.Button("Clear Conversation")

        # Modify process_audio_input to generate user input
        @gpu_decorator
        def process_audio_input(conv_state, audio_path, text):
            """Handle audio or text input from user"""

            if not audio_path and not text.strip():
                return conv_state

            if audio_path:
                text = preprocess_ref_audio_text(audio_path, text)[1]
            if not text.strip():
                return conv_state

            conv_state.append({"role": "user", "content": text})
            return conv_state

        # Use model and tokenizer from state to get text response
        @gpu_decorator
        def generate_text_response(conv_state, system_prompt):
            """Generate text response from AI"""
            for single_state in conv_state:
                if isinstance(single_state["content"], list):
                    assert len(single_state["content"]) == 1 and single_state["content"][0]["type"] == "text"
                    single_state["content"] = single_state["content"][0]["text"]

            system_prompt_state = [{"role": "system", "content": system_prompt}]
            response = chat_model_inference(system_prompt_state + conv_state, chat_model_state, chat_tokenizer_state)

            conv_state.append({"role": "assistant", "content": response})
            return conv_state

        @gpu_decorator
        def generate_audio_response(conv_state, ref_audio, ref_text, remove_silence, randomize_seed, seed_input):
            """Generate TTS audio for AI response"""
            if not conv_state or not ref_audio:
                return None, ref_text, seed_input

            last_ai_response = conv_state[-1]["content"][0]["text"]
            if not last_ai_response or conv_state[-1]["role"] != "assistant":
                return None, ref_text, seed_input

            if randomize_seed:
                seed_input = np.random.randint(0, 2**31 - 1)

            audio_result, _, ref_text_out, used_seed = infer(
                ref_audio,
                ref_text,
                last_ai_response,
                tts_model_choice,
                remove_silence,
                seed=seed_input,
                cross_fade_duration=0.15,
                speed=1.0,
                show_info=print,  # show_info=print no pull to top when generating
            )
            return audio_result, ref_text_out, used_seed

        def clear_conversation():
            """Reset the conversation"""
            return [], None

        ref_text_file_chat.upload(
            load_text_from_file,
            inputs=[ref_text_file_chat],
            outputs=[ref_text_chat],
        )

        for user_operation in [audio_input_chat.stop_recording, text_input_chat.submit, send_btn_chat.click]:
            user_operation(
                process_audio_input,
                inputs=[chatbot_interface, audio_input_chat, text_input_chat],
                outputs=[chatbot_interface],
            ).then(
                generate_text_response,
                inputs=[chatbot_interface, system_prompt_chat],
                outputs=[chatbot_interface],
            ).then(
                generate_audio_response,
                inputs=[
                    chatbot_interface,
                    ref_audio_chat,
                    ref_text_chat,
                    remove_silence_chat,
                    randomize_seed_chat,
                    seed_input_chat,
                ],
                outputs=[audio_output_chat, ref_text_chat, seed_input_chat],
            ).then(
                lambda: [None, None],
                None,
                [audio_input_chat, text_input_chat],
            )

        # Handle clear button or system prompt change and reset conversation
        for user_operation in [clear_btn_chat.click, system_prompt_chat.change, chatbot_interface.clear]:
            user_operation(
                clear_conversation,
                outputs=[chatbot_interface, audio_output_chat],
            )


with gr.Blocks() as app_credits:
    gr.Markdown("""
# Credits

* [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
* [RootingInLoad](https://github.com/RootingInLoad) for initial chunk generation and podcast app exploration
* [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation & voice chat
""")


with gr.Blocks() as app:
    gr.Markdown(
        f"""
# F5-TTS Demo Space

This is {"a local web UI for [F5-TTS](https://github.com/SWivid/F5-TTS)" if not USING_SPACES else "an online demo for [F5-TTS](https://github.com/SWivid/F5-TTS)"} with advanced batch processing support. This app supports the following TTS models:

* [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
* [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)

The checkpoints currently support English and Chinese.

If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 12s with  ✂  in the bottom right corner (otherwise might have non-optimal auto-trimmed result).

**NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<12s). Ensure the audio is fully uploaded before generating.**
"""
    )

    last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom_model_info_v1.txt")

    def load_last_used_custom():
        try:
            custom = []
            with open(last_used_custom, "r", encoding="utf-8") as f:
                for line in f:
                    custom.append(line.strip())
            return custom
        except FileNotFoundError:
            last_used_custom.parent.mkdir(parents=True, exist_ok=True)
            return DEFAULT_TTS_MODEL_CFG

    def switch_tts_model(new_choice):
        global tts_model_choice
        if new_choice == "Custom":  # override in case webpage is refreshed
            custom_ckpt_path, custom_vocab_path, custom_model_cfg = load_last_used_custom()
            tts_model_choice = ("Custom", custom_ckpt_path, custom_vocab_path, custom_model_cfg)
            return (
                gr.update(visible=True, value=custom_ckpt_path),
                gr.update(visible=True, value=custom_vocab_path),
                gr.update(visible=True, value=custom_model_cfg),
            )
        else:
            tts_model_choice = new_choice
            return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)

    def set_custom_model(custom_ckpt_path, custom_vocab_path, custom_model_cfg):
        global tts_model_choice
        tts_model_choice = ("Custom", custom_ckpt_path, custom_vocab_path, custom_model_cfg)
        with open(last_used_custom, "w", encoding="utf-8") as f:
            f.write(custom_ckpt_path + "\n" + custom_vocab_path + "\n" + custom_model_cfg + "\n")

    with gr.Row():
        if not USING_SPACES:
            choose_tts_model = gr.Radio(
                choices=[DEFAULT_TTS_MODEL, "E2-TTS", "Custom"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL
            )
        else:
            choose_tts_model = gr.Radio(
                choices=[DEFAULT_TTS_MODEL, "E2-TTS"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL
            )
        custom_ckpt_path = gr.Dropdown(
            choices=[DEFAULT_TTS_MODEL_CFG[0]],
            value=load_last_used_custom()[0],
            allow_custom_value=True,
            label="Model: local_path | hf://user_id/repo_id/model_ckpt",
            visible=False,
        )
        custom_vocab_path = gr.Dropdown(
            choices=[DEFAULT_TTS_MODEL_CFG[1]],
            value=load_last_used_custom()[1],
            allow_custom_value=True,
            label="Vocab: local_path | hf://user_id/repo_id/vocab_file",
            visible=False,
        )
        custom_model_cfg = gr.Dropdown(
            choices=[
                DEFAULT_TTS_MODEL_CFG[2],
                json.dumps(
                    dict(
                        dim=1024,
                        depth=22,
                        heads=16,
                        ff_mult=2,
                        text_dim=512,
                        text_mask_padding=False,
                        conv_layers=4,
                        pe_attn_head=1,
                    )
                ),
                json.dumps(
                    dict(
                        dim=768,
                        depth=18,
                        heads=12,
                        ff_mult=2,
                        text_dim=512,
                        text_mask_padding=False,
                        conv_layers=4,
                        pe_attn_head=1,
                    )
                ),
            ],
            value=load_last_used_custom()[2],
            allow_custom_value=True,
            label="Config: in a dictionary form",
            visible=False,
        )

    choose_tts_model.change(
        switch_tts_model,
        inputs=[choose_tts_model],
        outputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
        show_progress="hidden",
    )
    custom_ckpt_path.change(
        set_custom_model,
        inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
        show_progress="hidden",
    )
    custom_vocab_path.change(
        set_custom_model,
        inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
        show_progress="hidden",
    )
    custom_model_cfg.change(
        set_custom_model,
        inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
        show_progress="hidden",
    )

    gr.TabbedInterface(
        [app_tts, app_multistyle, app_chat, app_credits],
        ["Basic-TTS", "Multi-Speech", "Voice-Chat", "Credits"],
    )


@click.command()
@click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
@click.option("--host", "-H", default=None, help="Host to run the app on")
@click.option(
    "--share",
    "-s",
    default=False,
    is_flag=True,
    help="Share the app via Gradio share link",
)
@click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
@click.option(
    "--root_path",
    "-r",
    default=None,
    type=str,
    help='The root path (or "mount point") of the application, if it\'s not served from the root ("/") of the domain. Often used when the application is behind a reverse proxy that forwards requests to the application, e.g. set "/myapp" or full URL for application served at "https://example.com/myapp".',
)
@click.option(
    "--inbrowser",
    "-i",
    is_flag=True,
    default=False,
    help="Automatically launch the interface in the default web browser",
)
def main(port, host, share, api, root_path, inbrowser):
    global app
    print("Starting app...")
    app.queue(api_open=api).launch(
        server_name=host,
        server_port=port,
        share=share,
        root_path=root_path,
        inbrowser=inbrowser,
    )


if __name__ == "__main__":
    if not USING_SPACES:
        main()
    else:
        app.queue().launch()


================================================
FILE: src/f5_tts/infer/speech_edit.py
================================================
import os


os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"  # for MPS device compatibility

from importlib.resources import files

import torch
import torch.nn.functional as F
import torchaudio
from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf

from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
from f5_tts.model import CFM
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer


device = (
    "cuda"
    if torch.cuda.is_available()
    else "xpu"
    if torch.xpu.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)


# ---------------------- infer setting ---------------------- #

seed = None  # int | None

exp_name = "F5TTS_v1_Base"  # F5TTS_v1_Base | E2TTS_Base
ckpt_step = 1250000

nfe_step = 32  # 16, 32
cfg_strength = 2.0
ode_method = "euler"  # euler | midpoint
sway_sampling_coef = -1.0
speed = 1.0
target_rms = 0.1


model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch

dataset_name = model_cfg.datasets.name
tokenizer = model_cfg.model.tokenizer

mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
n_mel_channels = model_cfg.model.mel_spec.n_mel_channels
hop_length = model_cfg.model.mel_spec.hop_length
win_length = model_cfg.model.mel_spec.win_length
n_fft = model_cfg.model.mel_spec.n_fft


# ckpt_path = str(files("f5_tts").joinpath("../../")) + f"/ckpts/{exp_name}/model_{ckpt_step}.safetensors"
ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.safetensors"))
output_dir = "tests"


# [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
# pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
# [write the origin_text into a file, e.g. tests/test_edit.txt]
# ctc-forced-aligner --audio_path "src/f5_tts/infer/examples/basic/basic_ref_en.wav" --text_path "tests/test_edit.txt" --language "zho" --romanize --split_size "char"
# [result will be saved at same path of audio file]
# [--language "zho" for Chinese, "eng" for English]
# [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"]

audio_to_edit = str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav"))
origin_text = "Some call me nature, others call me mother nature."
target_text = "Some call me optimist, others call me realist."
parts_to_edit = [
    [1.42, 2.44],
    [4.04, 4.9],
]  # stard_ends of "nature" & "mother nature", in seconds
fix_duration = [
    1.2,
    1,
]  # fix duration for "optimist" & "realist", in seconds

# audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_zh.wav"
# origin_text = "对,这就是我,万人敬仰的太乙真人。"
# target_text = "对,那就是你,万人敬仰的太白金星。"
# parts_to_edit = [[0.84, 1.4], [1.92, 2.4], [4.26, 6.26], ]
# fix_duration = None  # use origin text duration

# audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_zh.wav"
# origin_text = "对,这就是我,万人敬仰的太乙真人。"
# target_text = "对,这就是你,万人敬仰的李白金星。"
# parts_to_edit = [[1.500, 2.784], [4.083, 6.760]]
# fix_duration = [1.284, 2.677]


# -------------------------------------------------#

use_ema = True

if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Vocoder model
local = False
if mel_spec_type == "vocos":
    vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
elif mel_spec_type == "bigvgan":
    vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path)

# Tokenizer
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)

# Model
model = CFM(
    transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
    mel_spec_kwargs=dict(
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        n_mel_channels=n_mel_channels,
        target_sample_rate=target_sample_rate,
        mel_spec_type=mel_spec_type,
    ),
    odeint_kwargs=dict(
        method=ode_method,
    ),
    vocab_char_map=vocab_char_map,
).to(device)

dtype = torch.float32 if mel_spec_type == "bigvgan" else None
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)

# Audio
audio, sr = torchaudio.load(audio_to_edit)
if audio.shape[0] > 1:
    audio = torch.mean(audio, dim=0, keepdim=True)
rms = torch.sqrt(torch.mean(torch.square(audio)))
if rms < target_rms:
    audio = audio * target_rms / rms
if sr != target_sample_rate:
    resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
    audio = resampler(audio)

# Convert to mel spectrogram FIRST (on clean original audio)
# This avoids boundary artifacts from mel windows straddling zeros and real audio
audio = audio.to(device)
with torch.inference_mode():
    original_mel = model.mel_spec(audio)  # (batch, n_mel, n_frames)
    original_mel = original_mel.permute(0, 2, 1)  # (batch, n_frames, n_mel)

# Build mel_cond and edit_mask at FRAME level
# Insert zero frames in mel domain instead of zero samples in wav domain
offset_frame = 0
mel_cond = torch.zeros(1, 0, n_mel_channels, device=device)
edit_mask = torch.zeros(1, 0, dtype=torch.bool, device=device)
fix_dur_list = fix_duration.copy() if fix_duration is not None else None

for part in parts_to_edit:
    start, end = part
    part_dur_sec = end - start if fix_dur_list is None else fix_dur_list.pop(0)

    # Convert to frames (this is the authoritative unit)
    start_frame = round(start * target_sample_rate / hop_length)
    end_frame = round(end * target_sample_rate / hop_length)
    part_dur_frames = round(part_dur_sec * target_sample_rate / hop_length)

    # Number of frames for the kept (non-edited) region
    keep_frames = start_frame - offset_frame

    # Build mel_cond: original mel frames + zero frames for edit region
    mel_cond = torch.cat(
        (
            mel_cond,
            original_mel[:, offset_frame:start_frame, :],
            torch.zeros(1, part_dur_frames, n_mel_channels, device=device),
        ),
        dim=1,
    )
    edit_mask = torch.cat(
        (
            edit_mask,
            torch.ones(1, keep_frames, dtype=torch.bool, device=device),
            torch.zeros(1, part_dur_frames, dtype=torch.bool, device=device),
        ),
        dim=-1,
    )
    offset_frame = end_frame

# Append remaining mel frames after last edit
mel_cond = torch.cat((mel_cond, original_mel[:, offset_frame:, :]), dim=1)
edit_mask = F.pad(edit_mask, (0, mel_cond.shape[1] - edit_mask.shape[-1]), value=True)

# Text
text_list = [target_text]
if tokenizer == "pinyin":
    final_text_list = convert_char_to_pinyin(text_list)
else:
    final_text_list = [text_list]
print(f"text  : {text_list}")
print(f"pinyin: {final_text_list}")

# Duration - use mel_cond length (not raw audio length)
duration = mel_cond.shape[1]

# Inference - pass mel_cond directly (not wav)
with torch.inference_mode():
    generated, trajectory = model.sample(
        cond=mel_cond,  # Now passing mel directly, not wav
        text=final_text_list,
        duration=duration,
        steps=nfe_step,
        cfg_strength=cfg_strength,
        sway_sampling_coef=sway_sampling_coef,
        seed=seed,
        edit_mask=edit_mask,
    )
    print(f"Generated mel: {generated.shape}")

    # Final result
    generated = generated.to(torch.float32)
    gen_mel_spec = generated.permute(0, 2, 1)
    if mel_spec_type == "vocos":
        generated_wave = vocoder.decode(gen_mel_spec).cpu()
    elif mel_spec_type == "bigvgan":
        generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()

    if rms < target_rms:
        generated_wave = generated_wave * rms / target_rms

    save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
    torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave, target_sample_rate)
    print(f"Generated wav: {generated_wave.shape}")


================================================
FILE: src/f5_tts/infer/utils_infer.py
================================================
# A unified script for inference process
# Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format
import os
import sys
from concurrent.futures import ThreadPoolExecutor


os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"  # for MPS device compatibility
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/")

import hashlib
import re
import tempfile
from importlib.resources import files

import matplotlib


matplotlib.use("Agg")

import matplotlib.pylab as plt
import numpy as np
import torch
import torchaudio
import tqdm
from huggingface_hub import hf_hub_download
from pydub import AudioSegment, silence
from transformers import pipeline
from vocos import Vocos

from f5_tts.model import CFM
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer


_ref_audio_cache = {}
_ref_text_cache = {}

device = (
    "cuda"
    if torch.cuda.is_available()
    else "xpu"
    if torch.xpu.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

tempfile_kwargs = {"delete_on_close": False} if sys.version_info >= (3, 12) else {"delete": False}

# -----------------------------------------

target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
win_length = 1024
n_fft = 1024
mel_spec_type = "vocos"
target_rms = 0.1
cross_fade_duration = 0.15
ode_method = "euler"
nfe_step = 32  # 16, 32
cfg_strength = 2.0
sway_sampling_coef = -1.0
speed = 1.0
fix_duration = None

# -----------------------------------------


# chunk text into smaller pieces


def chunk_text(text, max_chars=135):
    """
    Splits the input text into chunks, each with a maximum number of characters.

    Args:
        text (str): The text to be split.
        max_chars (int): The maximum number of characters per chunk.

    Returns:
        List[str]: A list of text chunks.
    """
    chunks = []
    current_chunk = ""
    # Split the text into sentences based on punctuation followed by whitespace
    sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)

    for sentence in sentences:
        if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
            current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
        else:
            if current_chunk:
                chunks.append(current_chunk.strip())
            current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence

    if current_chunk:
        chunks.append(current_chunk.strip())

    return chunks


# load vocoder
def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=device, hf_cache_dir=None):
    if vocoder_name == "vocos":
        # vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
        if is_local:
            print(f"Load vocos from local path {local_path}")
            config_path = f"{local_path}/config.yaml"
            model_path = f"{local_path}/pytorch_model.bin"
        else:
            print("Download Vocos from huggingface charactr/vocos-mel-24khz")
            repo_id = "charactr/vocos-mel-24khz"
            config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
            model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
        vocoder = Vocos.from_hparams(config_path)
        state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
        from vocos.feature_extractors import EncodecFeatures

        if isinstance(vocoder.feature_extractor, EncodecFeatures):
            encodec_parameters = {
                "feature_extractor.encodec." + key: value
                for key, value in vocoder.feature_extractor.encodec.state_dict().items()
            }
            state_dict.update(encodec_parameters)
        vocoder.load_state_dict(state_dict)
        vocoder = vocoder.eval().to(device)
    elif vocoder_name == "bigvgan":
        try:
            from third_party.BigVGAN import bigvgan
        except ImportError:
            print("You need to follow the README to init submodule and change the BigVGAN source code.")
        if is_local:
            # download generator from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main
            vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
        else:
            vocoder = bigvgan.BigVGAN.from_pretrained(
                "nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False, cache_dir=hf_cache_dir
            )

        vocoder.remove_weight_norm()
        vocoder = vocoder.eval().to(device)
    return vocoder


# load asr pipeline

asr_pipe = None


def initialize_asr_pipeline(device: str = device, dtype=None):
    if dtype is None:
        dtype = (
            torch.float16
            if "cuda" in device
            and torch.cuda.get_device_properties(device).major >= 7
            and not torch.cuda.get_device_name().endswith("[ZLUDA]")
            else torch.float32
        )
    global asr_pipe
    asr_pipe = pipeline(
        "automatic-speech-recognition",
        model="openai/whisper-large-v3-turbo",
        torch_dtype=dtype,
        device=device,
    )


# transcribe


def transcribe(ref_audio, language=None):
    global asr_pipe
    if asr_pipe is None:
        initialize_asr_pipeline(device=device)
    return asr_pipe(
        ref_audio,
        chunk_length_s=30,
        batch_size=128,
        generate_kwargs={"task": "transcribe", "language": language} if language else {"task": "transcribe"},
        return_timestamps=False,
    )["text"].strip()


# load model checkpoint for inference


def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
    if dtype is None:
        dtype = (
            torch.float16
            if "cuda" in device
            and torch.cuda.get_device_properties(device).major >= 7
            and not torch.cuda.get_device_name().endswith("[ZLUDA]")
            else torch.float32
        )
    model = model.to(dtype)

    ckpt_type = ckpt_path.split(".")[-1]
    if ckpt_type == "safetensors":
        from safetensors.torch import load_file

        checkpoint = load_file(ckpt_path, device=device)
    else:
        checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)

    if use_ema:
        if ckpt_type == "safetensors":
            checkpoint = {"ema_model_state_dict": checkpoint}
        checkpoint["model_state_dict"] = {
            k.replace("ema_model.", ""): v
            for k, v in checkpoint["ema_model_state_dict"].items()
            if k not in ["initted", "step"]
        }

        # patch for backward compati
Download .txt
gitextract_4jvakfwh/

├── .github/
│   ├── ISSUE_TEMPLATE/
│   │   ├── bug_report.yml
│   │   ├── config.yml
│   │   ├── feature_request.yml
│   │   ├── help_wanted.yml
│   │   └── question.yml
│   └── workflows/
│       ├── pre-commit.yaml
│       ├── publish-docker-image.yaml
│       └── publish-pypi.yaml
├── .gitignore
├── .gitmodules
├── .pre-commit-config.yaml
├── Dockerfile
├── LICENSE
├── README.md
├── pyproject.toml
├── ruff.toml
└── src/
    └── f5_tts/
        ├── api.py
        ├── configs/
        │   ├── E2TTS_Base.yaml
        │   ├── E2TTS_Small.yaml
        │   ├── F5TTS_Base.yaml
        │   ├── F5TTS_Small.yaml
        │   └── F5TTS_v1_Base.yaml
        ├── eval/
        │   ├── README.md
        │   ├── ecapa_tdnn.py
        │   ├── eval_infer_batch.py
        │   ├── eval_infer_batch.sh
        │   ├── eval_infer_batch_example.sh
        │   ├── eval_librispeech_test_clean.py
        │   ├── eval_seedtts_testset.py
        │   ├── eval_utmos.py
        │   └── utils_eval.py
        ├── infer/
        │   ├── README.md
        │   ├── SHARED.md
        │   ├── examples/
        │   │   ├── basic/
        │   │   │   └── basic.toml
        │   │   ├── multi/
        │   │   │   ├── country.flac
        │   │   │   ├── main.flac
        │   │   │   ├── story.toml
        │   │   │   ├── story.txt
        │   │   │   └── town.flac
        │   │   └── vocab.txt
        │   ├── infer_cli.py
        │   ├── infer_gradio.py
        │   ├── speech_edit.py
        │   └── utils_infer.py
        ├── model/
        │   ├── __init__.py
        │   ├── backbones/
        │   │   ├── README.md
        │   │   ├── dit.py
        │   │   ├── mmdit.py
        │   │   └── unett.py
        │   ├── cfm.py
        │   ├── dataset.py
        │   ├── modules.py
        │   ├── trainer.py
        │   └── utils.py
        ├── runtime/
        │   └── triton_trtllm/
        │       ├── .gitignore
        │       ├── Dockerfile.server
        │       ├── README.md
        │       ├── benchmark.py
        │       ├── client_grpc.py
        │       ├── client_http.py
        │       ├── docker-compose.yml
        │       ├── model_repo_f5_tts/
        │       │   ├── f5_tts/
        │       │   │   ├── 1/
        │       │   │   │   ├── f5_tts_trtllm.py
        │       │   │   │   └── model.py
        │       │   │   └── config.pbtxt
        │       │   └── vocoder/
        │       │       ├── 1/
        │       │       │   └── .gitkeep
        │       │       └── config.pbtxt
        │       ├── patch/
        │       │   ├── __init__.py
        │       │   └── f5tts/
        │       │       ├── model.py
        │       │       └── modules.py
        │       ├── run.sh
        │       └── scripts/
        │           ├── conv_stft.py
        │           ├── convert_checkpoint.py
        │           ├── export_vocoder_to_onnx.py
        │           ├── export_vocos_trt.sh
        │           └── fill_template.py
        ├── scripts/
        │   ├── count_max_epoch.py
        │   ├── count_max_epoch_precise.py
        │   └── count_params_gflops.py
        ├── socket_client.py
        ├── socket_server.py
        └── train/
            ├── README.md
            ├── datasets/
            │   ├── prepare_csv_wavs.py
            │   ├── prepare_emilia.py
            │   ├── prepare_emilia_v2.py
            │   ├── prepare_libritts.py
            │   ├── prepare_ljspeech.py
            │   └── prepare_wenetspeech4tts.py
            ├── finetune_cli.py
            ├── finetune_gradio.py
            └── train.py
Download .txt
SYMBOL INDEX (390 symbols across 40 files)

FILE: src/f5_tts/api.py
  class F5TTS (line 23) | class F5TTS:
    method __init__ (line 24) | def __init__(
    method transcribe (line 86) | def transcribe(self, ref_audio, language=None):
    method export_wav (line 89) | def export_wav(self, wav, file_wave, remove_silence=False):
    method export_spectrogram (line 95) | def export_spectrogram(self, spec, file_spec):
    method infer (line 98) | def infer(

FILE: src/f5_tts/eval/ecapa_tdnn.py
  class Res2Conv1dReluBn (line 17) | class Res2Conv1dReluBn(nn.Module):
    method __init__ (line 22) | def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilat...
    method forward (line 37) | def forward(self, x):
  class Conv1dReluBn (line 60) | class Conv1dReluBn(nn.Module):
    method __init__ (line 61) | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,...
    method forward (line 66) | def forward(self, x):
  class SE_Connect (line 74) | class SE_Connect(nn.Module):
    method __init__ (line 75) | def __init__(self, channels, se_bottleneck_dim=128):
    method forward (line 80) | def forward(self, x):
  class SE_Res2Block (line 101) | class SE_Res2Block(nn.Module):
    method __init__ (line 102) | def __init__(self, in_channels, out_channels, kernel_size, stride, pad...
    method forward (line 117) | def forward(self, x):
  class AttentiveStatsPool (line 134) | class AttentiveStatsPool(nn.Module):
    method __init__ (line 135) | def __init__(self, in_dim, attention_channels=128, global_context_att=...
    method forward (line 146) | def forward(self, x):
  class ECAPA_TDNN (line 164) | class ECAPA_TDNN(nn.Module):
    method __init__ (line 165) | def __init__(
    method get_feat_num (line 260) | def get_feat_num(self):
    method get_feat (line 271) | def get_feat(self, x):
    method forward (line 297) | def forward(self, x):
  function ECAPA_TDNN_SMALL (line 313) | def ECAPA_TDNN_SMALL(

FILE: src/f5_tts/eval/eval_infer_batch.py
  function main (line 39) | def main():

FILE: src/f5_tts/eval/eval_librispeech_test_clean.py
  function get_args (line 23) | def get_args():
  function parse_gpu_nums (line 36) | def parse_gpu_nums(gpu_nums_str):
  function main (line 49) | def main():

FILE: src/f5_tts/eval/eval_seedtts_testset.py
  function get_args (line 23) | def get_args():
  function parse_gpu_nums (line 35) | def parse_gpu_nums(gpu_nums_str):
  function main (line 48) | def main():

FILE: src/f5_tts/eval/eval_utmos.py
  function main (line 10) | def main():

FILE: src/f5_tts/eval/utils_eval.py
  function get_seedtts_testset_metainfo (line 18) | def get_seedtts_testset_metainfo(metalst):
  function get_librispeech_test_clean_metainfo (line 36) | def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_...
  function padded_mel_batch (line 58) | def padded_mel_batch(ref_mels):
  function get_inference_prompt (line 72) | def get_inference_prompt(
  function get_seed_tts_test (line 212) | def get_seed_tts_test(metalst, gen_wav_dir, gpus):
  function get_librispeech_test (line 247) | def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_cl...
  function load_asr_model (line 284) | def load_asr_model(lang, ckpt_dir=""):
  function run_asr_wer (line 306) | def run_asr_wer(args):
  function run_sim (line 380) | def run_sim(args):

FILE: src/f5_tts/infer/infer_cli.py
  function main (line 307) | def main():

FILE: src/f5_tts/infer/infer_gradio.py
  function gpu_decorator (line 31) | def gpu_decorator(func):
  function load_f5tts (line 65) | def load_f5tts():
  function load_e2tts (line 71) | def load_e2tts():
  function load_custom (line 77) | def load_custom(ckpt_path: str, vocab_path="", model_cfg=None):
  function chat_model_inference (line 99) | def chat_model_inference(messages, model, tokenizer):
  function load_text_from_file (line 122) | def load_text_from_file(file):
  function infer (line 133) | def infer(
  function collapse_accordion (line 272) | def collapse_accordion():
  function basic_tts (line 287) | def basic_tts(
  function parse_speechtypes_text (line 349) | def parse_speechtypes_text(gen_text):
  function add_speech_type_fn (line 501) | def add_speech_type_fn():
  function delete_speech_type_fn (line 514) | def delete_speech_type_fn():
  function make_insert_speech_type_fn (line 541) | def make_insert_speech_type_fn(index):
  function generate_multistyle_speech (line 611) | def generate_multistyle_speech(
  function validate_speech_types (line 707) | def validate_speech_types(gen_text, regular_name, *args):
  function load_chat_model (line 757) | def load_chat_model(chat_model_name):
  function process_audio_input (line 851) | def process_audio_input(conv_state, audio_path, text):
  function generate_text_response (line 867) | def generate_text_response(conv_state, system_prompt):
  function generate_audio_response (line 881) | def generate_audio_response(conv_state, ref_audio, ref_text, remove_sile...
  function clear_conversation (line 906) | def clear_conversation():
  function load_last_used_custom (line 980) | def load_last_used_custom():
  function switch_tts_model (line 991) | def switch_tts_model(new_choice):
  function set_custom_model (line 1005) | def set_custom_model(custom_ckpt_path, custom_vocab_path, custom_model_c...
  function main (line 1121) | def main(port, host, share, api, root_path, inbrowser):

FILE: src/f5_tts/infer/utils_infer.py
  function chunk_text (line 73) | def chunk_text(text, max_chars=135):
  function load_vocoder (line 104) | def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", de...
  function initialize_asr_pipeline (line 151) | def initialize_asr_pipeline(device: str = device, dtype=None):
  function transcribe (line 172) | def transcribe(ref_audio, language=None):
  function load_checkpoint (line 188) | def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=T...
  function load_model (line 236) | def load_model(
  function remove_silence_edges (line 277) | def remove_silence_edges(audio, silence_threshold=-42):
  function preprocess_ref_audio_text (line 296) | def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
  function infer_process (line 382) | def infer_process(
  function infer_batch_process (line 433) | def infer_batch_process(
  function remove_silence_for_generated_wav (line 585) | def remove_silence_for_generated_wav(filename):
  function save_spectrogram (line 600) | def save_spectrogram(spectrogram, path):

FILE: src/f5_tts/model/backbones/dit.py
  class TextEmbedding (line 31) | class TextEmbedding(nn.Module):
    method __init__ (line 32) | def __init__(
    method average_upsample_text_by_mask (line 53) | def average_upsample_text_by_mask(self, text, text_mask, target_lens):
    method forward (line 84) | def forward(self, text: int["b nt"], seq_len, drop_text=False):
  class InputEmbedding (line 143) | class InputEmbedding(nn.Module):
    method __init__ (line 144) | def __init__(self, mel_dim, text_dim, out_dim):
    method forward (line 149) | def forward(
  class DiT (line 168) | class DiT(nn.Module):
    method __init__ (line 169) | def __init__(
    method initialize_weights (line 236) | def initialize_weights(self):
    method ckpt_wrapper (line 248) | def ckpt_wrapper(self, module):
    method get_input_embed (line 256) | def get_input_embed(
    method clear_cache (line 288) | def clear_cache(self):
    method forward (line 291) | def forward(

FILE: src/f5_tts/model/backbones/mmdit.py
  class TextEmbedding (line 30) | class TextEmbedding(nn.Module):
    method __init__ (line 31) | def __init__(self, out_dim, text_num_embeds, mask_padding=True):
    method forward (line 40) | def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]:
  class AudioEmbedding (line 67) | class AudioEmbedding(nn.Module):
    method __init__ (line 68) | def __init__(self, in_dim, out_dim):
    method forward (line 73) | def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_...
  class MMDiT (line 85) | class MMDiT(nn.Module):
    method __init__ (line 86) | def __init__(
    method initialize_weights (line 138) | def initialize_weights(self):
    method ckpt_wrapper (line 152) | def ckpt_wrapper(self, module):
    method get_input_embed (line 159) | def get_input_embed(
    method clear_cache (line 183) | def clear_cache(self):
    method forward (line 186) | def forward(

FILE: src/f5_tts/model/backbones/unett.py
  class TextEmbedding (line 36) | class TextEmbedding(nn.Module):
    method __init__ (line 37) | def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_...
    method forward (line 53) | def forward(self, text: int["b nt"], seq_len, drop_text=False):
  class InputEmbedding (line 89) | class InputEmbedding(nn.Module):
    method __init__ (line 90) | def __init__(self, mel_dim, text_dim, out_dim):
    method forward (line 95) | def forward(self, x: float["b n d"], cond: float["b n d"], text_embed:...
  class UNetT (line 107) | class UNetT(nn.Module):
    method __init__ (line 108) | def __init__(
    method get_input_embed (line 188) | def get_input_embed(
    method clear_cache (line 214) | def clear_cache(self):
    method forward (line 217) | def forward(

FILE: src/f5_tts/model/cfm.py
  class CFM (line 34) | class CFM(nn.Module):
    method __init__ (line 35) | def __init__(
    method device (line 80) | def device(self):
    method sample (line 84) | def sample(
    method forward (line 231) | def forward(

FILE: src/f5_tts/model/dataset.py
  class HFDataset (line 17) | class HFDataset(Dataset):
    method __init__ (line 18) | def __init__(
    method get_frame_len (line 41) | def get_frame_len(self, index):
    method __len__ (line 47) | def __len__(self):
    method __getitem__ (line 50) | def __getitem__(self, index):
  class CustomDataset (line 82) | class CustomDataset(Dataset):
    method __init__ (line 83) | def __init__(
    method get_frame_len (line 118) | def get_frame_len(self, index):
    method __len__ (line 125) | def __len__(self):
    method __getitem__ (line 128) | def __getitem__(self, index):
  class DynamicBatchSampler (line 166) | class DynamicBatchSampler(Sampler[list[int]]):
    method __init__ (line 175) | def __init__(
    method set_epoch (line 220) | def set_epoch(self, epoch: int) -> None:
    method __iter__ (line 224) | def __iter__(self):
    method __len__ (line 236) | def __len__(self):
  function load_dataset (line 243) | def load_dataset(
  function collate_fn (line 309) | def collate_fn(batch):

FILE: src/f5_tts/model/modules.py
  function get_bigvgan_mel_spectrogram (line 34) | def get_bigvgan_mel_spectrogram(
  function get_vocos_mel_spectrogram (line 79) | def get_vocos_mel_spectrogram(
  class MelSpec (line 108) | class MelSpec(nn.Module):
    method __init__ (line 109) | def __init__(
    method forward (line 134) | def forward(self, wav):
  class SinusPositionEmbedding (line 153) | class SinusPositionEmbedding(nn.Module):
    method __init__ (line 154) | def __init__(self, dim):
    method forward (line 158) | def forward(self, x, scale=1000):
  class ConvPositionEmbedding (line 171) | class ConvPositionEmbedding(nn.Module):
    method __init__ (line 172) | def __init__(self, dim, kernel_size=31, groups=16):
    method forward (line 183) | def forward(self, x: float["b n d"], mask: bool["b n"] | None = None):
  function precompute_freqs_cis (line 203) | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, the...
  function get_pos_embed_indices (line 217) | def get_pos_embed_indices(start, length, max_pos, scale=1.0):
  class GRN (line 232) | class GRN(nn.Module):
    method __init__ (line 233) | def __init__(self, dim):
    method forward (line 238) | def forward(self, x):
  class ConvNeXtV2Block (line 248) | class ConvNeXtV2Block(nn.Module):
    method __init__ (line 249) | def __init__(
    method forward (line 266) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class RMSNorm (line 282) | class RMSNorm(nn.Module):
    method __init__ (line 283) | def __init__(self, dim: int, eps: float):
    method forward (line 289) | def forward(self, x):
  class AdaLayerNorm (line 308) | class AdaLayerNorm(nn.Module):
    method __init__ (line 309) | def __init__(self, dim):
    method forward (line 317) | def forward(self, x, emb=None):
  class AdaLayerNorm_Final (line 329) | class AdaLayerNorm_Final(nn.Module):
    method __init__ (line 330) | def __init__(self, dim):
    method forward (line 338) | def forward(self, x, emb):
  class FeedForward (line 349) | class FeedForward(nn.Module):
    method __init__ (line 350) | def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate...
    method forward (line 359) | def forward(self, x):
  class Attention (line 367) | class Attention(nn.Module):
    method __init__ (line 368) | def __init__(
    method forward (line 425) | def forward(
  class AttnProcessor (line 447) | class AttnProcessor:
    method __init__ (line 448) | def __init__(
    method __call__ (line 467) | def __call__(
  class JointAttnProcessor (line 559) | class JointAttnProcessor:
    method __init__ (line 560) | def __init__(
    method __call__ (line 577) | def __call__(
  class DiTBlock (line 707) | class DiTBlock(nn.Module):
    method __init__ (line 708) | def __init__(
    method forward (line 739) | def forward(self, x, t, mask=None, rope=None):  # x: noised input, t: ...
  class MMDiTBlock (line 759) | class MMDiTBlock(nn.Module):
    method __init__ (line 769) | def __init__(
    method forward (line 812) | def forward(
  class TimestepEmbedding (line 848) | class TimestepEmbedding(nn.Module):
    method __init__ (line 849) | def __init__(self, dim, freq_embed_dim=256):
    method forward (line 854) | def forward(self, timestep: float["b"]):

FILE: src/f5_tts/model/trainer.py
  class Trainer (line 26) | class Trainer:
    method __init__ (line 27) | def __init__(
    method is_main (line 147) | def is_main(self):
    method save_checkpoint (line 150) | def save_checkpoint(self, update, last=False):
    method load_checkpoint (line 185) | def load_checkpoint(self):
    method train (line 265) | def train(self, train_dataset: Dataset, num_workers=16, resumable_with...

FILE: src/f5_tts/model/utils.py
  function seed_everything (line 19) | def seed_everything(seed=0):
  function exists (line 32) | def exists(v):
  function default (line 36) | def default(v, d):
  function is_package_available (line 40) | def is_package_available(package_name: str) -> bool:
  function lens_to_mask (line 53) | def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]:
  function mask_from_start_end_indices (line 61) | def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end:...
  function mask_from_frac_lengths (line 69) | def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]):
  function maybe_masked_mean (line 80) | def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> fl...
  function list_str_to_tensor (line 92) | def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]:
  function list_str_to_idx (line 99) | def list_str_to_idx(
  function get_tokenizer (line 112) | def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
  function convert_char_to_pinyin (line 148) | def convert_char_to_pinyin(text_list, polyphone=True):
  function repetition_found (line 191) | def repetition_found(text, length=2, tolerance=10):
  function get_epss_timesteps (line 205) | def get_epss_timesteps(n, device, dtype):

FILE: src/f5_tts/runtime/triton_trtllm/benchmark.py
  function get_args (line 64) | def get_args():
  function data_collator (line 120) | def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
  function init_distributed (line 199) | def init_distributed():
  function load_vocoder (line 215) | def load_vocoder(
  class VocosTensorRT (line 249) | class VocosTensorRT:
    method __init__ (line 250) | def __init__(self, engine_path="./vocos_vocoder.plan", stream=None):
    method decode (line 260) | def decode(self, mels):
  function main (line 275) | def main():

FILE: src/f5_tts/runtime/triton_trtllm/client_grpc.py
  function write_triton_stats (line 50) | def write_triton_stats(stats, summary_file):
  function get_args (line 106) | def get_args():
  function load_audio (line 213) | def load_audio(wav_path, target_sample_rate=24000):
  function send (line 227) | async def send(
  function load_manifests (line 309) | def load_manifests(manifest_path):
  function split_data (line 330) | def split_data(data, k):
  function main (line 353) | async def main():

FILE: src/f5_tts/runtime/triton_trtllm/client_http.py
  function get_args (line 34) | def get_args():
  function prepare_request (line 81) | def prepare_request(
  function load_audio (line 109) | def load_audio(wav_path, target_sample_rate=24000):

FILE: src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py
  function remove_tensor_padding (line 18) | def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
  class TextEmbedding (line 35) | class TextEmbedding(nn.Module):
    method __init__ (line 36) | def __init__(
    method forward (line 45) | def forward(self, text, seq_len, drop_text=False):
  class GRN (line 68) | class GRN(nn.Module):
    method __init__ (line 69) | def __init__(self, dim):
    method forward (line 74) | def forward(self, x):
  class ConvNeXtV2Block (line 80) | class ConvNeXtV2Block(nn.Module):
    method __init__ (line 81) | def __init__(
    method forward (line 98) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  function precompute_freqs_cis (line 111) | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, the...
  function get_text_embed_dict (line 125) | def get_text_embed_dict(ckpt_path, use_ema=True):
  class F5TTS (line 155) | class F5TTS(object):
    method __init__ (line 156) | def __init__(
    method _tensor_dtype (line 263) | def _tensor_dtype(self, name):
    method _setup (line 268) | def _setup(self, batch_size, seq_len):
    method cuda_stream_guard (line 279) | def cuda_stream_guard(func):
    method forward (line 297) | def forward(
    method sample (line 374) | def sample(

FILE: src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py
  function get_tokenizer (line 39) | def get_tokenizer(vocab_file_path: str):
  function convert_char_to_pinyin (line 57) | def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
  function list_str_to_idx (line 95) | def list_str_to_idx(
  class TritonPythonModel (line 105) | class TritonPythonModel:
    method initialize (line 106) | def initialize(self, args):
    method get_vocos_mel_spectrogram (line 155) | def get_vocos_mel_spectrogram(self, waveform):
    method forward_vocoder (line 160) | def forward_vocoder(self, mel):
    method execute (line 176) | def execute(self, requests):

FILE: src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py
  class InputEmbedding (line 33) | class InputEmbedding(Module):
    method __init__ (line 34) | def __init__(self, mel_dim, text_dim, out_dim):
    method forward (line 39) | def forward(self, x, cond, mask=None):
  class F5TTS (line 44) | class F5TTS(PretrainedModel):
    method __init__ (line 45) | def __init__(self, config: PretrainedConfig):
    method forward (line 71) | def forward(
    method prepare_inputs (line 105) | def prepare_inputs(self, **kwargs):

FILE: src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py
  class FeedForward (line 37) | class FeedForward(Module):
    method __init__ (line 38) | def __init__(self, dim, dim_out=None, mult=4, dropout=0.0):
    method forward (line 46) | def forward(self, x):
  class AdaLayerNormZero (line 50) | class AdaLayerNormZero(Module):
    method __init__ (line 51) | def __init__(self, dim):
    method forward (line 57) | def forward(self, x, emb=None):
  class AdaLayerNormZero_Final (line 69) | class AdaLayerNormZero_Final(Module):
    method __init__ (line 70) | def __init__(self, dim):
    method forward (line 77) | def forward(self, x, emb):
  class ConvPositionEmbedding (line 89) | class ConvPositionEmbedding(Module):
    method __init__ (line 90) | def __init__(self, dim, kernel_size=31, groups=16):
    method forward (line 97) | def forward(self, x, mask=None):
  class Attention (line 117) | class Attention(Module):
    method __init__ (line 118) | def __init__(
    method forward (line 190) | def forward(
  function rotate_every_two_3dim (line 210) | def rotate_every_two_3dim(tensor: Tensor) -> Tensor:
  function apply_rotary_pos_emb_3dim (line 239) | def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin, pe_attn_head):
  class AttnProcessor (line 279) | class AttnProcessor:
    method __init__ (line 280) | def __init__(
    method __call__ (line 286) | def __call__(
  class DiTBlock (line 377) | class DiTBlock(Module):
    method __init__ (line 378) | def __init__(self, dim, heads, dim_head, ff_mult=2, dropout=0.1, pe_at...
    method forward (line 393) | def forward(
  class TimestepEmbedding (line 423) | class TimestepEmbedding(Module):
    method __init__ (line 424) | def __init__(self, dim, freq_embed_dim=256, dtype=None):
    method forward (line 430) | def forward(self, timestep):

FILE: src/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py
  class STFT (line 53) | class STFT(th.nn.Module):
    method __init__ (line 54) | def __init__(
    method __init_kernel__ (line 101) | def __init_kernel__(self):
    method is_perfect (line 146) | def is_perfect(self):
    method transform (line 156) | def transform(self, inputs, return_type="complex"):
    method inverse (line 193) | def inverse(self, input1, input2=None, input_type="magphase"):
    method forward (line 236) | def forward(self, inputs):

FILE: src/f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py
  function split_q_tp (line 16) | def split_q_tp(v, n_head, n_hidden, tensor_parallel, rank):
  function split_q_bias_tp (line 21) | def split_q_bias_tp(v, n_head, n_hidden, tensor_parallel, rank):
  function parse_arguments (line 26) | def parse_arguments():
  function convert_pytorch_dit_to_trtllm_weight (line 113) | def convert_pytorch_dit_to_trtllm_weight(args, mapping, dtype="float32",...
  function save_config (line 201) | def save_config(args):
  function covert_and_save (line 236) | def covert_and_save(args, rank):
  function execute (line 253) | def execute(workers, func, args):
  function main (line 270) | def main():

FILE: src/f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py
  function get_args (line 27) | def get_args():
  class ISTFTHead (line 45) | class ISTFTHead(nn.Module):
    method __init__ (line 46) | def __init__(self, n_fft: int, hop_length: int):
    method forward (line 51) | def forward(self, x: torch.Tensor):
  class VocosVocoder (line 62) | class VocosVocoder(nn.Module):
    method __init__ (line 63) | def __init__(self, vocos_vocoder):
    method forward (line 73) | def forward(self, mel):
  function export_VocosVocoder (line 78) | def export_VocosVocoder(vocos_vocoder, output_path, verbose):
  function load_vocoder (line 111) | def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", de...

FILE: src/f5_tts/runtime/triton_trtllm/scripts/fill_template.py
  function main (line 6) | def main(file_path, substitutions, in_place, participant_ids):

FILE: src/f5_tts/socket_client.py
  function listen_to_F5TTS (line 14) | async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998):

FILE: src/f5_tts/socket_server.py
  class AudioFileWriterThread (line 32) | class AudioFileWriterThread(threading.Thread):
    method __init__ (line 35) | def __init__(self, output_file, sampling_rate):
    method run (line 43) | def run(self):
    method add_chunk (line 61) | def add_chunk(self, chunk):
    method stop (line 65) | def stop(self):
  class TTSStreamingProcessor (line 72) | class TTSStreamingProcessor:
    method __init__ (line 73) | def __init__(self, model, ckpt_file, vocab_file, ref_audio, ref_text, ...
    method load_ema_model (line 97) | def load_ema_model(self, ckpt_file, vocab_file, dtype):
    method load_vocoder_model (line 109) | def load_vocoder_model(self):
    method update_reference (line 112) | def update_reference(self, ref_audio, ref_text):
    method _warm_up (line 122) | def _warm_up(self):
    method generate_stream (line 138) | def generate_stream(self, text, conn):
  function handle_client (line 180) | def handle_client(conn, processor):
  function start_server (line 203) | def start_server(host, port, processor):

FILE: src/f5_tts/train/datasets/prepare_csv_wavs.py
  function is_csv_wavs_format (line 50) | def is_csv_wavs_format(input_path):
  function graceful_exit (line 56) | def graceful_exit():
  function process_audio_file (line 77) | def process_audio_file(audio_path, text, polyphone):
  function batch_convert_texts (line 92) | def batch_convert_texts(texts, polyphone, batch_size=BATCH_SIZE):
  function prepare_csv_wavs_dir (line 106) | def prepare_csv_wavs_dir(input_path, num_workers=None):
  function get_audio_duration (line 172) | def get_audio_duration(audio_path, timeout=5):
  function read_audio_text_pairs (line 209) | def read_audio_text_pairs(csv_file_path):
  function save_prepped_dataset (line 235) | def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set,...
  function prepare_and_save_set (line 267) | def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True, num...
  function get_args (line 274) | def get_args():
  function cli (line 287) | def cli():

FILE: src/f5_tts/train/datasets/prepare_emilia.py
  function deal_with_audio_dir (line 111) | def deal_with_audio_dir(audio_dir):
  function main (line 147) | def main():

FILE: src/f5_tts/train/datasets/prepare_emilia_v2.py
  function process_audio_directory (line 21) | def process_audio_directory(audio_dir):
  function main (line 44) | def main():

FILE: src/f5_tts/train/datasets/prepare_libritts.py
  function deal_with_audio_dir (line 17) | def deal_with_audio_dir(audio_dir):
  function main (line 34) | def main():

FILE: src/f5_tts/train/datasets/prepare_ljspeech.py
  function main (line 16) | def main():

FILE: src/f5_tts/train/datasets/prepare_wenetspeech4tts.py
  function deal_with_sub_path_files (line 21) | def deal_with_sub_path_files(dataset_path, sub_path):
  function main (line 49) | def main():

FILE: src/f5_tts/train/finetune_cli.py
  function parse_args (line 23) | def parse_args():
  function main (line 81) | def main():

FILE: src/f5_tts/train/finetune_gradio.py
  function save_settings (line 61) | def save_settings(
  function load_settings (line 114) | def load_settings(project_name):
  function get_audio_duration (line 175) | def get_audio_duration(audio_path):
  class Slicer (line 181) | class Slicer:  # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/...
    method __init__ (line 182) | def __init__(
    method _apply_slice (line 203) | def _apply_slice(self, waveform, begin, end):
    method slice (line 210) | def slice(self, waveform):
  function terminate_process_tree (line 298) | def terminate_process_tree(pid, including_parent=True):
  function terminate_process (line 318) | def terminate_process(pid):
  function start_training (line 326) | def start_training(
  function stop_training (line 584) | def stop_training():
  function get_list_projects (line 595) | def get_list_projects():
  function create_data_project (line 611) | def create_data_project(name, tokenizer_type):
  function transcribe_all (line 619) | def transcribe_all(name_project, audio_files, language, user=False, prog...
  function format_seconds_to_hms (line 690) | def format_seconds_to_hms(seconds):
  function get_correct_audio_path (line 697) | def get_correct_audio_path(
  function create_metadata (line 728) | def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
  function check_user (line 839) | def check_user(value):
  function calculate_train (line 843) | def calculate_train(
  function prune_checkpoint (line 932) | def prune_checkpoint(checkpoint_path: str, new_checkpoint_path: str, sav...
  function expand_model_embeddings (line 957) | def expand_model_embeddings(ckpt_path, new_ckpt_path, num_new_tokens=42):
  function vocab_count (line 997) | def vocab_count(text):
  function vocab_extend (line 1001) | def vocab_extend(project_name, symbols, model_type):
  function vocab_check (line 1064) | def vocab_check(project_name, tokenizer_type):
  function get_random_sample_prepare (line 1111) | def get_random_sample_prepare(project_name):
  function get_random_sample_transcribe (line 1124) | def get_random_sample_transcribe(project_name):
  function get_random_sample_infer (line 1153) | def get_random_sample_infer(project_name):
  function infer (line 1162) | def infer(
  function check_finetune (line 1210) | def check_finetune(finetune):
  function get_checkpoints_project (line 1214) | def get_checkpoints_project(project_name, is_gradio=True):
  function get_audio_project (line 1248) | def get_audio_project(project_name, is_gradio=True):
  function get_gpu_stats (line 1269) | def get_gpu_stats():
  function get_cpu_stats (line 1323) | def get_cpu_stats():
  function get_combined_stats (line 1343) | def get_combined_stats():
  function get_audio_select (line 1350) | def get_audio_select(file_sample):
  function setup_load_settings (line 1716) | def setup_load_settings():
  function update_stats (line 1836) | def update_stats():
  function auto_update (line 1842) | def auto_update():
  function main (line 1859) | def main(port, host, share, api):

FILE: src/f5_tts/train/train.py
  function main (line 18) | def main(model_cfg):
Condensed preview — 90 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (617K chars).
[
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.yml",
    "chars": 2053,
    "preview": "name: \"Bug Report\"\ndescription: |\n  Please provide as much details to help address the issue more efficiently, including"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/config.yml",
    "chars": 28,
    "preview": "blank_issues_enabled: false\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature_request.yml",
    "chars": 2479,
    "preview": "name: \"Feature Request\"\ndescription: |\n  Some constructive suggestions and new ideas regarding current repo.\nlabels:\n  -"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/help_wanted.yml",
    "chars": 2421,
    "preview": "name: \"Help Wanted\"\ndescription: |\n  Please provide as much details to help address the issue more efficiently, includin"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/question.yml",
    "chars": 1006,
    "preview": "name: \"Question\"\ndescription: |\n  Research question or pure inquiry about the project, usage issue goes with \"help wante"
  },
  {
    "path": ".github/workflows/pre-commit.yaml",
    "chars": 237,
    "preview": "name: pre-commit\n\non:\n  pull_request:\n  push:\n    branches: [main]\n\njobs:\n  pre-commit:\n    runs-on: ubuntu-latest\n    s"
  },
  {
    "path": ".github/workflows/publish-docker-image.yaml",
    "chars": 3268,
    "preview": "name: Create and publish a Docker image\r\n\r\n# Configures this workflow to run every time a change is pushed to the branch"
  },
  {
    "path": ".github/workflows/publish-pypi.yaml",
    "chars": 1725,
    "preview": "# This workflow uses actions that are not certified by GitHub.\n# They are provided by a third-party and are governed by\n"
  },
  {
    "path": ".gitignore",
    "chars": 3202,
    "preview": "# Customed\n.vscode/\ntests/\nruns/\ndata/\nckpts/\nwandb/\nresults/\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py"
  },
  {
    "path": ".gitmodules",
    "chars": 115,
    "preview": "[submodule \"src/third_party/BigVGAN\"]\n\tpath = src/third_party/BigVGAN\n\turl = https://github.com/NVIDIA/BigVGAN.git\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "chars": 413,
    "preview": "repos:\n  - repo: https://github.com/astral-sh/ruff-pre-commit\n    # Ruff version.\n    rev: v0.11.2\n    hooks:\n      - id"
  },
  {
    "path": "Dockerfile",
    "chars": 865,
    "preview": "FROM pytorch/pytorch:2.4.0-cuda12.4-cudnn9-devel\n\nUSER root\n\nARG DEBIAN_FRONTEND=noninteractive\n\nLABEL github_repo=\"http"
  },
  {
    "path": "LICENSE",
    "chars": 1068,
    "preview": "MIT License\n\nCopyright (c) 2024 Yushen CHEN\n\nPermission is hereby granted, free of charge, to any person obtaining a cop"
  },
  {
    "path": "README.md",
    "chars": 9727,
    "preview": "# F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching\n\n[![python](https://img.shields.io/badge"
  },
  {
    "path": "pyproject.toml",
    "chars": 1561,
    "preview": "[build-system]\nrequires = [\"setuptools >= 61.0\", \"setuptools-scm>=8.0\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[projec"
  },
  {
    "path": "ruff.toml",
    "chars": 198,
    "preview": "line-length = 120\ntarget-version = \"py310\"\n\n[lint]\n# Only ignore variables with names starting with \"_\".\ndummy-variable-"
  },
  {
    "path": "src/f5_tts/api.py",
    "chars": 4986,
    "preview": "import random\nimport sys\nfrom importlib.resources import files\n\nimport soundfile as sf\nimport tqdm\nfrom cached_path impo"
  },
  {
    "path": "src/f5_tts/configs/E2TTS_Base.yaml",
    "chars": 2001,
    "preview": "hydra:\n  run:\n    dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-"
  },
  {
    "path": "src/f5_tts/configs/E2TTS_Small.yaml",
    "chars": 1927,
    "preview": "hydra:\n  run:\n    dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-"
  },
  {
    "path": "src/f5_tts/configs/F5TTS_Base.yaml",
    "chars": 2236,
    "preview": "hydra:\n  run:\n    dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-"
  },
  {
    "path": "src/f5_tts/configs/F5TTS_Small.yaml",
    "chars": 2269,
    "preview": "hydra:\n  run:\n    dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-"
  },
  {
    "path": "src/f5_tts/configs/F5TTS_v1_Base.yaml",
    "chars": 2278,
    "preview": "hydra:\n  run:\n    dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-"
  },
  {
    "path": "src/f5_tts/eval/README.md",
    "chars": 2698,
    "preview": "\n# Evaluation\n\nInstall packages for evaluation:\n\n```bash\npip install -e .[eval]\n```\n\n> [!IMPORTANT]\n> For [faster-whispe"
  },
  {
    "path": "src/f5_tts/eval/ecapa_tdnn.py",
    "chars": 11366,
    "preview": "# just for speaker similarity evaluation, third-party code\n\n# From https://github.com/microsoft/UniSpeech/blob/main/down"
  },
  {
    "path": "src/f5_tts/eval/eval_infer_batch.py",
    "chars": 8044,
    "preview": "import os\nimport sys\n\n\nsys.path.append(os.getcwd())\n\nimport argparse\nimport time\nfrom importlib.resources import files\n\n"
  },
  {
    "path": "src/f5_tts/eval/eval_infer_batch.sh",
    "chars": 4274,
    "preview": "#!/bin/bash\nset -e\nexport PYTHONWARNINGS=\"ignore::UserWarning,ignore::FutureWarning\"\n\n# Configuration parameters\nMODEL_N"
  },
  {
    "path": "src/f5_tts/eval/eval_infer_batch_example.sh",
    "chars": 1394,
    "preview": "#!/bin/bash\n\n# e.g. F5-TTS, 16 NFE\naccelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n \"F5TTS_v1_Base\" -t \"see"
  },
  {
    "path": "src/f5_tts/eval/eval_librispeech_test_clean.py",
    "chars": 3928,
    "preview": "# Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)\n\nimpor"
  },
  {
    "path": "src/f5_tts/eval/eval_seedtts_testset.py",
    "chars": 3663,
    "preview": "# Evaluate with Seed-TTS testset\n\nimport argparse\nimport ast\nimport json\nimport os\nimport sys\n\n\nsys.path.append(os.getcw"
  },
  {
    "path": "src/f5_tts/eval/eval_utmos.py",
    "chars": 1559,
    "preview": "import argparse\nimport json\nfrom pathlib import Path\n\nimport librosa\nimport torch\nfrom tqdm import tqdm\n\n\ndef main():\n  "
  },
  {
    "path": "src/f5_tts/eval/utils_eval.py",
    "chars": 14289,
    "preview": "import math\nimport os\nimport random\nimport string\nfrom pathlib import Path\n\nimport torch\nimport torch.nn.functional as F"
  },
  {
    "path": "src/f5_tts/infer/README.md",
    "chars": 6897,
    "preview": "# Inference\n\nThe pretrained model checkpoints can be reached at [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS) a"
  },
  {
    "path": "src/f5_tts/infer/SHARED.md",
    "chars": 10011,
    "preview": "<!-- omit in toc -->\n# Shared Model Cards\n\n<!-- omit in toc -->\n### **Prerequisites of using**\n- This document is servin"
  },
  {
    "path": "src/f5_tts/infer/examples/basic/basic.toml",
    "chars": 558,
    "preview": "# F5TTS_v1_Base | E2TTS_Base\nmodel = \"F5TTS_v1_Base\"\nref_audio = \"infer/examples/basic/basic_ref_en.wav\"\n# If an empty \""
  },
  {
    "path": "src/f5_tts/infer/examples/multi/story.toml",
    "chars": 562,
    "preview": "# F5TTS_v1_Base | E2TTS_Base\nmodel = \"F5TTS_v1_Base\"\nref_audio = \"infer/examples/multi/main.flac\"\n# If an empty \"\", tran"
  },
  {
    "path": "src/f5_tts/infer/examples/multi/story.txt",
    "chars": 1390,
    "preview": "A Town Mouse and a Country Mouse were acquaintances, and the Country Mouse one day invited his friend to come and see hi"
  },
  {
    "path": "src/f5_tts/infer/examples/vocab.txt",
    "chars": 9330,
    "preview": " \n!\n\"\n#\n$\n%\n&\n'\n(\n)\n*\n+\n,\n-\n.\n/\n0\n1\n2\n3\n4\n5\n6\n7\n8\n9\n:\n;\n=\n>\n?\n@\nA\nB\nC\nD\nE\nF\nG\nH\nI\nJ\nK\nL\nM\nN\nO\nP\nQ\nR\nS\nT\nU\nV\nW\nX\nY\nZ\n[\n\\\n"
  },
  {
    "path": "src/f5_tts/infer/infer_cli.py",
    "chars": 11581,
    "preview": "import argparse\nimport codecs\nimport os\nimport re\nfrom datetime import datetime\nfrom importlib.resources import files\nfr"
  },
  {
    "path": "src/f5_tts/infer/infer_gradio.py",
    "chars": 41898,
    "preview": "# ruff: noqa: E402\n# Above allows ruff to ignore E402: module level import not at top of file\n\nimport gc\nimport json\nimp"
  },
  {
    "path": "src/f5_tts/infer/speech_edit.py",
    "chars": 8098,
    "preview": "import os\n\n\nos.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\"  # for MPS device compatibility\n\nfrom importlib.resources im"
  },
  {
    "path": "src/f5_tts/infer/utils_infer.py",
    "chars": 20069,
    "preview": "# A unified script for inference process\n# Make adjustments inside functions, and consider both gradio and cli scripts i"
  },
  {
    "path": "src/f5_tts/model/__init__.py",
    "chars": 267,
    "preview": "from f5_tts.model.backbones.dit import DiT\nfrom f5_tts.model.backbones.mmdit import MMDiT\nfrom f5_tts.model.backbones.un"
  },
  {
    "path": "src/f5_tts/model/backbones/README.md",
    "chars": 708,
    "preview": "## Backbones quick introduction\n\n\n### unett.py\n- flat unet transformer\n- structure same as in e2-tts & voicebox paper ex"
  },
  {
    "path": "src/f5_tts/model/backbones/dit.py",
    "chars": 12043,
    "preview": "\"\"\"\nein notation:\nb - batch\nn - sequence\nnt - text sequence\nnw - raw wave length\nd - dimension\n\"\"\"\n# ruff: noqa: F722 F8"
  },
  {
    "path": "src/f5_tts/model/backbones/mmdit.py",
    "chars": 7775,
    "preview": "\"\"\"\nein notation:\nb - batch\nn - sequence\nnt - text sequence\nnw - raw wave length\nd - dimension\n\"\"\"\n# ruff: noqa: F722 F8"
  },
  {
    "path": "src/f5_tts/model/backbones/unett.py",
    "chars": 9390,
    "preview": "\"\"\"\nein notation:\nb - batch\nn - sequence\nnt - text sequence\nnw - raw wave length\nd - dimension\n\"\"\"\n# ruff: noqa: F722 F8"
  },
  {
    "path": "src/f5_tts/model/cfm.py",
    "chars": 9756,
    "preview": "\"\"\"\nein notation:\nb - batch\nn - sequence\nnt - text sequence\nnw - raw wave length\nd - dimension\n\"\"\"\n# ruff: noqa: F722 F8"
  },
  {
    "path": "src/f5_tts/model/dataset.py",
    "chars": 11046,
    "preview": "import json\nfrom importlib.resources import files\n\nimport torch\nimport torch.nn.functional as F\nimport torchaudio\nfrom d"
  },
  {
    "path": "src/f5_tts/model/modules.py",
    "chars": 29669,
    "preview": "\"\"\"\nein notation:\nb - batch\nn - sequence\nnt - text sequence\nnw - raw wave length\nd - dimension\n\"\"\"\n# ruff: noqa: F722 F8"
  },
  {
    "path": "src/f5_tts/model/trainer.py",
    "chars": 20585,
    "preview": "from __future__ import annotations\n\nimport gc\nimport math\nimport os\n\nimport torch\nimport torchaudio\nimport wandb\nfrom ac"
  },
  {
    "path": "src/f5_tts/model/utils.py",
    "chars": 7063,
    "preview": "# ruff: noqa: F722 F821\n\nfrom __future__ import annotations\n\nimport os\nimport random\nfrom collections import defaultdict"
  },
  {
    "path": "src/f5_tts/runtime/triton_trtllm/.gitignore",
    "chars": 56,
    "preview": "# runtime/triton_trtllm related\nmodel.cache\nmodel_repo/\n"
  },
  {
    "path": "src/f5_tts/runtime/triton_trtllm/Dockerfile.server",
    "chars": 165,
    "preview": "FROM nvcr.io/nvidia/tritonserver:24.12-py3\nRUN pip install tritonclient[grpc] tensorrt-llm==0.16.0 torchaudio==2.5.1 rji"
  },
  {
    "path": "src/f5_tts/runtime/triton_trtllm/README.md",
    "chars": 2746,
    "preview": "## Triton Inference Serving Best Practice for F5-TTS\n\n### Setup\n#### Option 1: Quick Start\n```sh\n# Directly launch the s"
  },
  {
    "path": "src/f5_tts/runtime/triton_trtllm/benchmark.py",
    "chars": 18027,
    "preview": "# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)\n#               2025                (authors: Yuekai Zhang)"
  },
  {
    "path": "src/f5_tts/runtime/triton_trtllm/client_grpc.py",
    "chars": 16352,
    "preview": "#!/usr/bin/env python3\n# Copyright      2022  Xiaomi Corp.        (authors: Fangjun Kuang)\n#                2023  Nvidia"
  },
  {
    "path": "src/f5_tts/runtime/triton_trtllm/client_http.py",
    "chars": 5184,
    "preview": "# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n#\n# Redistribution and use in source and binary "
  },
  {
    "path": "src/f5_tts/runtime/triton_trtllm/docker-compose.yml",
    "chars": 582,
    "preview": "services:\n  tts:\n    image: soar97/triton-f5-tts:24.12\n    shm_size: '1gb'\n    ports:\n      - \"8000:8000\"\n      - \"8001:"
  },
  {
    "path": "src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py",
    "chars": 20029,
    "preview": "import math\nimport os\nimport time\nfrom functools import wraps\nfrom typing import List, Optional\n\nimport tensorrt as trt\n"
  },
  {
    "path": "src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py",
    "chars": 11778,
    "preview": "# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n#\n# Redistribution and use in source and binary "
  },
  {
    "path": "src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/config.pbtxt",
    "chars": 1590,
    "preview": "# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the "
  },
  {
    "path": "src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/vocoder/1/.gitkeep",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/vocoder/config.pbtxt",
    "chars": 412,
    "preview": "name: \"vocoder\"\nbackend: \"tensorrt\"\ndefault_model_filename: \"vocoder.plan\"\nmax_batch_size: 4\n\ninput [\n  {\n    name: \"mel"
  },
  {
    "path": "src/f5_tts/runtime/triton_trtllm/patch/__init__.py",
    "chars": 7142,
    "preview": "# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-I"
  },
  {
    "path": "src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py",
    "chars": 8661,
    "preview": "from __future__ import annotations\n\nimport os\nimport sys\nfrom collections import OrderedDict\n\nimport numpy as np\nimport "
  },
  {
    "path": "src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py",
    "chars": 15363,
    "preview": "from __future__ import annotations\n\nimport math\nfrom typing import Optional\n\nimport numpy as np\nimport torch\nimport torc"
  },
  {
    "path": "src/f5_tts/runtime/triton_trtllm/run.sh",
    "chars": 4437,
    "preview": "stage=$1\nstop_stage=$2\nmodel=$3  # F5TTS_v1_Base | F5TTS_Base | F5TTS_v1_Small | F5TTS_Small\nif [ -z \"$model\" ]; then\n  "
  },
  {
    "path": "src/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py",
    "chars": 10136,
    "preview": "# Modified from https://github.com/echocatzh/conv-stft/blob/master/conv_stft/conv_stft.py\n\n# Copyright (c) 2024, NVIDIA "
  },
  {
    "path": "src/f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py",
    "chars": 11227,
    "preview": "import argparse\nimport json\nimport os\nimport re\nimport time\nimport traceback\nfrom concurrent.futures import ThreadPoolEx"
  },
  {
    "path": "src/f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py",
    "chars": 4713,
    "preview": "# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the "
  },
  {
    "path": "src/f5_tts/runtime/triton_trtllm/scripts/export_vocos_trt.sh",
    "chars": 1443,
    "preview": "#!/bin/bash\n# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Versi"
  },
  {
    "path": "src/f5_tts/runtime/triton_trtllm/scripts/fill_template.py",
    "chars": 1120,
    "preview": "#! /usr/bin/env python3\nfrom argparse import ArgumentParser\nfrom string import Template\n\n\ndef main(file_path, substituti"
  },
  {
    "path": "src/f5_tts/scripts/count_max_epoch.py",
    "chars": 1105,
    "preview": "\"\"\"ADAPTIVE BATCH SIZE\"\"\"\n\nprint(\"Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in\")\nprint"
  },
  {
    "path": "src/f5_tts/scripts/count_max_epoch_precise.py",
    "chars": 832,
    "preview": "import math\n\nfrom torch.utils.data import SequentialSampler\n\nfrom f5_tts.model.dataset import DynamicBatchSampler, load_"
  },
  {
    "path": "src/f5_tts/scripts/count_params_gflops.py",
    "chars": 1359,
    "preview": "import os\nimport sys\n\n\nsys.path.append(os.getcwd())\n\nimport thop\nimport torch\n\nfrom f5_tts.model import CFM, DiT\n\n\n\"\"\" ~"
  },
  {
    "path": "src/f5_tts/socket_client.py",
    "chars": 2019,
    "preview": "import asyncio\nimport logging\nimport socket\nimport time\n\nimport numpy as np\nimport pyaudio\n\n\nlogging.basicConfig(level=l"
  },
  {
    "path": "src/f5_tts/socket_server.py",
    "chars": 9070,
    "preview": "import argparse\nimport gc\nimport logging\nimport queue\nimport socket\nimport struct\nimport threading\nimport traceback\nimpo"
  },
  {
    "path": "src/f5_tts/train/README.md",
    "chars": 3390,
    "preview": "# Training\n\nCheck your FFmpeg installation:\n```bash\nffmpeg -version\n```\nIf not found, install it first (or skip assuming"
  },
  {
    "path": "src/f5_tts/train/datasets/prepare_csv_wavs.py",
    "chars": 10987,
    "preview": "\"\"\"\nUsage:\n    python prepare_csv_wavs.py /path/to/metadata.csv /output/dataset/path [--pretrain] [--workers N]\n\nCSV for"
  },
  {
    "path": "src/f5_tts/train/datasets/prepare_emilia.py",
    "chars": 7414,
    "preview": "# Emilia Dataset: https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07\n# if use updated new version, i.e."
  },
  {
    "path": "src/f5_tts/train/datasets/prepare_emilia_v2.py",
    "chars": 3240,
    "preview": "# put in src/f5_tts/train/datasets/prepare_emilia_v2.py\n# prepares Emilia dataset with the new format w/ Emilia-YODAS\n\ni"
  },
  {
    "path": "src/f5_tts/train/datasets/prepare_libritts.py",
    "chars": 3181,
    "preview": "import os\nimport sys\n\n\nsys.path.append(os.getcwd())\n\nimport json\nfrom concurrent.futures import ProcessPoolExecutor\nfrom"
  },
  {
    "path": "src/f5_tts/train/datasets/prepare_ljspeech.py",
    "chars": 2276,
    "preview": "import os\nimport sys\n\n\nsys.path.append(os.getcwd())\n\nimport json\nfrom importlib.resources import files\nfrom pathlib impo"
  },
  {
    "path": "src/f5_tts/train/datasets/prepare_wenetspeech4tts.py",
    "chars": 4581,
    "preview": "# generate audio text map for WenetSpeech4TTS\n# evaluate for vocab size\n\nimport os\nimport sys\n\n\nsys.path.append(os.getcw"
  },
  {
    "path": "src/f5_tts/train/finetune_cli.py",
    "chars": 7593,
    "preview": "import argparse\nimport os\nimport shutil\nfrom importlib.resources import files\n\nfrom cached_path import cached_path\n\nfrom"
  },
  {
    "path": "src/f5_tts/train/finetune_gradio.py",
    "chars": 67656,
    "preview": "import gc\nimport json\nimport os\nimport platform\nimport queue\nimport random\nimport re\nimport shutil\nimport signal\nimport "
  },
  {
    "path": "src/f5_tts/train/train.py",
    "chars": 3126,
    "preview": "# training script.\n\nimport os\nfrom importlib.resources import files\n\nimport hydra\nfrom omegaconf import OmegaConf\n\nfrom "
  }
]

// ... and 3 more files (download for full content)

About this extraction

This page contains the full source code of the SWivid/F5-TTS GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 90 files (571.3 KB), approximately 152.7k tokens, and a symbol index with 390 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!