Repository: SWivid/F5-TTS Branch: main Commit: 623c96c29496 Files: 90 Total size: 571.3 KB Directory structure: gitextract_4jvakfwh/ ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ ├── bug_report.yml │ │ ├── config.yml │ │ ├── feature_request.yml │ │ ├── help_wanted.yml │ │ └── question.yml │ └── workflows/ │ ├── pre-commit.yaml │ ├── publish-docker-image.yaml │ └── publish-pypi.yaml ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── Dockerfile ├── LICENSE ├── README.md ├── pyproject.toml ├── ruff.toml └── src/ └── f5_tts/ ├── api.py ├── configs/ │ ├── E2TTS_Base.yaml │ ├── E2TTS_Small.yaml │ ├── F5TTS_Base.yaml │ ├── F5TTS_Small.yaml │ └── F5TTS_v1_Base.yaml ├── eval/ │ ├── README.md │ ├── ecapa_tdnn.py │ ├── eval_infer_batch.py │ ├── eval_infer_batch.sh │ ├── eval_infer_batch_example.sh │ ├── eval_librispeech_test_clean.py │ ├── eval_seedtts_testset.py │ ├── eval_utmos.py │ └── utils_eval.py ├── infer/ │ ├── README.md │ ├── SHARED.md │ ├── examples/ │ │ ├── basic/ │ │ │ └── basic.toml │ │ ├── multi/ │ │ │ ├── country.flac │ │ │ ├── main.flac │ │ │ ├── story.toml │ │ │ ├── story.txt │ │ │ └── town.flac │ │ └── vocab.txt │ ├── infer_cli.py │ ├── infer_gradio.py │ ├── speech_edit.py │ └── utils_infer.py ├── model/ │ ├── __init__.py │ ├── backbones/ │ │ ├── README.md │ │ ├── dit.py │ │ ├── mmdit.py │ │ └── unett.py │ ├── cfm.py │ ├── dataset.py │ ├── modules.py │ ├── trainer.py │ └── utils.py ├── runtime/ │ └── triton_trtllm/ │ ├── .gitignore │ ├── Dockerfile.server │ ├── README.md │ ├── benchmark.py │ ├── client_grpc.py │ ├── client_http.py │ ├── docker-compose.yml │ ├── model_repo_f5_tts/ │ │ ├── f5_tts/ │ │ │ ├── 1/ │ │ │ │ ├── f5_tts_trtllm.py │ │ │ │ └── model.py │ │ │ └── config.pbtxt │ │ └── vocoder/ │ │ ├── 1/ │ │ │ └── .gitkeep │ │ └── config.pbtxt │ ├── patch/ │ │ ├── __init__.py │ │ └── f5tts/ │ │ ├── model.py │ │ └── modules.py │ ├── run.sh │ └── scripts/ │ ├── conv_stft.py │ ├── convert_checkpoint.py │ ├── export_vocoder_to_onnx.py │ ├── export_vocos_trt.sh │ └── fill_template.py ├── scripts/ │ ├── count_max_epoch.py │ ├── count_max_epoch_precise.py │ └── count_params_gflops.py ├── socket_client.py ├── socket_server.py └── train/ ├── README.md ├── datasets/ │ ├── prepare_csv_wavs.py │ ├── prepare_emilia.py │ ├── prepare_emilia_v2.py │ ├── prepare_libritts.py │ ├── prepare_ljspeech.py │ └── prepare_wenetspeech4tts.py ├── finetune_cli.py ├── finetune_gradio.py └── train.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.yml ================================================ name: "Bug Report" description: | Please provide as much details to help address the issue more efficiently, including input, output, logs and screenshots. labels: - bug body: - type: checkboxes attributes: label: Checks description: "To ensure timely help, please confirm the following:" options: - label: This template is only for bug reports, usage problems go with 'Help Wanted'. required: true - label: I have thoroughly reviewed the project documentation but couldn't find information to solve my problem. required: true - label: I have searched for existing issues, including closed ones, and couldn't find a solution. required: true - label: I am using English to submit this issue to facilitate community communication. required: true - type: textarea attributes: label: Environment Details description: "Provide details including OS, GPU info, Python version, any relevant software or dependencies, and trainer setting." placeholder: e.g., CentOS Linux 7, 4 * RTX 3090, Python 3.10, torch==2.3.0+cu118, cuda 11.8, config yaml is ... validations: required: true - type: textarea attributes: label: Steps to Reproduce description: | Include detailed steps, screenshots, and logs. Use the correct markdown syntax for code blocks. placeholder: | 1. Create a new conda environment. 2. Clone the repository, install as local editable and properly set up. 3. Run the command: `accelerate launch src/f5_tts/train/train.py`. 4. Have following error message... (attach logs). validations: required: true - type: textarea attributes: label: ✔️ Expected Behavior placeholder: Describe in detail what you expected to happen. validations: required: false - type: textarea attributes: label: ❌ Actual Behavior placeholder: Describe in detail what actually happened. validations: required: false ================================================ FILE: .github/ISSUE_TEMPLATE/config.yml ================================================ blank_issues_enabled: false ================================================ FILE: .github/ISSUE_TEMPLATE/feature_request.yml ================================================ name: "Feature Request" description: | Some constructive suggestions and new ideas regarding current repo. labels: - enhancement body: - type: checkboxes attributes: label: Checks description: "To help us grasp quickly, please confirm the following:" options: - label: This template is only for feature request. required: true - label: I have thoroughly reviewed the project documentation but couldn't find any relevant information that meets my needs. required: true - label: I have searched for existing issues, including closed ones, and found not discussion yet. required: true - label: I am using English to submit this issue to facilitate community communication. required: true - type: textarea attributes: label: 1. Is this request related to a challenge you're experiencing? Tell us your story. description: | Describe the specific problem or scenario you're facing in detail. For example: *"I was trying to use [feature] for [specific task], but encountered [issue]. This was frustrating because...."* placeholder: Please describe the situation in as much detail as possible. validations: required: true - type: textarea attributes: label: 2. What is your suggested solution? description: | Provide a clear description of the feature or enhancement you'd like to propose. How would this feature solve your issue or improve the project? placeholder: Describe your idea or proposed solution here. validations: required: true - type: textarea attributes: label: 3. Additional context or comments description: | Any other relevant information, links, documents, or screenshots that provide clarity. Use this section for anything not covered above. placeholder: Add any extra details here. validations: required: false - type: checkboxes attributes: label: 4. Can you help us with this feature? description: | Let us know if you're interested in contributing. This is not a commitment but a way to express interest in collaboration. options: - label: I am interested in contributing to this feature. required: false - type: markdown attributes: value: | **Note:** Please submit only one request per issue to keep discussions focused and manageable. ================================================ FILE: .github/ISSUE_TEMPLATE/help_wanted.yml ================================================ name: "Help Wanted" description: | Please provide as much details to help address the issue more efficiently, including input, output, logs and screenshots. labels: - help wanted body: - type: checkboxes attributes: label: Checks description: "To ensure timely help, please confirm the following:" options: - label: This template is only for usage issues encountered. required: true - label: I have thoroughly reviewed the project documentation but couldn't find information to solve my problem. required: true - label: I have searched for existing issues, including closed ones, and couldn't find a solution. required: true - label: I am using English to submit this issue to facilitate community communication. required: true - type: textarea attributes: label: Environment Details description: "Provide details such as OS, Python version, and any relevant software or dependencies." placeholder: | e.g., macOS 13.5, Python 3.10, torch==2.3.0, Gradio 4.44.1 If training or finetuning related, provide detailed configuration including GPU info and training setup. validations: required: true - type: textarea attributes: label: Steps to Reproduce description: | Include detailed steps, screenshots, and logs. Provide used prompt wav and text. Use the correct markdown syntax for code blocks. placeholder: | 1. Create a new conda environment. 2. Clone the repository and install as pip package. 3. Run the command: `f5-tts_infer-gradio` with no ref_text provided. 4. Stuck there with the following message... (attach logs and also error msg e.g. after ctrl-c). 5. Prompt & generated wavs are [change suffix to .mp4 to enable direct upload or pack all to .zip]. 6. Reference audio's transcription or provided ref_text is `xxx`, and text to generate is `xxx`. validations: required: true - type: textarea attributes: label: ✔️ Expected Behavior placeholder: Describe what you expected to happen in detail, e.g. output a generated audio. validations: required: false - type: textarea attributes: label: ❌ Actual Behavior placeholder: Describe what actually happened in detail, failure messages, etc. validations: required: false ================================================ FILE: .github/ISSUE_TEMPLATE/question.yml ================================================ name: "Question" description: | Research question or pure inquiry about the project, usage issue goes with "help wanted". labels: - question body: - type: checkboxes attributes: label: Checks description: "To help us grasp quickly, please confirm the following:" options: - label: This template is only for research question, not usage problems, feature requests or bug reports. required: true - label: I have thoroughly reviewed the project documentation and read the related paper(s). required: true - label: I have searched for existing issues, including closed ones, no similar questions. required: true - label: I am using English to submit this issue to facilitate community communication. required: true - type: textarea attributes: label: Question details description: | Question details, clearly stated using proper markdown syntax. validations: required: true ================================================ FILE: .github/workflows/pre-commit.yaml ================================================ name: pre-commit on: pull_request: push: branches: [main] jobs: pre-commit: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v3 - uses: pre-commit/action@v3.0.1 ================================================ FILE: .github/workflows/publish-docker-image.yaml ================================================ name: Create and publish a Docker image # Configures this workflow to run every time a change is pushed to the branch called `release`. on: push: branches: ['main'] # Defines two custom environment variables for the workflow. These are used for the Container registry domain, and a name for the Docker image that this workflow builds. env: REGISTRY: ghcr.io IMAGE_NAME: ${{ github.repository }} # There is a single job in this workflow. It's configured to run on the latest available version of Ubuntu. jobs: build-and-push-image: runs-on: ubuntu-latest # Sets the permissions granted to the `GITHUB_TOKEN` for the actions in this job. permissions: contents: read packages: write # steps: - name: Checkout repository uses: actions/checkout@v4 - name: Free Up GitHub Actions Ubuntu Runner Disk Space 🔧 uses: jlumbroso/free-disk-space@main with: # This might remove tools that are actually needed, if set to "true" but frees about 6 GB tool-cache: false # All of these default to true, but feel free to set to "false" if necessary for your workflow android: true dotnet: true haskell: true large-packages: false swap-storage: false docker-images: false # Uses the `docker/login-action` action to log in to the Container registry registry using the account and password that will publish the packages. Once published, the packages are scoped to the account defined here. - name: Log in to the Container registry uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1 with: registry: ${{ env.REGISTRY }} username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} # This step uses [docker/metadata-action](https://github.com/docker/metadata-action#about) to extract tags and labels that will be applied to the specified image. The `id` "meta" allows the output of this step to be referenced in a subsequent step. The `images` value provides the base name for the tags and labels. - name: Extract metadata (tags, labels) for Docker id: meta uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7 with: images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} # This step uses the `docker/build-push-action` action to build the image, based on your repository's `Dockerfile`. If the build succeeds, it pushes the image to GitHub Packages. # It uses the `context` parameter to define the build's context as the set of files located in the specified path. For more information, see "[Usage](https://github.com/docker/build-push-action#usage)" in the README of the `docker/build-push-action` repository. # It uses the `tags` and `labels` parameters to tag and label the image with the output from the "meta" step. - name: Build and push Docker image uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4 with: context: . push: true tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} ================================================ FILE: .github/workflows/publish-pypi.yaml ================================================ # This workflow uses actions that are not certified by GitHub. # They are provided by a third-party and are governed by # separate terms of service, privacy policy, and support # documentation. # GitHub recommends pinning actions to a commit SHA. # To get a newer version, you will need to update the SHA. # You can also reference a tag or branch, but the action may change without warning. name: Upload Python Package on: release: types: [published] permissions: contents: read jobs: release-build: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: python-version: "3.x" - name: Build release distributions run: | # NOTE: put your own distribution build steps here. python -m pip install build python -m build - name: Upload distributions uses: actions/upload-artifact@v4 with: name: release-dists path: dist/ pypi-publish: runs-on: ubuntu-latest needs: - release-build permissions: # IMPORTANT: this permission is mandatory for trusted publishing id-token: write # Dedicated environments with protections for publishing are strongly recommended. environment: name: pypi # OPTIONAL: uncomment and update to include your PyPI project URL in the deployment status: # url: https://pypi.org/p/YOURPROJECT steps: - name: Retrieve release distributions uses: actions/download-artifact@v4 with: name: release-dists path: dist/ - name: Publish release distributions to PyPI uses: pypa/gh-action-pypi-publish@release/v1 ================================================ FILE: .gitignore ================================================ # Customed .vscode/ tests/ runs/ data/ ckpts/ wandb/ results/ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # poetry # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control #poetry.lock # pdm # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. #pdm.lock # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it # in version control. # https://pdm.fming.dev/latest/usage/project/#working-with-version-control .pdm.toml .pdm-python .pdm-build/ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ ================================================ FILE: .gitmodules ================================================ [submodule "src/third_party/BigVGAN"] path = src/third_party/BigVGAN url = https://github.com/NVIDIA/BigVGAN.git ================================================ FILE: .pre-commit-config.yaml ================================================ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. rev: v0.11.2 hooks: - id: ruff name: ruff linter args: [--fix] - id: ruff-format name: ruff formatter - id: ruff name: ruff sorter args: [--select, I, --fix] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v5.0.0 hooks: - id: check-yaml ================================================ FILE: Dockerfile ================================================ FROM pytorch/pytorch:2.4.0-cuda12.4-cudnn9-devel USER root ARG DEBIAN_FRONTEND=noninteractive LABEL github_repo="https://github.com/SWivid/F5-TTS" RUN set -x \ && apt-get update \ && apt-get -y install wget curl man git less openssl libssl-dev unzip unar build-essential aria2 tmux vim \ && apt-get install -y openssh-server sox libsox-fmt-all libsox-fmt-mp3 libsndfile1-dev ffmpeg \ && apt-get install -y librdmacm1 libibumad3 librdmacm-dev libibverbs1 libibverbs-dev ibverbs-utils ibverbs-providers \ && rm -rf /var/lib/apt/lists/* \ && apt-get clean WORKDIR /workspace RUN git clone https://github.com/SWivid/F5-TTS.git \ && cd F5-TTS \ && git submodule update --init --recursive \ && pip install -e . --no-cache-dir ENV SHELL=/bin/bash VOLUME /root/.cache/huggingface/hub/ EXPOSE 7860 WORKDIR /workspace/F5-TTS ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2024 Yushen CHEN Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching [![python](https://img.shields.io/badge/Python-3.10-brightgreen)](https://github.com/SWivid/F5-TTS) [![arXiv](https://img.shields.io/badge/arXiv-2410.06885-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2410.06885) [![demo](https://img.shields.io/badge/GitHub-Demo-orange.svg)](https://swivid.github.io/F5-TTS/) [![hfspace](https://img.shields.io/badge/🤗-HF%20Space-yellow)](https://huggingface.co/spaces/mrfakename/E2-F5-TTS) [![msspace](https://img.shields.io/badge/🤖-MS%20Space-blue)](https://modelscope.cn/studios/AI-ModelScope/E2-F5-TTS) [![lab](https://img.shields.io/badge/🏫-X--LANCE-grey?labelColor=lightgrey)](https://x-lance.sjtu.edu.cn/) [![lab](https://img.shields.io/badge/🏫-SII-grey?labelColor=lightgrey)](https://www.sii.edu.cn/) [![lab](https://img.shields.io/badge/🏫-PCL-grey?labelColor=lightgrey)](https://www.pcl.ac.cn) **F5-TTS**: Diffusion Transformer with ConvNeXt V2, faster trained and inference. **E2 TTS**: Flat-UNet Transformer, closest reproduction from [paper](https://arxiv.org/abs/2406.18009). **Sway Sampling**: Inference-time flow step sampling strategy, greatly improves performance ### Thanks to all the contributors ! ## News - **2025/03/12**: 🔥 F5-TTS v1 base model with better training and inference performance. [Few demo](https://swivid.github.io/F5-TTS_updates). - **2024/10/08**: F5-TTS & E2 TTS base models on [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS), [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), [🟣 Wisemodel](https://wisemodel.cn/models/SJTU_X-LANCE/F5-TTS_Emilia-ZH-EN). ## Installation ### Create a separate environment if needed ```bash # Create a conda env with python_version>=3.10 (you could also use virtualenv) conda create -n f5-tts python=3.11 conda activate f5-tts # Install FFmpeg if you haven't yet conda install ffmpeg ``` ### Install PyTorch with matched device
NVIDIA GPU > ```bash > # Install pytorch with your CUDA version, e.g. > pip install torch==2.8.0+cu128 torchaudio==2.8.0+cu128 --extra-index-url https://download.pytorch.org/whl/cu128 > > # And also possible previous versions, e.g. > pip install torch==2.4.0+cu124 torchaudio==2.4.0+cu124 --extra-index-url https://download.pytorch.org/whl/cu124 > # etc. > ```
AMD GPU > ```bash > # Install pytorch with your ROCm version (Linux only), e.g. > pip install torch==2.5.1+rocm6.2 torchaudio==2.5.1+rocm6.2 --extra-index-url https://download.pytorch.org/whl/rocm6.2 > ```
Intel GPU > ```bash > # Install pytorch with your XPU version, e.g. > # Intel® Deep Learning Essentials or Intel® oneAPI Base Toolkit must be installed > pip install torch torchaudio --index-url https://download.pytorch.org/whl/test/xpu > > # Intel GPU support is also available through IPEX (Intel® Extension for PyTorch) > # IPEX does not require the Intel® Deep Learning Essentials or Intel® oneAPI Base Toolkit > # See: https://pytorch-extension.intel.com/installation?request=platform > ```
Apple Silicon > ```bash > # Install the stable pytorch, e.g. > pip install torch torchaudio > ```
### Then you can choose one from below: > ### 1. As a pip package (if just for inference) > > ```bash > pip install f5-tts > ``` > > ### 2. Local editable (if also do training, finetuning) > > ```bash > git clone https://github.com/SWivid/F5-TTS.git > cd F5-TTS > # git submodule update --init --recursive # (optional, if use bigvgan as vocoder) > pip install -e . > ``` ### Docker usage also available ```bash # Build from Dockerfile docker build -t f5tts:v1 . # Run from GitHub Container Registry docker container run --rm -it --gpus=all --mount 'type=volume,source=f5-tts,target=/root/.cache/huggingface/hub/' -p 7860:7860 ghcr.io/swivid/f5-tts:main # Quickstart if you want to just run the web interface (not CLI) docker container run --rm -it --gpus=all --mount 'type=volume,source=f5-tts,target=/root/.cache/huggingface/hub/' -p 7860:7860 ghcr.io/swivid/f5-tts:main f5-tts_infer-gradio --host 0.0.0.0 ``` ### Runtime Deployment solution with Triton and TensorRT-LLM. #### Benchmark Results Decoding on a single L20 GPU, using 26 different prompt_audio & target_text pairs, 16 NFE. | Model | Concurrency | Avg Latency | RTF | Mode | |---------------------|----------------|-------------|--------|-----------------| | F5-TTS Base (Vocos) | 2 | 253 ms | 0.0394 | Client-Server | | F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.0402 | Offline TRT-LLM | | F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.1467 | Offline Pytorch | See [detailed instructions](src/f5_tts/runtime/triton_trtllm/README.md) for more information. ## Inference - In order to achieve desired performance, take a moment to read [detailed guidance](src/f5_tts/infer). - By properly searching the keywords of problem encountered, [issues](https://github.com/SWivid/F5-TTS/issues?q=is%3Aissue) are very helpful. ### 1. Gradio App Currently supported features: - Basic TTS with Chunk Inference - Multi-Style / Multi-Speaker Generation - Voice Chat powered by Qwen2.5-3B-Instruct - [Custom inference with more language support](src/f5_tts/infer/SHARED.md) ```bash # Launch a Gradio app (web interface) f5-tts_infer-gradio # Specify the port/host f5-tts_infer-gradio --port 7860 --host 0.0.0.0 # Launch a share link f5-tts_infer-gradio --share ```
NVIDIA device docker compose file example ```yaml services: f5-tts: image: ghcr.io/swivid/f5-tts:main ports: - "7860:7860" environment: GRADIO_SERVER_PORT: 7860 entrypoint: ["f5-tts_infer-gradio", "--port", "7860", "--host", "0.0.0.0"] deploy: resources: reservations: devices: - driver: nvidia count: 1 capabilities: [gpu] volumes: f5-tts: driver: local ```
### 2. CLI Inference ```bash # Run with flags # Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage) f5-tts_infer-cli --model F5TTS_v1_Base \ --ref_audio "provide_prompt_wav_path_here.wav" \ --ref_text "The content, subtitle or transcription of reference audio." \ --gen_text "Some text you want TTS model generate for you." # Run with default setting. src/f5_tts/infer/examples/basic/basic.toml f5-tts_infer-cli # Or with your own .toml file f5-tts_infer-cli -c custom.toml # Multi voice. See src/f5_tts/infer/README.md f5-tts_infer-cli -c src/f5_tts/infer/examples/multi/story.toml ``` ## Training ### 1. With Hugging Face Accelerate Refer to [training & finetuning guidance](src/f5_tts/train) for best practice. ### 2. With Gradio App ```bash # Quick start with Gradio web interface f5-tts_finetune-gradio ``` Read [training & finetuning guidance](src/f5_tts/train) for more instructions. ## [Evaluation](src/f5_tts/eval) ## Development Use pre-commit to ensure code quality (will run linters and formatters automatically): ```bash pip install pre-commit pre-commit install ``` When making a pull request, before each commit, run: ```bash pre-commit run --all-files ``` Note: Some model components have linting exceptions for E722 to accommodate tensor notation. ## Acknowledgements - [E2-TTS](https://arxiv.org/abs/2406.18009) brilliant work, simple and effective - [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763), [LibriTTS](https://arxiv.org/abs/1904.02882), [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) valuable datasets - [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion - [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure - [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) and [BigVGAN](https://github.com/NVIDIA/BigVGAN) as vocoder - [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech), [SpeechMOS](https://github.com/tarepan/SpeechMOS) for evaluation tools - [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test - [mrfakename](https://x.com/realmrfakename) huggingface space demo ~ - [f5-tts-mlx](https://github.com/lucasnewman/f5-tts-mlx/tree/main) Implementation with MLX framework by [Lucas Newman](https://github.com/lucasnewman) - [F5-TTS-ONNX](https://github.com/DakeQQ/F5-TTS-ONNX) ONNX Runtime version by [DakeQQ](https://github.com/DakeQQ) - [Yuekai Zhang](https://github.com/yuekaizhang) Triton and TensorRT-LLM support ~ ## Citation If our work and codebase is useful for you, please cite as: ``` @article{chen-etal-2024-f5tts, title={F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching}, author={Yushen Chen and Zhikang Niu and Ziyang Ma and Keqi Deng and Chunhui Wang and Jian Zhao and Kai Yu and Xie Chen}, journal={arXiv preprint arXiv:2410.06885}, year={2024}, } ``` ## License Our code is released under MIT License. The pre-trained models are licensed under the CC-BY-NC license due to the training data Emilia, which is an in-the-wild dataset. Sorry for any inconvenience this may cause. ================================================ FILE: pyproject.toml ================================================ [build-system] requires = ["setuptools >= 61.0", "setuptools-scm>=8.0"] build-backend = "setuptools.build_meta" [project] name = "f5-tts" version = "1.1.17" description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching" readme = "README.md" license = {text = "MIT License"} classifiers = [ "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", "Programming Language :: Python :: 3", ] dependencies = [ "accelerate>=0.33.0", "bitsandbytes>0.37.0; platform_machine!='arm64' and platform_system!='Darwin'", "cached_path", "click", "datasets", "ema_pytorch>=0.5.2", "gradio>=6.0.0", "hydra-core>=1.3.0", "librosa", "matplotlib", "numpy<=1.26.4; python_version<='3.10'", "pydub", "pypinyin", "rjieba", "safetensors", "soundfile", "tomli", "torch>=2.0.0", "torchaudio>=2.0.0", "torchcodec", "torchdiffeq", "tqdm>=4.65.0", "transformers", "transformers_stream_generator", "unidecode", "vocos", "wandb", "x_transformers>=1.31.14", ] [project.optional-dependencies] eval = [ "faster_whisper==0.10.1", "funasr", "jiwer", "modelscope", "zhconv", "zhon", ] [project.urls] Homepage = "https://github.com/SWivid/F5-TTS" [project.scripts] "f5-tts_infer-cli" = "f5_tts.infer.infer_cli:main" "f5-tts_infer-gradio" = "f5_tts.infer.infer_gradio:main" "f5-tts_finetune-cli" = "f5_tts.train.finetune_cli:main" "f5-tts_finetune-gradio" = "f5_tts.train.finetune_gradio:main" ================================================ FILE: ruff.toml ================================================ line-length = 120 target-version = "py310" [lint] # Only ignore variables with names starting with "_". dummy-variable-rgx = "^_.*$" [lint.isort] force-single-line = false lines-after-imports = 2 ================================================ FILE: src/f5_tts/api.py ================================================ import random import sys from importlib.resources import files import soundfile as sf import tqdm from cached_path import cached_path from hydra.utils import get_class from omegaconf import OmegaConf from f5_tts.infer.utils_infer import ( infer_process, load_model, load_vocoder, preprocess_ref_audio_text, remove_silence_for_generated_wav, save_spectrogram, transcribe, ) from f5_tts.model.utils import seed_everything class F5TTS: def __init__( self, model="F5TTS_v1_Base", ckpt_file="", vocab_file="", ode_method="euler", use_ema=True, vocoder_local_path=None, device=None, hf_cache_dir=None, ): model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml"))) model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}") model_arc = model_cfg.model.arch self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type self.target_sample_rate = model_cfg.model.mel_spec.target_sample_rate self.ode_method = ode_method self.use_ema = use_ema if device is not None: self.device = device else: import torch self.device = ( "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) # Load models self.vocoder = load_vocoder( self.mel_spec_type, vocoder_local_path is not None, vocoder_local_path, self.device, hf_cache_dir ) repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors" # override for previous models if model == "F5TTS_Base": if self.mel_spec_type == "vocos": ckpt_step = 1200000 elif self.mel_spec_type == "bigvgan": model = "F5TTS_Base_bigvgan" ckpt_type = "pt" elif model == "E2TTS_Base": repo_name = "E2-TTS" ckpt_step = 1200000 if not ckpt_file: ckpt_file = str( cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}", cache_dir=hf_cache_dir) ) self.ema_model = load_model( model_cls, model_arc, ckpt_file, self.mel_spec_type, vocab_file, self.ode_method, self.use_ema, self.device ) def transcribe(self, ref_audio, language=None): return transcribe(ref_audio, language) def export_wav(self, wav, file_wave, remove_silence=False): sf.write(file_wave, wav, self.target_sample_rate) if remove_silence: remove_silence_for_generated_wav(file_wave) def export_spectrogram(self, spec, file_spec): save_spectrogram(spec, file_spec) def infer( self, ref_file, ref_text, gen_text, show_info=print, progress=tqdm, target_rms=0.1, cross_fade_duration=0.15, sway_sampling_coef=-1, cfg_strength=2, nfe_step=32, speed=1.0, fix_duration=None, remove_silence=False, file_wave=None, file_spec=None, seed=None, ): if seed is None: seed = random.randint(0, sys.maxsize) seed_everything(seed) self.seed = seed ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, show_info=show_info) wav, sr, spec = infer_process( ref_file, ref_text, gen_text, self.ema_model, self.vocoder, self.mel_spec_type, show_info=show_info, progress=progress, target_rms=target_rms, cross_fade_duration=cross_fade_duration, nfe_step=nfe_step, cfg_strength=cfg_strength, sway_sampling_coef=sway_sampling_coef, speed=speed, fix_duration=fix_duration, device=self.device, ) if file_wave is not None: self.export_wav(wav, file_wave, remove_silence) if file_spec is not None: self.export_spectrogram(spec, file_spec) return wav, sr, spec if __name__ == "__main__": f5tts = F5TTS() wav, sr, spec = f5tts.infer( ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")), ref_text="Some call me nature, others call me mother nature.", gen_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring.", file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")), file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")), seed=None, ) print("seed :", f5tts.seed) ================================================ FILE: src/f5_tts/configs/E2TTS_Base.yaml ================================================ hydra: run: dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} datasets: name: Emilia_ZH_EN # dataset name batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 batch_size_type: frame # frame | sample max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models num_workers: 16 optim: epochs: 11 learning_rate: 7.5e-5 num_warmup_updates: 20000 # warmup updates grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps max_grad_norm: 1.0 # gradient clipping bnb_optimizer: False # use bnb 8bit AdamW optimizer or not model: name: E2TTS_Base tokenizer: pinyin tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt) backbone: UNetT arch: dim: 1024 depth: 24 heads: 16 ff_mult: 4 text_mask_padding: False pe_attn_head: 1 mel_spec: target_sample_rate: 24000 n_mel_channels: 100 hop_length: 256 win_length: 1024 n_fft: 1024 mel_spec_type: vocos # vocos | bigvgan vocoder: is_local: False # use local offline ckpt or not local_path: null # local vocoder path ckpts: logger: wandb # wandb | tensorboard | null wandb_project: CFM-TTS # wandb project name wandb_run_name: ${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} # wandb run name wandb_resume_id: null # wandb run id for resuming, null to auto-detect from checkpoint log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples save_per_updates: 50000 # save checkpoint per updates keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints last_per_updates: 5000 # save last checkpoint per updates save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} ================================================ FILE: src/f5_tts/configs/E2TTS_Small.yaml ================================================ hydra: run: dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} datasets: name: Emilia_ZH_EN batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 batch_size_type: frame # frame | sample max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models num_workers: 16 optim: epochs: 11 learning_rate: 7.5e-5 num_warmup_updates: 20000 # warmup updates grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps max_grad_norm: 1.0 bnb_optimizer: False model: name: E2TTS_Small tokenizer: pinyin tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt) backbone: UNetT arch: dim: 768 depth: 20 heads: 12 ff_mult: 4 text_mask_padding: False pe_attn_head: 1 mel_spec: target_sample_rate: 24000 n_mel_channels: 100 hop_length: 256 win_length: 1024 n_fft: 1024 mel_spec_type: vocos # vocos | bigvgan vocoder: is_local: False # use local offline ckpt or not local_path: null # local vocoder path ckpts: logger: wandb # wandb | tensorboard | null wandb_project: CFM-TTS # wandb project name wandb_run_name: ${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} # wandb run name wandb_resume_id: null # wandb run id for resuming, null to auto-detect from checkpoint log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples save_per_updates: 50000 # save checkpoint per updates keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints last_per_updates: 5000 # save last checkpoint per updates save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} ================================================ FILE: src/f5_tts/configs/F5TTS_Base.yaml ================================================ hydra: run: dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} datasets: name: Emilia_ZH_EN # dataset name batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 batch_size_type: frame # frame | sample max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models num_workers: 16 optim: epochs: 11 learning_rate: 7.5e-5 num_warmup_updates: 20000 # warmup updates grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps max_grad_norm: 1.0 # gradient clipping bnb_optimizer: False # use bnb 8bit AdamW optimizer or not model: name: F5TTS_Base # model name tokenizer: pinyin # tokenizer type tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt) backbone: DiT arch: dim: 1024 depth: 22 heads: 16 ff_mult: 2 text_dim: 512 text_mask_padding: False conv_layers: 4 pe_attn_head: 1 attn_backend: torch # torch | flash_attn attn_mask_enabled: False checkpoint_activations: False # recompute activations and save memory for extra compute mel_spec: target_sample_rate: 24000 n_mel_channels: 100 hop_length: 256 win_length: 1024 n_fft: 1024 mel_spec_type: vocos # vocos | bigvgan vocoder: is_local: False # use local offline ckpt or not local_path: null # local vocoder path ckpts: logger: wandb # wandb | tensorboard | null wandb_project: CFM-TTS # wandb project name wandb_run_name: ${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} # wandb run name wandb_resume_id: null # wandb run id for resuming, null to auto-detect from checkpoint log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples save_per_updates: 50000 # save checkpoint per updates keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints last_per_updates: 5000 # save last checkpoint per updates save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} ================================================ FILE: src/f5_tts/configs/F5TTS_Small.yaml ================================================ hydra: run: dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} datasets: name: Emilia_ZH_EN batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 batch_size_type: frame # frame | sample max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models num_workers: 16 optim: epochs: 11 # only suitable for Emilia, if you want to train it on LibriTTS, set epoch 686 learning_rate: 7.5e-5 num_warmup_updates: 20000 # warmup updates grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps max_grad_norm: 1.0 # gradient clipping bnb_optimizer: False # use bnb 8bit AdamW optimizer or not model: name: F5TTS_Small tokenizer: pinyin tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt) backbone: DiT arch: dim: 768 depth: 18 heads: 12 ff_mult: 2 text_dim: 512 text_mask_padding: False conv_layers: 4 pe_attn_head: 1 attn_backend: torch # torch | flash_attn attn_mask_enabled: False checkpoint_activations: False # recompute activations and save memory for extra compute mel_spec: target_sample_rate: 24000 n_mel_channels: 100 hop_length: 256 win_length: 1024 n_fft: 1024 mel_spec_type: vocos # vocos | bigvgan vocoder: is_local: False # use local offline ckpt or not local_path: null # local vocoder path ckpts: logger: wandb # wandb | tensorboard | null wandb_project: CFM-TTS # wandb project name wandb_run_name: ${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} # wandb run name wandb_resume_id: null # wandb run id for resuming, null to auto-detect from checkpoint log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples save_per_updates: 50000 # save checkpoint per updates keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints last_per_updates: 5000 # save last checkpoint per updates save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} ================================================ FILE: src/f5_tts/configs/F5TTS_v1_Base.yaml ================================================ hydra: run: dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} datasets: name: Emilia_ZH_EN # dataset name batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 batch_size_type: frame # frame | sample max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models num_workers: 16 optim: epochs: 11 learning_rate: 7.5e-5 num_warmup_updates: 20000 # warmup updates grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps max_grad_norm: 1.0 # gradient clipping bnb_optimizer: False # use bnb 8bit AdamW optimizer or not model: name: F5TTS_v1_Base # model name tokenizer: pinyin # tokenizer type tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt) backbone: DiT arch: dim: 1024 depth: 22 heads: 16 ff_mult: 2 text_dim: 512 text_mask_padding: True qk_norm: null # null | rms_norm conv_layers: 4 pe_attn_head: null attn_backend: torch # torch | flash_attn attn_mask_enabled: False checkpoint_activations: False # recompute activations and save memory for extra compute mel_spec: target_sample_rate: 24000 n_mel_channels: 100 hop_length: 256 win_length: 1024 n_fft: 1024 mel_spec_type: vocos # vocos | bigvgan vocoder: is_local: False # use local offline ckpt or not local_path: null # local vocoder path ckpts: logger: wandb # wandb | tensorboard | null wandb_project: CFM-TTS # wandb project name wandb_run_name: ${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} # wandb run name wandb_resume_id: null # wandb run id for resuming, null to auto-detect from checkpoint log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples save_per_updates: 50000 # save checkpoint per updates keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints last_per_updates: 5000 # save last checkpoint per updates save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} ================================================ FILE: src/f5_tts/eval/README.md ================================================ # Evaluation Install packages for evaluation: ```bash pip install -e .[eval] ``` > [!IMPORTANT] > For [faster-whisper](https://github.com/SYSTRAN/faster-whisper), for various compatibilities: > `pip install ctranslate2==4.5.0` if CUDA 12 and cuDNN 9; > `pip install ctranslate2==4.4.0` if CUDA 12 and cuDNN 8; > `pip install ctranslate2==3.24.0` if CUDA 11 and cuDNN 8. ## Generating Samples for Evaluation ### Prepare Test Datasets 1. *Seed-TTS testset*: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval). 2. *LibriSpeech test-clean*: Download from [OpenSLR](http://www.openslr.org/12/). 3. Unzip the downloaded datasets and place them in the `data/` directory. 4. Our filtered LibriSpeech-PC 4-10s subset: `data/librispeech_pc_test_clean_cross_sentence.lst` ### Batch Inference for Test Set To run batch inference for evaluations, execute the following commands: ```bash # if not setup accelerate config yet accelerate config # if only perform inference bash src/f5_tts/eval/eval_infer_batch.sh --infer-only # if inference and with corresponding evaluation, setup the following tools first bash src/f5_tts/eval/eval_infer_batch.sh ``` ## Objective Evaluation on Generated Results ### Download Evaluation Model Checkpoints 1. Chinese ASR Model: [Paraformer-zh](https://huggingface.co/funasr/paraformer-zh) 2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3) 3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view). > [!NOTE] > ASR model will be automatically downloaded if `--local` not set for evaluation scripts. > Otherwise, you should update the `asr_ckpt_dir` path values in `eval_librispeech_test_clean.py` or `eval_seedtts_testset.py`. > > WavLM model must be downloaded and your `wavlm_ckpt_dir` path updated in `eval_librispeech_test_clean.py` and `eval_seedtts_testset.py`. ### Objective Evaluation Examples Update the path with your batch-inferenced results, and carry out WER / SIM / UTMOS evaluations: ```bash # Evaluation [WER] for Seed-TTS test [ZH] set python src/f5_tts/eval/eval_seedtts_testset.py --eval_task wer --lang zh --gen_wav_dir --gpu_nums 8 # Evaluation [SIM] for LibriSpeech-PC test-clean (cross-sentence) python src/f5_tts/eval/eval_librispeech_test_clean.py --eval_task sim --gen_wav_dir --librispeech_test_clean_path # Evaluation [UTMOS]. --ext: Audio extension python src/f5_tts/eval/eval_utmos.py --audio_dir --ext wav ``` > [!NOTE] > Evaluation results can also be found in `_*_results.jsonl` files saved in ``/``. ================================================ FILE: src/f5_tts/eval/ecapa_tdnn.py ================================================ # just for speaker similarity evaluation, third-party code # From https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/ # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN import os import torch import torch.nn as nn import torch.nn.functional as F """ Res2Conv1d + BatchNorm1d + ReLU """ class Res2Conv1dReluBn(nn.Module): """ in_channels == out_channels == channels """ def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4): super().__init__() assert channels % scale == 0, "{} % {} != 0".format(channels, scale) self.scale = scale self.width = channels // scale self.nums = scale if scale == 1 else scale - 1 self.convs = [] self.bns = [] for i in range(self.nums): self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias)) self.bns.append(nn.BatchNorm1d(self.width)) self.convs = nn.ModuleList(self.convs) self.bns = nn.ModuleList(self.bns) def forward(self, x): out = [] spx = torch.split(x, self.width, 1) for i in range(self.nums): if i == 0: sp = spx[i] else: sp = sp + spx[i] # Order: conv -> relu -> bn sp = self.convs[i](sp) sp = self.bns[i](F.relu(sp)) out.append(sp) if self.scale != 1: out.append(spx[self.nums]) out = torch.cat(out, dim=1) return out """ Conv1d + BatchNorm1d + ReLU """ class Conv1dReluBn(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True): super().__init__() self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias) self.bn = nn.BatchNorm1d(out_channels) def forward(self, x): return self.bn(F.relu(self.conv(x))) """ The SE connection of 1D case. """ class SE_Connect(nn.Module): def __init__(self, channels, se_bottleneck_dim=128): super().__init__() self.linear1 = nn.Linear(channels, se_bottleneck_dim) self.linear2 = nn.Linear(se_bottleneck_dim, channels) def forward(self, x): out = x.mean(dim=2) out = F.relu(self.linear1(out)) out = torch.sigmoid(self.linear2(out)) out = x * out.unsqueeze(2) return out """ SE-Res2Block of the ECAPA-TDNN architecture. """ # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale): # return nn.Sequential( # Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0), # Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale), # Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0), # SE_Connect(channels) # ) class SE_Res2Block(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim): super().__init__() self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0) self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale) self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0) self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim) self.shortcut = None if in_channels != out_channels: self.shortcut = nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, ) def forward(self, x): residual = x if self.shortcut: residual = self.shortcut(x) x = self.Conv1dReluBn1(x) x = self.Res2Conv1dReluBn(x) x = self.Conv1dReluBn2(x) x = self.SE_Connect(x) return x + residual """ Attentive weighted mean and standard deviation pooling. """ class AttentiveStatsPool(nn.Module): def __init__(self, in_dim, attention_channels=128, global_context_att=False): super().__init__() self.global_context_att = global_context_att # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs. if global_context_att: self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper else: self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper def forward(self, x): if self.global_context_att: context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x) context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x) x_in = torch.cat((x, context_mean, context_std), dim=1) else: x_in = x # DON'T use ReLU here! In experiments, I find ReLU hard to converge. alpha = torch.tanh(self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in)) alpha = torch.softmax(self.linear2(alpha), dim=2) mean = torch.sum(alpha * x, dim=2) residuals = torch.sum(alpha * (x**2), dim=2) - mean**2 std = torch.sqrt(residuals.clamp(min=1e-9)) return torch.cat([mean, std], dim=1) class ECAPA_TDNN(nn.Module): def __init__( self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False, feat_type="wavlm_large", sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None, ): super().__init__() self.feat_type = feat_type self.feature_selection = feature_selection self.update_extract = update_extract self.sr = sr torch.hub._validate_not_a_forked_repo = lambda a, b, c: True try: local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main") self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source="local", config_path=config_path) except: # noqa: E722 self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type) if len(self.feature_extract.model.encoder.layers) == 24 and hasattr( self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention" ): self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False if len(self.feature_extract.model.encoder.layers) == 24 and hasattr( self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention" ): self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False self.feat_num = self.get_feat_num() self.feature_weight = nn.Parameter(torch.zeros(self.feat_num)) if feat_type != "fbank" and feat_type != "mfcc": freeze_list = ["final_proj", "label_embs_concat", "mask_emb", "project_q", "quantizer"] for name, param in self.feature_extract.named_parameters(): for freeze_val in freeze_list: if freeze_val in name: param.requires_grad = False break if not self.update_extract: for param in self.feature_extract.parameters(): param.requires_grad = False self.instance_norm = nn.InstanceNorm1d(feat_dim) # self.channels = [channels] * 4 + [channels * 3] self.channels = [channels] * 4 + [1536] self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2) self.layer2 = SE_Res2Block( self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128, ) self.layer3 = SE_Res2Block( self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128, ) self.layer4 = SE_Res2Block( self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128, ) # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1) cat_channels = channels * 3 self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1) self.pooling = AttentiveStatsPool( self.channels[-1], attention_channels=128, global_context_att=global_context_att ) self.bn = nn.BatchNorm1d(self.channels[-1] * 2) self.linear = nn.Linear(self.channels[-1] * 2, emb_dim) def get_feat_num(self): self.feature_extract.eval() wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)] with torch.no_grad(): features = self.feature_extract(wav) select_feature = features[self.feature_selection] if isinstance(select_feature, (list, tuple)): return len(select_feature) else: return 1 def get_feat(self, x): if self.update_extract: x = self.feature_extract([sample for sample in x]) else: with torch.no_grad(): if self.feat_type == "fbank" or self.feat_type == "mfcc": x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len else: x = self.feature_extract([sample for sample in x]) if self.feat_type == "fbank": x = x.log() if self.feat_type != "fbank" and self.feat_type != "mfcc": x = x[self.feature_selection] if isinstance(x, (list, tuple)): x = torch.stack(x, dim=0) else: x = x.unsqueeze(0) norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) x = (norm_weights * x).sum(dim=0) x = torch.transpose(x, 1, 2) + 1e-6 x = self.instance_norm(x) return x def forward(self, x): x = self.get_feat(x) out1 = self.layer1(x) out2 = self.layer2(out1) out3 = self.layer3(out2) out4 = self.layer4(out3) out = torch.cat([out2, out3, out4], dim=1) out = F.relu(self.conv(out)) out = self.bn(self.pooling(out)) out = self.linear(out) return out def ECAPA_TDNN_SMALL( feat_dim, emb_dim=256, feat_type="wavlm_large", sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None, ): return ECAPA_TDNN( feat_dim=feat_dim, channels=512, emb_dim=emb_dim, feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract, config_path=config_path, ) ================================================ FILE: src/f5_tts/eval/eval_infer_batch.py ================================================ import os import sys sys.path.append(os.getcwd()) import argparse import time from importlib.resources import files import torch import torchaudio from accelerate import Accelerator from hydra.utils import get_class from omegaconf import OmegaConf from tqdm import tqdm from f5_tts.eval.utils_eval import ( get_inference_prompt, get_librispeech_test_clean_metainfo, get_seedtts_testset_metainfo, ) from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder from f5_tts.model import CFM from f5_tts.model.utils import get_tokenizer accelerator = Accelerator() device = f"cuda:{accelerator.process_index}" use_ema = True target_rms = 0.1 rel_path = str(files("f5_tts").joinpath("../../")) def main(): parser = argparse.ArgumentParser(description="batch inference") parser.add_argument("-s", "--seed", default=None, type=int) parser.add_argument("-n", "--expname", required=True) parser.add_argument("-c", "--ckptstep", default=1250000, type=int) parser.add_argument("-nfe", "--nfestep", default=32, type=int) parser.add_argument("-o", "--odemethod", default="euler") parser.add_argument("-ss", "--swaysampling", default=-1, type=float) parser.add_argument("-t", "--testset", required=True) parser.add_argument( "-p", "--librispeech_test_clean_path", default=f"{rel_path}/data/LibriSpeech/test-clean", type=str ) parser.add_argument("--local", action="store_true", help="Use local vocoder checkpoint directory") args = parser.parse_args() seed = args.seed exp_name = args.expname ckpt_step = args.ckptstep nfe_step = args.nfestep ode_method = args.odemethod sway_sampling_coef = args.swaysampling testset = args.testset infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended) cfg_strength = 2.0 speed = 1.0 use_truth_duration = False no_ref_audio = False model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml"))) model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}") model_arc = model_cfg.model.arch dataset_name = model_cfg.datasets.name tokenizer = model_cfg.model.tokenizer mel_spec_type = model_cfg.model.mel_spec.mel_spec_type target_sample_rate = model_cfg.model.mel_spec.target_sample_rate n_mel_channels = model_cfg.model.mel_spec.n_mel_channels hop_length = model_cfg.model.mel_spec.hop_length win_length = model_cfg.model.mel_spec.win_length n_fft = model_cfg.model.mel_spec.n_fft if testset == "ls_pc_test_clean": metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst" librispeech_test_clean_path = args.librispeech_test_clean_path metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path) elif testset == "seedtts_test_zh": metalst = rel_path + "/data/seedtts_testset/zh/meta.lst" metainfo = get_seedtts_testset_metainfo(metalst) elif testset == "seedtts_test_en": metalst = rel_path + "/data/seedtts_testset/en/meta.lst" metainfo = get_seedtts_testset_metainfo(metalst) # path to save genereted wavs output_dir = ( f"{rel_path}/" f"results/{exp_name}_{ckpt_step}/{testset}/" f"seed{seed}_{ode_method}_nfe{nfe_step}_{mel_spec_type}" f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}" f"_cfg{cfg_strength}_speed{speed}" f"{'_gt-dur' if use_truth_duration else ''}" f"{'_no-ref-audio' if no_ref_audio else ''}" ) # -------------------------------------------------# prompts_all = get_inference_prompt( metainfo, speed=speed, tokenizer=tokenizer, target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length, mel_spec_type=mel_spec_type, target_rms=target_rms, use_truth_duration=use_truth_duration, infer_batch_size=infer_batch_size, ) # Vocoder model local = args.local if mel_spec_type == "vocos": vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz" elif mel_spec_type == "bigvgan": vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x" vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path) # Tokenizer vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer) # Model model = CFM( transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels), mel_spec_kwargs=dict( n_fft=n_fft, hop_length=hop_length, win_length=win_length, n_mel_channels=n_mel_channels, target_sample_rate=target_sample_rate, mel_spec_type=mel_spec_type, ), odeint_kwargs=dict( method=ode_method, ), vocab_char_map=vocab_char_map, ).to(device) ckpt_prefix = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}" if os.path.exists(ckpt_prefix + ".pt"): ckpt_path = ckpt_prefix + ".pt" elif os.path.exists(ckpt_prefix + ".safetensors"): ckpt_path = ckpt_prefix + ".safetensors" else: print("Loading from self-organized training checkpoints rather than released pretrained.") ckpt_prefix = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}" if os.path.exists(ckpt_prefix + ".pt"): ckpt_path = ckpt_prefix + ".pt" elif os.path.exists(ckpt_prefix + ".safetensors"): ckpt_path = ckpt_prefix + ".safetensors" else: raise ValueError("The checkpoint does not exist or cannot be found in given location.") dtype = torch.float32 if mel_spec_type == "bigvgan" else None model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema) if not os.path.exists(output_dir) and accelerator.is_main_process: os.makedirs(output_dir) # start batch inference accelerator.wait_for_everyone() start = time.time() with accelerator.split_between_processes(prompts_all) as prompts: for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process): utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt ref_mels = ref_mels.to(device) ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device) total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device) # Inference with torch.inference_mode(): generated, _ = model.sample( cond=ref_mels, text=final_text_list, duration=total_mel_lens, lens=ref_mel_lens, steps=nfe_step, cfg_strength=cfg_strength, sway_sampling_coef=sway_sampling_coef, no_ref_audio=no_ref_audio, seed=seed, ) # Final result for i, gen in enumerate(generated): gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0) gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32) if mel_spec_type == "vocos": generated_wave = vocoder.decode(gen_mel_spec).cpu() elif mel_spec_type == "bigvgan": generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu() if ref_rms_list[i] < target_rms: generated_wave = generated_wave * ref_rms_list[i] / target_rms torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate) accelerator.wait_for_everyone() if accelerator.is_main_process: timediff = time.time() - start print(f"Done batch inference in {timediff / 60:.2f} minutes.") if __name__ == "__main__": main() ================================================ FILE: src/f5_tts/eval/eval_infer_batch.sh ================================================ #!/bin/bash set -e export PYTHONWARNINGS="ignore::UserWarning,ignore::FutureWarning" # Configuration parameters MODEL_NAME="F5TTS_v1_Base" SEEDS=(0 1 2) CKPTSTEPS=(1250000) TASKS=("seedtts_test_zh" "seedtts_test_en" "ls_pc_test_clean") LS_TEST_CLEAN_PATH="data/LibriSpeech/test-clean" GPUS="[0,1,2,3,4,5,6,7]" OFFLINE_MODE=false # Parse arguments if [ $OFFLINE_MODE = true ]; then LOCAL="--local" else LOCAL="" fi INFER_ONLY=false while [[ $# -gt 0 ]]; do case $1 in --infer-only) INFER_ONLY=true shift ;; *) echo "======== Unknown parameter: $1" exit 1 ;; esac done echo "======== Starting F5-TTS batch evaluation task..." if [ "$INFER_ONLY" = true ]; then echo "======== Mode: Execute infer tasks only" else echo "======== Mode: Execute full pipeline (infer + eval)" fi # Function: Execute eval tasks execute_eval_tasks() { local ckptstep=$1 local seed=$2 local task_name=$3 local gen_wav_dir="results/${MODEL_NAME}_${ckptstep}/${task_name}/seed${seed}_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0" echo ">>>>>>>> Starting eval task: ckptstep=${ckptstep}, seed=${seed}, task=${task_name}" case $task_name in "seedtts_test_zh") python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh -g "$gen_wav_dir" -n "$GPUS" $LOCAL python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh -g "$gen_wav_dir" -n "$GPUS" $LOCAL python src/f5_tts/eval/eval_utmos.py --audio_dir "$gen_wav_dir" ;; "seedtts_test_en") python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l en -g "$gen_wav_dir" -n "$GPUS" $LOCAL python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l en -g "$gen_wav_dir" -n "$GPUS" $LOCAL python src/f5_tts/eval/eval_utmos.py --audio_dir "$gen_wav_dir" ;; "ls_pc_test_clean") python src/f5_tts/eval/eval_librispeech_test_clean.py -e wer -g "$gen_wav_dir" -n "$GPUS" -p "$LS_TEST_CLEAN_PATH" $LOCAL python src/f5_tts/eval/eval_librispeech_test_clean.py -e sim -g "$gen_wav_dir" -n "$GPUS" -p "$LS_TEST_CLEAN_PATH" $LOCAL python src/f5_tts/eval/eval_utmos.py --audio_dir "$gen_wav_dir" ;; esac echo ">>>>>>>> Completed eval task: ckptstep=${ckptstep}, seed=${seed}, task=${task_name}" } # Main execution loop for ckptstep in "${CKPTSTEPS[@]}"; do echo "======== Processing ckptstep: ${ckptstep}" for seed in "${SEEDS[@]}"; do echo "-------- Processing seed: ${seed}" # Store eval task PIDs for current seed (if not infer-only mode) if [ "$INFER_ONLY" = false ]; then declare -a eval_pids fi # Execute each infer task sequentially for task in "${TASKS[@]}"; do echo ">>>>>>>> Executing infer task: accelerate launch src/f5_tts/eval/eval_infer_batch.py -s ${seed} -n \"${MODEL_NAME}\" -t \"${task}\" -c ${ckptstep} $LOCAL" # Execute infer task (foreground execution, wait for completion) accelerate launch src/f5_tts/eval/eval_infer_batch.py -s ${seed} -n "${MODEL_NAME}" -t "${task}" -c ${ckptstep} -p "${LS_TEST_CLEAN_PATH}" $LOCAL # If not infer-only mode, launch corresponding eval task if [ "$INFER_ONLY" = false ]; then # Launch corresponding eval task (background execution, non-blocking for next infer) execute_eval_tasks $ckptstep $seed $task & eval_pids+=($!) fi done # If not infer-only mode, wait for all eval tasks of current seed to complete if [ "$INFER_ONLY" = false ]; then echo ">>>>>>>> All infer tasks for seed ${seed} completed, waiting for corresponding eval tasks to finish..." for pid in "${eval_pids[@]}"; do wait $pid done unset eval_pids # Clean up array fi echo "-------- All eval tasks for seed ${seed} completed" done echo "======== Completed ckptstep: ${ckptstep}" echo done echo "======== All tasks completed!" ================================================ FILE: src/f5_tts/eval/eval_infer_batch_example.sh ================================================ #!/bin/bash # e.g. F5-TTS, 16 NFE accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_zh" -nfe 16 accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_en" -nfe 16 accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "ls_pc_test_clean" -nfe 16 -p data/LibriSpeech/test-clean # e.g. Vanilla E2 TTS, 32 NFE accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_zh" -o "midpoint" -ss 0 accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_en" -o "midpoint" -ss 0 accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "ls_pc_test_clean" -o "midpoint" -ss 0 -p data/LibriSpeech/test-clean # e.g. evaluate F5-TTS 16 NFE result on Seed-TTS test-zh python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8 python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8 python src/f5_tts/eval/eval_utmos.py --audio_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0 # etc. ================================================ FILE: src/f5_tts/eval/eval_librispeech_test_clean.py ================================================ # Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation) import argparse import ast import json import os import sys sys.path.append(os.getcwd()) import multiprocessing as mp from importlib.resources import files import numpy as np from f5_tts.eval.utils_eval import get_librispeech_test, run_asr_wer, run_sim rel_path = str(files("f5_tts").joinpath("../../")) def get_args(): parser = argparse.ArgumentParser() parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"]) parser.add_argument("-l", "--lang", type=str, default="en") parser.add_argument("-g", "--gen_wav_dir", type=str, required=True) parser.add_argument("-p", "--librispeech_test_clean_path", type=str, required=True) parser.add_argument( "-n", "--gpu_nums", type=str, default="8", help="Number of GPUs to use (e.g., 8) or GPU list (e.g., [0,1,2,3])" ) parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory") return parser.parse_args() def parse_gpu_nums(gpu_nums_str): try: if gpu_nums_str.startswith("[") and gpu_nums_str.endswith("]"): gpu_list = ast.literal_eval(gpu_nums_str) if isinstance(gpu_list, list): return gpu_list return list(range(int(gpu_nums_str))) except (ValueError, SyntaxError): raise argparse.ArgumentTypeError( f"Invalid GPU specification: {gpu_nums_str}. Use a number (e.g., 8) or a list (e.g., [0,1,2,3])" ) def main(): args = get_args() eval_task = args.eval_task lang = args.lang librispeech_test_clean_path = args.librispeech_test_clean_path # test-clean path gen_wav_dir = args.gen_wav_dir metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst" gpus = parse_gpu_nums(args.gpu_nums) test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path) ## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book, ## leading to a low similarity for the ground truth in some cases. # test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = True) # eval ground truth local = args.local if local: # use local custom checkpoint dir asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3" else: asr_ckpt_dir = "" # auto download to cache dir wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth" # -------------------------------------------------------------------------- full_results = [] metrics = [] if eval_task == "wer": with mp.Pool(processes=len(gpus)) as pool: args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set] results = pool.map(run_asr_wer, args) for r in results: full_results.extend(r) elif eval_task == "sim": with mp.Pool(processes=len(gpus)) as pool: args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set] results = pool.map(run_sim, args) for r in results: full_results.extend(r) else: raise ValueError(f"Unknown metric type: {eval_task}") result_path = f"{gen_wav_dir}/_{eval_task}_results.jsonl" with open(result_path, "w") as f: for line in full_results: metrics.append(line[eval_task]) f.write(json.dumps(line, ensure_ascii=False) + "\n") metric = round(np.mean(metrics), 5) f.write(f"\n{eval_task.upper()}: {metric}\n") print(f"\nTotal {len(metrics)} samples") print(f"{eval_task.upper()}: {metric}") print(f"{eval_task.upper()} results saved to {result_path}") if __name__ == "__main__": main() ================================================ FILE: src/f5_tts/eval/eval_seedtts_testset.py ================================================ # Evaluate with Seed-TTS testset import argparse import ast import json import os import sys sys.path.append(os.getcwd()) import multiprocessing as mp from importlib.resources import files import numpy as np from f5_tts.eval.utils_eval import get_seed_tts_test, run_asr_wer, run_sim rel_path = str(files("f5_tts").joinpath("../../")) def get_args(): parser = argparse.ArgumentParser() parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"]) parser.add_argument("-l", "--lang", type=str, default="en", choices=["zh", "en"]) parser.add_argument("-g", "--gen_wav_dir", type=str, required=True) parser.add_argument( "-n", "--gpu_nums", type=str, default="8", help="Number of GPUs to use (e.g., 8) or GPU list (e.g., [0,1,2,3])" ) parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory") return parser.parse_args() def parse_gpu_nums(gpu_nums_str): try: if gpu_nums_str.startswith("[") and gpu_nums_str.endswith("]"): gpu_list = ast.literal_eval(gpu_nums_str) if isinstance(gpu_list, list): return gpu_list return list(range(int(gpu_nums_str))) except (ValueError, SyntaxError): raise argparse.ArgumentTypeError( f"Invalid GPU specification: {gpu_nums_str}. Use a number (e.g., 8) or a list (e.g., [0,1,2,3])" ) def main(): args = get_args() eval_task = args.eval_task lang = args.lang gen_wav_dir = args.gen_wav_dir metalst = rel_path + f"/data/seedtts_testset/{lang}/meta.lst" # seed-tts testset # NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different # zh 1.254 seems a result of 4 workers wer_seed_tts gpus = parse_gpu_nums(args.gpu_nums) test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus) local = args.local if local: # use local custom checkpoint dir if lang == "zh": asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr elif lang == "en": asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3" else: asr_ckpt_dir = "" # auto download to cache dir wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth" # -------------------------------------------------------------------------- full_results = [] metrics = [] if eval_task == "wer": with mp.Pool(processes=len(gpus)) as pool: args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set] results = pool.map(run_asr_wer, args) for r in results: full_results.extend(r) elif eval_task == "sim": with mp.Pool(processes=len(gpus)) as pool: args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set] results = pool.map(run_sim, args) for r in results: full_results.extend(r) else: raise ValueError(f"Unknown metric type: {eval_task}") result_path = f"{gen_wav_dir}/_{eval_task}_results.jsonl" with open(result_path, "w") as f: for line in full_results: metrics.append(line[eval_task]) f.write(json.dumps(line, ensure_ascii=False) + "\n") metric = round(np.mean(metrics), 5) f.write(f"\n{eval_task.upper()}: {metric}\n") print(f"\nTotal {len(metrics)} samples") print(f"{eval_task.upper()}: {metric}") print(f"{eval_task.upper()} results saved to {result_path}") if __name__ == "__main__": main() ================================================ FILE: src/f5_tts/eval/eval_utmos.py ================================================ import argparse import json from pathlib import Path import librosa import torch from tqdm import tqdm def main(): parser = argparse.ArgumentParser(description="UTMOS Evaluation") parser.add_argument("--audio_dir", type=str, required=True, help="Audio file path.") parser.add_argument("--ext", type=str, default="wav", help="Audio extension.") args = parser.parse_args() device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "cpu" predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True) predictor = predictor.to(device) audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}")) utmos_score = 0 utmos_result_path = Path(args.audio_dir) / "_utmos_results.jsonl" with open(utmos_result_path, "w", encoding="utf-8") as f: for audio_path in tqdm(audio_paths, desc="Processing"): wav, sr = librosa.load(audio_path, sr=None, mono=True) wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0) score = predictor(wav_tensor, sr) line = {} line["wav"], line["utmos"] = str(audio_path.stem), score.item() utmos_score += score.item() f.write(json.dumps(line, ensure_ascii=False) + "\n") avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0 f.write(f"\nUTMOS: {avg_score:.4f}\n") print(f"UTMOS: {avg_score:.4f}") print(f"UTMOS results saved to {utmos_result_path}") if __name__ == "__main__": main() ================================================ FILE: src/f5_tts/eval/utils_eval.py ================================================ import math import os import random import string from pathlib import Path import torch import torch.nn.functional as F import torchaudio from tqdm import tqdm from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL from f5_tts.model.modules import MelSpec from f5_tts.model.utils import convert_char_to_pinyin # seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav def get_seedtts_testset_metainfo(metalst): f = open(metalst) lines = f.readlines() f.close() metainfo = [] for line in lines: if len(line.strip().split("|")) == 5: utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|") elif len(line.strip().split("|")) == 4: utt, prompt_text, prompt_wav, gt_text = line.strip().split("|") gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav") if not os.path.isabs(prompt_wav): prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav) metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav)) return metainfo # librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path): f = open(metalst) lines = f.readlines() f.close() metainfo = [] for line in lines: ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t") # ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc) ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-") ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac") # gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc) gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-") gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac") metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav)) return metainfo # padded to max length mel batch def padded_mel_batch(ref_mels): max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax() padded_ref_mels = [] for mel in ref_mels: padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0) padded_ref_mels.append(padded_ref_mel) padded_ref_mels = torch.stack(padded_ref_mels) padded_ref_mels = padded_ref_mels.permute(0, 2, 1) return padded_ref_mels # get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav def get_inference_prompt( metainfo, speed=1.0, tokenizer="pinyin", polyphone=True, target_sample_rate=24000, n_fft=1024, win_length=1024, n_mel_channels=100, hop_length=256, mel_spec_type="vocos", target_rms=0.1, use_truth_duration=False, infer_batch_size=1, num_buckets=200, min_secs=3, max_secs=40, ): prompts_all = [] min_tokens = min_secs * target_sample_rate // hop_length max_tokens = max_secs * target_sample_rate // hop_length batch_accum = [0] * num_buckets utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = ( [[] for _ in range(num_buckets)] for _ in range(6) ) mel_spectrogram = MelSpec( n_fft=n_fft, hop_length=hop_length, win_length=win_length, n_mel_channels=n_mel_channels, target_sample_rate=target_sample_rate, mel_spec_type=mel_spec_type, ) for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."): # Audio ref_audio, ref_sr = torchaudio.load(prompt_wav) ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio))) if ref_rms < target_rms: ref_audio = ref_audio * target_rms / ref_rms assert ref_audio.shape[-1] > 5000, f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue." if ref_sr != target_sample_rate: resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) ref_audio = resampler(ref_audio) # Text if len(prompt_text[-1].encode("utf-8")) == 1: prompt_text = prompt_text + " " text = [prompt_text + gt_text] if tokenizer == "pinyin": text_list = convert_char_to_pinyin(text, polyphone=polyphone) else: text_list = text # to mel spectrogram ref_mel = mel_spectrogram(ref_audio) ref_mel = ref_mel.squeeze(0) # Duration, mel frame length ref_mel_len = ref_mel.shape[-1] if use_truth_duration: gt_audio, gt_sr = torchaudio.load(gt_wav) if gt_sr != target_sample_rate: resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate) gt_audio = resampler(gt_audio) total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed) # # test vocoder resynthesis # ref_audio = gt_audio else: ref_text_len = len(prompt_text.encode("utf-8")) gen_text_len = len(gt_text.encode("utf-8")) total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed) # deal with batch assert infer_batch_size > 0, "infer_batch_size should be greater than 0." assert min_tokens <= total_mel_len <= max_tokens, ( f"Audio {utt} has duration {total_mel_len * hop_length // target_sample_rate}s out of range [{min_secs}, {max_secs}]." ) bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets) utts[bucket_i].append(utt) ref_rms_list[bucket_i].append(ref_rms) ref_mels[bucket_i].append(ref_mel) ref_mel_lens[bucket_i].append(ref_mel_len) total_mel_lens[bucket_i].append(total_mel_len) final_text_list[bucket_i].extend(text_list) batch_accum[bucket_i] += total_mel_len if batch_accum[bucket_i] >= infer_batch_size: # print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}") prompts_all.append( ( utts[bucket_i], ref_rms_list[bucket_i], padded_mel_batch(ref_mels[bucket_i]), ref_mel_lens[bucket_i], total_mel_lens[bucket_i], final_text_list[bucket_i], ) ) batch_accum[bucket_i] = 0 ( utts[bucket_i], ref_rms_list[bucket_i], ref_mels[bucket_i], ref_mel_lens[bucket_i], total_mel_lens[bucket_i], final_text_list[bucket_i], ) = [], [], [], [], [], [] # add residual for bucket_i, bucket_frames in enumerate(batch_accum): if bucket_frames > 0: prompts_all.append( ( utts[bucket_i], ref_rms_list[bucket_i], padded_mel_batch(ref_mels[bucket_i]), ref_mel_lens[bucket_i], total_mel_lens[bucket_i], final_text_list[bucket_i], ) ) # not only leave easy work for last workers random.seed(666) random.shuffle(prompts_all) return prompts_all # get wav_res_ref_text of seed-tts test metalst # https://github.com/BytedanceSpeech/seed-tts-eval def get_seed_tts_test(metalst, gen_wav_dir, gpus): f = open(metalst) lines = f.readlines() f.close() test_set_ = [] for line in tqdm(lines): if len(line.strip().split("|")) == 5: utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|") elif len(line.strip().split("|")) == 4: utt, prompt_text, prompt_wav, gt_text = line.strip().split("|") if not os.path.exists(os.path.join(gen_wav_dir, utt + ".wav")): continue gen_wav = os.path.join(gen_wav_dir, utt + ".wav") if not os.path.isabs(prompt_wav): prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav) test_set_.append((gen_wav, prompt_wav, gt_text)) num_jobs = len(gpus) if num_jobs == 1: return [(gpus[0], test_set_)] wav_per_job = len(test_set_) // num_jobs + 1 test_set = [] for i in range(num_jobs): test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job])) return test_set # get librispeech test-clean cross sentence test def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth=False): f = open(metalst) lines = f.readlines() f.close() test_set_ = [] for line in tqdm(lines): ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t") if eval_ground_truth: gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-") gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac") else: if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + ".wav")): raise FileNotFoundError(f"Generated wav not found: {gen_utt}") gen_wav = os.path.join(gen_wav_dir, gen_utt + ".wav") ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-") ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac") test_set_.append((gen_wav, ref_wav, gen_txt)) num_jobs = len(gpus) if num_jobs == 1: return [(gpus[0], test_set_)] wav_per_job = len(test_set_) // num_jobs + 1 test_set = [] for i in range(num_jobs): test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job])) return test_set # load asr model def load_asr_model(lang, ckpt_dir=""): if lang == "zh": from funasr import AutoModel model = AutoModel( model=os.path.join(ckpt_dir, "paraformer-zh"), # vad_model = os.path.join(ckpt_dir, "fsmn-vad"), # punc_model = os.path.join(ckpt_dir, "ct-punc"), # spk_model = os.path.join(ckpt_dir, "cam++"), disable_update=True, ) # following seed-tts setting elif lang == "en": from faster_whisper import WhisperModel model_size = "large-v3" if ckpt_dir == "" else ckpt_dir model = WhisperModel(model_size, device="cuda", compute_type="float16") return model # WER Evaluation, the way Seed-TTS does def run_asr_wer(args): rank, lang, test_set, ckpt_dir = args if lang == "zh": import zhconv torch.cuda.set_device(rank) elif lang == "en": os.environ["CUDA_VISIBLE_DEVICES"] = str(rank) else: raise NotImplementedError( "lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now." ) asr_model = load_asr_model(lang, ckpt_dir=ckpt_dir) from zhon.hanzi import punctuation punctuation_all = punctuation + string.punctuation wer_results = [] from jiwer import process_words for gen_wav, prompt_wav, truth in tqdm(test_set): if lang == "zh": res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True) hypo = res[0]["text"] hypo = zhconv.convert(hypo, "zh-cn") elif lang == "en": segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en") hypo = "" for segment in segments: hypo = hypo + " " + segment.text raw_truth = truth raw_hypo = hypo for x in punctuation_all: truth = truth.replace(x, "") hypo = hypo.replace(x, "") truth = truth.replace(" ", " ") hypo = hypo.replace(" ", " ") if lang == "zh": truth = " ".join([x for x in truth]) hypo = " ".join([x for x in hypo]) elif lang == "en": truth = truth.lower() hypo = hypo.lower() measures = process_words(truth, hypo) wer = measures.wer # ref_list = truth.split(" ") # subs = measures.substitutions / len(ref_list) # dele = measures.deletions / len(ref_list) # inse = measures.insertions / len(ref_list) wer_results.append( { "wav": Path(gen_wav).stem, "truth": raw_truth, "hypo": raw_hypo, "wer": wer, } ) return wer_results # SIM Evaluation def run_sim(args): rank, test_set, ckpt_dir = args device = f"cuda:{rank}" model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type="wavlm_large", config_path=None) state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage) model.load_state_dict(state_dict["model"], strict=False) use_gpu = True if torch.cuda.is_available() else False if use_gpu: model = model.cuda(device) model.eval() sim_results = [] for gen_wav, prompt_wav, truth in tqdm(test_set): wav1, sr1 = torchaudio.load(gen_wav) wav2, sr2 = torchaudio.load(prompt_wav) if use_gpu: wav1 = wav1.cuda(device) wav2 = wav2.cuda(device) if sr1 != 16000: resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000) if use_gpu: resample1 = resample1.cuda(device) wav1 = resample1(wav1) if sr2 != 16000: resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000) if use_gpu: resample2 = resample2.cuda(device) wav2 = resample2(wav2) with torch.no_grad(): emb1 = model(wav1) emb2 = model(wav2) sim = F.cosine_similarity(emb1, emb2)[0].item() # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).") sim_results.append( { "wav": Path(gen_wav).stem, "sim": sim, } ) return sim_results ================================================ FILE: src/f5_tts/infer/README.md ================================================ # Inference The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS) and [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), or will be automatically downloaded when running inference scripts. **More checkpoints with whole community efforts can be found in [SHARED.md](SHARED.md), supporting more languages.** Currently support **30s for a single** generation, which is the **total length** (same logic if `fix_duration`) including both prompt and output audio. However, `infer_cli` and `infer_gradio` will automatically do chunk generation for longer text. Long reference audio will be **clip short to ~12s**. To avoid possible inference failures, make sure you have seen through the following instructions. - Use reference audio <12s and leave proper silence space (e.g. 1s) at the end. Otherwise there is a risk of truncating in the middle of word, leading to suboptimal generation. - Uppercased letters (best with form like K.F.C.) will be uttered letter by letter, and lowercased letters used for common words. - Add some spaces (blank: " ") or punctuations (e.g. "," ".") to explicitly introduce some pauses. - If English punctuation marks the end of a sentence, make sure there is a space " " after it. Otherwise not regarded as when chunk. - Preprocess numbers to Chinese letters if you want to have them read in Chinese, otherwise in English. - If the generation output is blank (pure silence), check for FFmpeg installation. - Try turn off `use_ema` if using an early-stage finetuned checkpoint (which goes just few updates). ## Gradio App Currently supported features: - Basic TTS with Chunk Inference - Multi-Style / Multi-Speaker Generation - Voice Chat powered by Qwen2.5-3B-Instruct - [Custom inference with more language support](SHARED.md) The cli command `f5-tts_infer-gradio` equals to `python src/f5_tts/infer/infer_gradio.py`, which launches a Gradio APP (web interface) for inference. The script will load model checkpoints from Huggingface. You can also manually download files and update the path to `load_model()` in `infer_gradio.py`. Currently only load TTS models first, will load ASR model to do transcription if `ref_text` not provided, will load LLM model if use Voice Chat. More flags options: ```bash # Automatically launch the interface in the default web browser f5-tts_infer-gradio --inbrowser # Set the root path of the application, if it's not served from the root ("/") of the domain # For example, if the application is served at "https://example.com/myapp" f5-tts_infer-gradio --root_path "/myapp" ``` Could also be used as a component for larger application: ```python import gradio as gr from f5_tts.infer.infer_gradio import app with gr.Blocks() as main_app: gr.Markdown("# This is an example of using F5-TTS within a bigger Gradio app") # ... other Gradio components app.render() main_app.launch() ``` ## CLI Inference The cli command `f5-tts_infer-cli` equals to `python src/f5_tts/infer/infer_cli.py`, which is a command line tool for inference. The script will load model checkpoints from Huggingface. You can also manually download files and use `--ckpt_file` to specify the model you want to load, or directly update in `infer_cli.py`. For change vocab.txt use `--vocab_file` to provide your `vocab.txt` file. Basically you can inference with flags: ```bash # Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage) f5-tts_infer-cli \ --model F5TTS_v1_Base \ --ref_audio "ref_audio.wav" \ --ref_text "The content, subtitle or transcription of reference audio." \ --gen_text "Some text you want TTS model generate for you." # Use BigVGAN as vocoder. Currently only support F5TTS_Base. f5-tts_infer-cli --model F5TTS_Base --vocoder_name bigvgan --load_vocoder_from_local # Use custom path checkpoint, e.g. f5-tts_infer-cli --ckpt_file ckpts/F5TTS_v1_Base/model_1250000.safetensors # More instructions f5-tts_infer-cli --help ``` And a `.toml` file would help with more flexible usage. ```bash f5-tts_infer-cli -c custom.toml ``` For example, you can use `.toml` to pass in variables, refer to `src/f5_tts/infer/examples/basic/basic.toml`: ```toml # F5TTS_v1_Base | E2TTS_Base model = "F5TTS_v1_Base" ref_audio = "infer/examples/basic/basic_ref_en.wav" # If an empty "", transcribes the reference audio automatically. ref_text = "Some call me nature, others call me mother nature." gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring." # File with text to generate. Ignores the text above. gen_file = "" remove_silence = false output_dir = "tests" ``` You can also leverage `.toml` file to do multi-style generation, refer to `src/f5_tts/infer/examples/multi/story.toml`. ```toml # F5TTS_v1_Base | E2TTS_Base model = "F5TTS_v1_Base" ref_audio = "infer/examples/multi/main.flac" # If an empty "", transcribes the reference audio automatically. ref_text = "" gen_text = "" # File with text to generate. Ignores the text above. gen_file = "infer/examples/multi/story.txt" remove_silence = true output_dir = "tests" [voices.town] ref_audio = "infer/examples/multi/town.flac" ref_text = "" [voices.country] ref_audio = "infer/examples/multi/country.flac" ref_text = "" ``` You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`. ## API Usage ```python from importlib.resources import files from f5_tts.api import F5TTS f5tts = F5TTS() wav, sr, spec = f5tts.infer( ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")), ref_text="some call me nature, others call me mother nature.", gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""", file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")), file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")), seed=None, ) ``` Check [api.py](../api.py) for more details. ## TensorRT-LLM Deployment See [detailed instructions](../runtime/triton_trtllm/README.md) for more information. ## Socket Real-time Service Real-time voice output with chunk stream: ```bash # Start socket server python src/f5_tts/socket_server.py # If PyAudio not installed sudo apt-get install portaudio19-dev pip install pyaudio # Communicate with socket client python src/f5_tts/socket_client.py ``` ## Speech Editing To test speech editing capabilities, use the following command: ```bash python src/f5_tts/infer/speech_edit.py ``` ================================================ FILE: src/f5_tts/infer/SHARED.md ================================================ # Shared Model Cards ### **Prerequisites of using** - This document is serving as a quick lookup table for the community training/finetuning result, with various language support. - The models in this repository are open source and are based on voluntary contributions from contributors. - The use of models must be conditioned on respect for the respective creators. The convenience brought comes from their efforts. ### **Welcome to share here** - Have a pretrained/finetuned result: model checkpoint (pruned best to facilitate inference, i.e. leave only `ema_model_state_dict`) and corresponding vocab file (for tokenization). - Host a public [huggingface model repository](https://huggingface.co/new) and upload the model related files. - Make a pull request adding a model card to the current page, i.e. `src\f5_tts\infer\SHARED.md`. ### Supported Languages - [Multilingual](#multilingual) - [F5-TTS v1 v0 Base @ zh \& en @ F5-TTS](#f5-tts-v1-v0-base--zh--en--f5-tts) - [Arabic](#arabic) - [F5-TTS Small @ ar & en @ SILMA AI](#f5-tts-small--ar--en--silma-ai) - [English](#english) - [Finnish](#finnish) - [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen) - [French](#french) - [F5-TTS Base @ fr @ RASPIAUDIO](#f5-tts-base--fr--raspiaudio) - [German](#german) - [F5-TTS Base @ de @ hvoss-techfak](#f5-tts-base--de--hvoss-techfak) - [Hindi](#hindi) - [F5-TTS Small @ hi @ SPRINGLab](#f5-tts-small--hi--springlab) - [Italian](#italian) - [F5-TTS Base @ it @ alien79](#f5-tts-base--it--alien79) - [Japanese](#japanese) - [F5-TTS Base @ ja @ Jmica](#f5-tts-base--ja--jmica) - [Latvian](#latvian) - [F5-TTS Base @ lv @ RaivisDejus](#f5-tts-base--lv--raivisdejus) - [Mandarin](#mandarin) - [Russian](#russian) - [F5-TTS Base @ ru @ HotDro4illa](#f5-tts-base--ru--hotdro4illa) - [Spanish](#spanish) - [F5-TTS Base @ es @ jpgallegoar](#f5-tts-base--es--jpgallegoar) ## Multilingual #### F5-TTS v1 v0 Base @ zh & en @ F5-TTS |Model|🤗Hugging Face|Data (Hours)|Model License| |:---:|:------------:|:-----------:|:-------------:| |F5-TTS v1 Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_v1_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0| ```bash Model: hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors # A Variant Model: hf://SWivid/F5-TTS/F5TTS_v1_Base_no_zero_init/model_1250000.safetensors Vocab: hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4} ``` |Model|🤗Hugging Face|Data (Hours)|Model License| |:---:|:------------:|:-----------:|:-------------:| |F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0| ```bash Model: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors Vocab: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1} ``` *Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...* ## Arabic #### F5-TTS Small @ ar & en @ SILMA AI |Model|🤗Hugging Face|Data (Hours)|Model License| |:---:|:------------:|:-----------:|:-------------:| |F5-TTS Small|[ckpt & vocab](https://huggingface.co/silma-ai/silma-tts)| Tens of thousands EN/AR |Apache-2.0| - Pretrained by [SILMA.AI](https://silma.ai) - [GitHub repo](https://github.com/SILMA-AI/silma-tts), Inference code ## English ## Finnish #### F5-TTS Base @ fi @ AsmoKoskinen |Model|🤗Hugging Face|Data|Model License| |:---:|:------------:|:-----------:|:-------------:| |F5-TTS Base|[ckpt & vocab](https://huggingface.co/AsmoKoskinen/F5-TTS_Finnish_Model)|[Common Voice](https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0), [Vox Populi](https://huggingface.co/datasets/facebook/voxpopuli)|cc-by-nc-4.0| ```bash Model: hf://AsmoKoskinen/F5-TTS_Finnish_Model/model_common_voice_fi_vox_populi_fi_20241206.safetensors Vocab: hf://AsmoKoskinen/F5-TTS_Finnish_Model/vocab.txt Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1} ``` ## French #### F5-TTS Base @ fr @ RASPIAUDIO |Model|🤗Hugging Face|Data (Hours)|Model License| |:---:|:------------:|:-----------:|:-------------:| |F5-TTS Base|[ckpt & vocab](https://huggingface.co/RASPIAUDIO/F5-French-MixedSpeakers-reduced)|[LibriVox](https://librivox.org/)|cc-by-nc-4.0| ```bash Model: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt Vocab: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1} ``` - [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french). - [Tutorial video to train a new language model](https://www.youtube.com/watch?v=UO4usaOojys). - [Discussion about this training can be found here](https://github.com/SWivid/F5-TTS/issues/434). ## German #### F5-TTS Base @ de @ hvoss-techfak |Model|🤗Hugging Face|Data (Hours)|Model License| |:---:|:------------:|:-----------:|:-------------:| |F5-TTS Base|[ckpt & vocab](https://huggingface.co/hvoss-techfak/F5-TTS-German)|[Mozilla Common Voice 19.0](https://commonvoice.mozilla.org/en/datasets) & 800 hours Crowdsourced |cc-by-nc-4.0| ```bash Model: hf://hvoss-techfak/F5-TTS-German/model_f5tts_german.pt Vocab: hf://hvoss-techfak/F5-TTS-German/vocab.txt Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1} ``` - Finetuned by [@hvoss-techfak](https://github.com/hvoss-techfak) ## Hindi #### F5-TTS Small @ hi @ SPRINGLab |Model|🤗Hugging Face|Data (Hours)|Model License| |:---:|:------------:|:-----------:|:-------------:| |F5-TTS Small|[ckpt & vocab](https://huggingface.co/SPRINGLab/F5-Hindi-24KHz)|[IndicTTS Hi](https://huggingface.co/datasets/SPRINGLab/IndicTTS-Hindi) & [IndicVoices-R Hi](https://huggingface.co/datasets/SPRINGLab/IndicVoices-R_Hindi) |cc-by-4.0| ```bash Model: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors Vocab: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1} ``` - Authors: SPRING Lab, Indian Institute of Technology, Madras - Website: https://asr.iitm.ac.in/ ## Italian #### F5-TTS Base @ it @ alien79 |Model|🤗Hugging Face|Data|Model License| |:---:|:------------:|:-----------:|:-------------:| |F5-TTS Base|[ckpt & vocab](https://huggingface.co/alien79/F5-TTS-italian)|[ylacombe/cml-tts](https://huggingface.co/datasets/ylacombe/cml-tts) |cc-by-nc-4.0| ```bash Model: hf://alien79/F5-TTS-italian/model_159600.safetensors Vocab: hf://alien79/F5-TTS-italian/vocab.txt Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1} ``` - Trained by [Mithril Man](https://github.com/MithrilMan) - Model details on [hf project home](https://huggingface.co/alien79/F5-TTS-italian) - Open to collaborations to further improve the model ## Japanese #### F5-TTS Base @ ja @ Jmica |Model|🤗Hugging Face|Data (Hours)|Model License| |:---:|:------------:|:-----------:|:-------------:| |F5-TTS Base|[ckpt & vocab](https://huggingface.co/Jmica/F5TTS/tree/main/JA_21999120)|[Emilia 1.7k JA](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07) & [Galgame Dataset 5.4k](https://huggingface.co/datasets/OOPPEENN/Galgame_Dataset)|cc-by-nc-4.0| ```bash Model: hf://Jmica/F5TTS/JA_21999120/model_21999120.pt Vocab: hf://Jmica/F5TTS/JA_21999120/vocab_japanese.txt Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1} ``` ## Latvian #### F5-TTS Base @ lv @ RaivisDejus |Model|🤗Hugging Face|Data (Hours)|Model License| |:---:|:------------:|:-----------:|:-------------:| |F5-TTS Base|[ckpt & vocab](https://huggingface.co/RaivisDejus/F5-TTS-Latvian)|[Common voice](https://datacollective.mozillafoundation.org/datasets/cmj8u3pec00flnxxbntvfb4as)|cc0-1.0| ```bash Model: hf://RaivisDejus/F5-TTS-Latvian/model.safetensors Vocab: hf://RaivisDejus/F5-TTS-Latvian/vocab.txt Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1} ``` ## Mandarin ## Russian #### F5-TTS Base @ ru @ HotDro4illa |Model|🤗Hugging Face|Data (Hours)|Model License| |:---:|:------------:|:-----------:|:-------------:| |F5-TTS Base|[ckpt & vocab](https://huggingface.co/hotstone228/F5-TTS-Russian)|[Common voice](https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0)|cc-by-nc-4.0| ```bash Model: hf://hotstone228/F5-TTS-Russian/model_last.safetensors Vocab: hf://hotstone228/F5-TTS-Russian/vocab.txt Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1} ``` - Finetuned by [HotDro4illa](https://github.com/HotDro4illa) - Any improvements are welcome ## Spanish #### F5-TTS Base @ es @ jpgallegoar |Model|🤗Hugging Face|Data (Hours)|Model License| |:---:|:------------:|:-----------:|:-------------:| |F5-TTS Base|[ckpt & vocab](https://huggingface.co/jpgallegoar/F5-Spanish)|[Voxpopuli](https://huggingface.co/datasets/facebook/voxpopuli) & Crowdsourced & TEDx, 218 hours|cc0-1.0| - @jpgallegoar [GitHub repo](https://github.com/jpgallegoar/Spanish-F5), Jupyter Notebook and Gradio usage for Spanish model. ================================================ FILE: src/f5_tts/infer/examples/basic/basic.toml ================================================ # F5TTS_v1_Base | E2TTS_Base model = "F5TTS_v1_Base" ref_audio = "infer/examples/basic/basic_ref_en.wav" # If an empty "", transcribes the reference audio automatically. ref_text = "Some call me nature, others call me mother nature." gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring." # File with text to generate. Ignores the text above. gen_file = "" remove_silence = false output_dir = "tests" output_file = "infer_cli_basic.wav" ================================================ FILE: src/f5_tts/infer/examples/multi/story.toml ================================================ # F5TTS_v1_Base | E2TTS_Base model = "F5TTS_v1_Base" ref_audio = "infer/examples/multi/main.flac" # If an empty "", transcribes the reference audio automatically. ref_text = "" gen_text = "" # File with text to generate. Ignores the text above. gen_file = "infer/examples/multi/story.txt" remove_silence = true output_dir = "tests" output_file = "infer_cli_story.wav" [voices.town] ref_audio = "infer/examples/multi/town.flac" ref_text = "" speed = 0.8 # will ignore global speed [voices.country] ref_audio = "infer/examples/multi/country.flac" ref_text = "" ================================================ FILE: src/f5_tts/infer/examples/multi/story.txt ================================================ A Town Mouse and a Country Mouse were acquaintances, and the Country Mouse one day invited his friend to come and see him at his home in the fields. The Town Mouse came, and they sat down to a dinner of barleycorns and roots, the latter of which had a distinctly earthy flavour. The fare was not much to the taste of the guest, and presently he broke out with [town] "My poor dear friend, you live here no better than the ants! Now, you should just see how I fare! My larder is a regular horn of plenty. You must come and stay with me, and I promise you you shall live on the fat of the land." [main] So when he returned to town he took the Country Mouse with him, and showed him into a larder containing flour and oatmeal and figs and honey and dates. The Country Mouse had never seen anything like it, and sat down to enjoy the luxuries his friend provided: but before they had well begun, the door of the larder opened and someone came in. The two Mice scampered off and hid themselves in a narrow and exceedingly uncomfortable hole. Presently, when all was quiet, they ventured out again; but someone else came in, and off they scuttled again. This was too much for the visitor. [country] "Goodbye," [main] said he, [country] "I'm off. You live in the lap of luxury, I can see, but you are surrounded by dangers; whereas at home I can enjoy my simple dinner of roots and corn in peace." ================================================ FILE: src/f5_tts/infer/examples/vocab.txt ================================================ ! " # $ % & ' ( ) * + , - . / 0 1 2 3 4 5 6 7 8 9 : ; = > ? @ A B C D E F G H I J K L M N O P Q R S T U V W X Y Z [ \ ] _ a a1 ai1 ai2 ai3 ai4 an1 an3 an4 ang1 ang2 ang4 ao1 ao2 ao3 ao4 b ba ba1 ba2 ba3 ba4 bai1 bai2 bai3 bai4 ban1 ban2 ban3 ban4 bang1 bang2 bang3 bang4 bao1 bao2 bao3 bao4 bei bei1 bei2 bei3 bei4 ben1 ben2 ben3 ben4 beng beng1 beng2 beng3 beng4 bi1 bi2 bi3 bi4 bian1 bian2 bian3 bian4 biao1 biao2 biao3 bie1 bie2 bie3 bie4 bin1 bin4 bing1 bing2 bing3 bing4 bo bo1 bo2 bo3 bo4 bu2 bu3 bu4 c ca1 cai1 cai2 cai3 cai4 can1 can2 can3 can4 cang1 cang2 cao1 cao2 cao3 ce4 cen1 cen2 ceng1 ceng2 ceng4 cha1 cha2 cha3 cha4 chai1 chai2 chan1 chan2 chan3 chan4 chang1 chang2 chang3 chang4 chao1 chao2 chao3 che1 che2 che3 che4 chen1 chen2 chen3 chen4 cheng1 cheng2 cheng3 cheng4 chi1 chi2 chi3 chi4 chong1 chong2 chong3 chong4 chou1 chou2 chou3 chou4 chu1 chu2 chu3 chu4 chua1 chuai1 chuai2 chuai3 chuai4 chuan1 chuan2 chuan3 chuan4 chuang1 chuang2 chuang3 chuang4 chui1 chui2 chun1 chun2 chun3 chuo1 chuo4 ci1 ci2 ci3 ci4 cong1 cong2 cou4 cu1 cu4 cuan1 cuan2 cuan4 cui1 cui3 cui4 cun1 cun2 cun4 cuo1 cuo2 cuo4 d da da1 da2 da3 da4 dai1 dai2 dai3 dai4 dan1 dan2 dan3 dan4 dang1 dang2 dang3 dang4 dao1 dao2 dao3 dao4 de de1 de2 dei3 den4 deng1 deng2 deng3 deng4 di1 di2 di3 di4 dia3 dian1 dian2 dian3 dian4 diao1 diao3 diao4 die1 die2 die4 ding1 ding2 ding3 ding4 diu1 dong1 dong3 dong4 dou1 dou2 dou3 dou4 du1 du2 du3 du4 duan1 duan2 duan3 duan4 dui1 dui4 dun1 dun3 dun4 duo1 duo2 duo3 duo4 e e1 e2 e3 e4 ei2 en1 en4 er er2 er3 er4 f fa1 fa2 fa3 fa4 fan1 fan2 fan3 fan4 fang1 fang2 fang3 fang4 fei1 fei2 fei3 fei4 fen1 fen2 fen3 fen4 feng1 feng2 feng3 feng4 fo2 fou2 fou3 fu1 fu2 fu3 fu4 g ga1 ga2 ga3 ga4 gai1 gai2 gai3 gai4 gan1 gan2 gan3 gan4 gang1 gang2 gang3 gang4 gao1 gao2 gao3 gao4 ge1 ge2 ge3 ge4 gei2 gei3 gen1 gen2 gen3 gen4 geng1 geng3 geng4 gong1 gong3 gong4 gou1 gou2 gou3 gou4 gu gu1 gu2 gu3 gu4 gua1 gua2 gua3 gua4 guai1 guai2 guai3 guai4 guan1 guan2 guan3 guan4 guang1 guang2 guang3 guang4 gui1 gui2 gui3 gui4 gun3 gun4 guo1 guo2 guo3 guo4 h ha1 ha2 ha3 hai1 hai2 hai3 hai4 han1 han2 han3 han4 hang1 hang2 hang4 hao1 hao2 hao3 hao4 he1 he2 he4 hei1 hen2 hen3 hen4 heng1 heng2 heng4 hong1 hong2 hong3 hong4 hou1 hou2 hou3 hou4 hu1 hu2 hu3 hu4 hua1 hua2 hua4 huai2 huai4 huan1 huan2 huan3 huan4 huang1 huang2 huang3 huang4 hui1 hui2 hui3 hui4 hun1 hun2 hun4 huo huo1 huo2 huo3 huo4 i j ji1 ji2 ji3 ji4 jia jia1 jia2 jia3 jia4 jian1 jian2 jian3 jian4 jiang1 jiang2 jiang3 jiang4 jiao1 jiao2 jiao3 jiao4 jie1 jie2 jie3 jie4 jin1 jin2 jin3 jin4 jing1 jing2 jing3 jing4 jiong3 jiu1 jiu2 jiu3 jiu4 ju1 ju2 ju3 ju4 juan1 juan2 juan3 juan4 jue1 jue2 jue4 jun1 jun4 k ka1 ka2 ka3 kai1 kai2 kai3 kai4 kan1 kan2 kan3 kan4 kang1 kang2 kang4 kao1 kao2 kao3 kao4 ke1 ke2 ke3 ke4 ken3 keng1 kong1 kong3 kong4 kou1 kou2 kou3 kou4 ku1 ku2 ku3 ku4 kua1 kua3 kua4 kuai3 kuai4 kuan1 kuan2 kuan3 kuang1 kuang2 kuang4 kui1 kui2 kui3 kui4 kun1 kun3 kun4 kuo4 l la la1 la2 la3 la4 lai2 lai4 lan2 lan3 lan4 lang1 lang2 lang3 lang4 lao1 lao2 lao3 lao4 le le1 le4 lei lei1 lei2 lei3 lei4 leng1 leng2 leng3 leng4 li li1 li2 li3 li4 lia3 lian2 lian3 lian4 liang2 liang3 liang4 liao1 liao2 liao3 liao4 lie1 lie2 lie3 lie4 lin1 lin2 lin3 lin4 ling2 ling3 ling4 liu1 liu2 liu3 liu4 long1 long2 long3 long4 lou1 lou2 lou3 lou4 lu1 lu2 lu3 lu4 luan2 luan3 luan4 lun1 lun2 lun4 luo1 luo2 luo3 luo4 lv2 lv3 lv4 lve3 lve4 m ma ma1 ma2 ma3 ma4 mai2 mai3 mai4 man1 man2 man3 man4 mang2 mang3 mao1 mao2 mao3 mao4 me mei2 mei3 mei4 men men1 men2 men4 meng meng1 meng2 meng3 meng4 mi1 mi2 mi3 mi4 mian2 mian3 mian4 miao1 miao2 miao3 miao4 mie1 mie4 min2 min3 ming2 ming3 ming4 miu4 mo1 mo2 mo3 mo4 mou1 mou2 mou3 mu2 mu3 mu4 n n2 na1 na2 na3 na4 nai2 nai3 nai4 nan1 nan2 nan3 nan4 nang1 nang2 nang3 nao1 nao2 nao3 nao4 ne ne2 ne4 nei3 nei4 nen4 neng2 ni1 ni2 ni3 ni4 nian1 nian2 nian3 nian4 niang2 niang4 niao2 niao3 niao4 nie1 nie4 nin2 ning2 ning3 ning4 niu1 niu2 niu3 niu4 nong2 nong4 nou4 nu2 nu3 nu4 nuan3 nuo2 nuo4 nv2 nv3 nve4 o o1 o2 ou1 ou2 ou3 ou4 p pa1 pa2 pa4 pai1 pai2 pai3 pai4 pan1 pan2 pan4 pang1 pang2 pang4 pao1 pao2 pao3 pao4 pei1 pei2 pei4 pen1 pen2 pen4 peng1 peng2 peng3 peng4 pi1 pi2 pi3 pi4 pian1 pian2 pian4 piao1 piao2 piao3 piao4 pie1 pie2 pie3 pin1 pin2 pin3 pin4 ping1 ping2 po1 po2 po3 po4 pou1 pu1 pu2 pu3 pu4 q qi1 qi2 qi3 qi4 qia1 qia3 qia4 qian1 qian2 qian3 qian4 qiang1 qiang2 qiang3 qiang4 qiao1 qiao2 qiao3 qiao4 qie1 qie2 qie3 qie4 qin1 qin2 qin3 qin4 qing1 qing2 qing3 qing4 qiong1 qiong2 qiu1 qiu2 qiu3 qu1 qu2 qu3 qu4 quan1 quan2 quan3 quan4 que1 que2 que4 qun2 r ran2 ran3 rang1 rang2 rang3 rang4 rao2 rao3 rao4 re2 re3 re4 ren2 ren3 ren4 reng1 reng2 ri4 rong1 rong2 rong3 rou2 rou4 ru2 ru3 ru4 ruan2 ruan3 rui3 rui4 run4 ruo4 s sa1 sa2 sa3 sa4 sai1 sai4 san1 san2 san3 san4 sang1 sang3 sang4 sao1 sao2 sao3 sao4 se4 sen1 seng1 sha1 sha2 sha3 sha4 shai1 shai2 shai3 shai4 shan1 shan3 shan4 shang shang1 shang3 shang4 shao1 shao2 shao3 shao4 she1 she2 she3 she4 shei2 shen1 shen2 shen3 shen4 sheng1 sheng2 sheng3 sheng4 shi shi1 shi2 shi3 shi4 shou1 shou2 shou3 shou4 shu1 shu2 shu3 shu4 shua1 shua2 shua3 shua4 shuai1 shuai3 shuai4 shuan1 shuan4 shuang1 shuang3 shui2 shui3 shui4 shun3 shun4 shuo1 shuo4 si1 si2 si3 si4 song1 song3 song4 sou1 sou3 sou4 su1 su2 su4 suan1 suan4 sui1 sui2 sui3 sui4 sun1 sun3 suo suo1 suo2 suo3 t ta1 ta2 ta3 ta4 tai1 tai2 tai4 tan1 tan2 tan3 tan4 tang1 tang2 tang3 tang4 tao1 tao2 tao3 tao4 te4 teng2 ti1 ti2 ti3 ti4 tian1 tian2 tian3 tiao1 tiao2 tiao3 tiao4 tie1 tie2 tie3 tie4 ting1 ting2 ting3 tong1 tong2 tong3 tong4 tou tou1 tou2 tou4 tu1 tu2 tu3 tu4 tuan1 tuan2 tui1 tui2 tui3 tui4 tun1 tun2 tun4 tuo1 tuo2 tuo3 tuo4 u v w wa wa1 wa2 wa3 wa4 wai1 wai3 wai4 wan1 wan2 wan3 wan4 wang1 wang2 wang3 wang4 wei1 wei2 wei3 wei4 wen1 wen2 wen3 wen4 weng1 weng4 wo1 wo2 wo3 wo4 wu1 wu2 wu3 wu4 x xi1 xi2 xi3 xi4 xia1 xia2 xia4 xian1 xian2 xian3 xian4 xiang1 xiang2 xiang3 xiang4 xiao1 xiao2 xiao3 xiao4 xie1 xie2 xie3 xie4 xin1 xin2 xin4 xing1 xing2 xing3 xing4 xiong1 xiong2 xiu1 xiu3 xiu4 xu xu1 xu2 xu3 xu4 xuan1 xuan2 xuan3 xuan4 xue1 xue2 xue3 xue4 xun1 xun2 xun4 y ya ya1 ya2 ya3 ya4 yan1 yan2 yan3 yan4 yang1 yang2 yang3 yang4 yao1 yao2 yao3 yao4 ye1 ye2 ye3 ye4 yi yi1 yi2 yi3 yi4 yin1 yin2 yin3 yin4 ying1 ying2 ying3 ying4 yo1 yong1 yong2 yong3 yong4 you1 you2 you3 you4 yu1 yu2 yu3 yu4 yuan1 yuan2 yuan3 yuan4 yue1 yue4 yun1 yun2 yun3 yun4 z za1 za2 za3 zai1 zai3 zai4 zan1 zan2 zan3 zan4 zang1 zang4 zao1 zao2 zao3 zao4 ze2 ze4 zei2 zen3 zeng1 zeng4 zha1 zha2 zha3 zha4 zhai1 zhai2 zhai3 zhai4 zhan1 zhan2 zhan3 zhan4 zhang1 zhang2 zhang3 zhang4 zhao1 zhao2 zhao3 zhao4 zhe zhe1 zhe2 zhe3 zhe4 zhen1 zhen2 zhen3 zhen4 zheng1 zheng2 zheng3 zheng4 zhi1 zhi2 zhi3 zhi4 zhong1 zhong2 zhong3 zhong4 zhou1 zhou2 zhou3 zhou4 zhu1 zhu2 zhu3 zhu4 zhua1 zhua2 zhua3 zhuai1 zhuai3 zhuai4 zhuan1 zhuan2 zhuan3 zhuan4 zhuang1 zhuang4 zhui1 zhui4 zhun1 zhun2 zhun3 zhuo1 zhuo2 zi zi1 zi2 zi3 zi4 zong1 zong2 zong3 zong4 zou1 zou2 zou3 zou4 zu1 zu2 zu3 zuan1 zuan3 zuan4 zui2 zui3 zui4 zun1 zuo zuo1 zuo2 zuo3 zuo4 { ~ ¡ ¢ £ ¥ § ¨ © « ® ¯ ° ± ² ³ ´ µ · ¹ º » ¼ ½ ¾ ¿ À Á Â Ã Ä Å Æ Ç È É Ê Í Î Ñ Ó Ö × Ø Ú Ü Ý Þ ß à á â ã ä å æ ç è é ê ë ì í î ï ð ñ ò ó ô õ ö ø ù ú û ü ý Ā ā ă ą ć Č č Đ đ ē ė ę ě ĝ ğ ħ ī į İ ı Ł ł ń ņ ň ŋ Ō ō ő œ ř Ś ś Ş ş Š š Ť ť ũ ū ź Ż ż Ž ž ơ ư ǎ ǐ ǒ ǔ ǚ ș ț ɑ ɔ ɕ ə ɛ ɜ ɡ ɣ ɪ ɫ ɴ ɹ ɾ ʃ ʊ ʌ ʒ ʔ ʰ ʷ ʻ ʾ ʿ ˈ ː ˙ ˜ ˢ ́ ̅ Α Β Δ Ε Θ Κ Λ Μ Ξ Π Σ Τ Φ Χ Ψ Ω ά έ ή ί α β γ δ ε ζ η θ ι κ λ μ ν ξ ο π ρ ς σ τ υ φ χ ψ ω ϊ ό ύ ώ ϕ ϵ Ё А Б В Г Д Е Ж З И Й К Л М Н О П Р С Т У Ф Х Ц Ч Ш Щ Ы Ь Э Ю Я а б в г д е ж з и й к л м н о п р с т у ф х ц ч ш щ ъ ы ь э ю я ё і ְ ִ ֵ ֶ ַ ָ ֹ ּ ־ ׁ א ב ג ד ה ו ז ח ט י כ ל ם מ ן נ ס ע פ ק ר ש ת أ ب ة ت ج ح د ر ز س ص ط ع ق ك ل م ن ه و ي َ ُ ِ ْ ก ข ง จ ต ท น ป ย ร ว ส ห อ ฮ ั า ี ึ โ ใ ไ ่ ้ ์ ḍ Ḥ ḥ ṁ ṃ ṅ ṇ Ṛ ṛ Ṣ ṣ Ṭ ṭ ạ ả Ấ ấ ầ ậ ắ ằ ẻ ẽ ế ề ể ễ ệ ị ọ ỏ ố ồ ộ ớ ờ ở ụ ủ ứ ữ ἀ ἁ Ἀ ἐ ἔ ἰ ἱ ὀ ὁ ὐ ὲ ὸ ᾶ ᾽ ῆ ῇ ῶ ‎ ‑ ‒ – — ― ‖ † ‡ • … ‧ ‬ ′ ″ ⁄ ⁡ ⁰ ⁴ ⁵ ⁶ ⁷ ⁸ ⁹ ₁ ₂ ₃ € ₱ ₹ ₽ ℃ ℏ ℓ № ℝ ™ ⅓ ⅔ ⅛ → ∂ ∈ ∑ − ∗ √ ∞ ∫ ≈ ≠ ≡ ≤ ≥ ⋅ ⋯ █ ♪ ⟨ ⟩ 、 。 《 》 「 」 【 】 あ う え お か が き ぎ く ぐ け げ こ ご さ し じ す ず せ ぜ そ ぞ た だ ち っ つ で と ど な に ね の は ば ひ ぶ へ べ ま み む め も ゃ や ゆ ょ よ ら り る れ ろ わ を ん ァ ア ィ イ ウ ェ エ オ カ ガ キ ク ケ ゲ コ ゴ サ ザ シ ジ ス ズ セ ゾ タ ダ チ ッ ツ テ デ ト ド ナ ニ ネ ノ バ パ ビ ピ フ プ ヘ ベ ペ ホ ボ ポ マ ミ ム メ モ ャ ヤ ュ ユ ョ ヨ ラ リ ル レ ロ ワ ン ・ ー ㄋ ㄍ ㄎ ㄏ ㄓ ㄕ ㄚ ㄜ ㄟ ㄤ ㄥ ㄧ ㄱ ㄴ ㄷ ㄹ ㅁ ㅂ ㅅ ㅈ ㅍ ㅎ ㅏ ㅓ ㅗ ㅜ ㅡ ㅣ 㗎 가 각 간 갈 감 갑 갓 갔 강 같 개 거 건 걸 겁 것 겉 게 겠 겨 결 겼 경 계 고 곤 골 곱 공 과 관 광 교 구 국 굴 귀 귄 그 근 글 금 기 긴 길 까 깍 깔 깜 깨 께 꼬 꼭 꽃 꾸 꿔 끔 끗 끝 끼 나 난 날 남 납 내 냐 냥 너 넘 넣 네 녁 년 녕 노 녹 놀 누 눈 느 는 늘 니 님 닙 다 닥 단 달 닭 당 대 더 덕 던 덥 데 도 독 동 돼 됐 되 된 될 두 둑 둥 드 들 등 디 따 딱 딸 땅 때 떤 떨 떻 또 똑 뚱 뛰 뜻 띠 라 락 란 람 랍 랑 래 랜 러 런 럼 렇 레 려 력 렵 렸 로 록 롬 루 르 른 를 름 릉 리 릴 림 마 막 만 많 말 맑 맙 맛 매 머 먹 멍 메 면 명 몇 모 목 몸 못 무 문 물 뭐 뭘 미 민 밌 밑 바 박 밖 반 받 발 밤 밥 방 배 백 밸 뱀 버 번 벌 벚 베 벼 벽 별 병 보 복 본 볼 봐 봤 부 분 불 비 빔 빛 빠 빨 뼈 뽀 뿅 쁘 사 산 살 삼 샀 상 새 색 생 서 선 설 섭 섰 성 세 셔 션 셨 소 속 손 송 수 숙 순 술 숫 숭 숲 쉬 쉽 스 슨 습 슷 시 식 신 실 싫 심 십 싶 싸 써 쓰 쓴 씌 씨 씩 씬 아 악 안 않 알 야 약 얀 양 얘 어 언 얼 엄 업 없 었 엉 에 여 역 연 염 엽 영 옆 예 옛 오 온 올 옷 옹 와 왔 왜 요 욕 용 우 운 울 웃 워 원 월 웠 위 윙 유 육 윤 으 은 을 음 응 의 이 익 인 일 읽 임 입 있 자 작 잔 잖 잘 잡 잤 장 재 저 전 점 정 제 져 졌 조 족 좀 종 좋 죠 주 준 줄 중 줘 즈 즐 즘 지 진 집 짜 짝 쩌 쪼 쪽 쫌 쭈 쯔 찌 찍 차 착 찾 책 처 천 철 체 쳐 쳤 초 촌 추 출 춤 춥 춰 치 친 칠 침 칩 칼 커 켓 코 콩 쿠 퀴 크 큰 큽 키 킨 타 태 터 턴 털 테 토 통 투 트 특 튼 틀 티 팀 파 팔 패 페 펜 펭 평 포 폭 표 품 풍 프 플 피 필 하 학 한 할 함 합 항 해 햇 했 행 허 험 형 혜 호 혼 홀 화 회 획 후 휴 흐 흔 희 히 힘 ﷺ ﷻ ! , ? � 𠮶 ================================================ FILE: src/f5_tts/infer/infer_cli.py ================================================ import argparse import codecs import os import re from datetime import datetime from importlib.resources import files from pathlib import Path import numpy as np import soundfile as sf import tomli from cached_path import cached_path from hydra.utils import get_class from omegaconf import OmegaConf from unidecode import unidecode from f5_tts.infer.utils_infer import ( cfg_strength, cross_fade_duration, device, fix_duration, infer_process, load_model, load_vocoder, mel_spec_type, nfe_step, preprocess_ref_audio_text, remove_silence_for_generated_wav, speed, sway_sampling_coef, target_rms, ) parser = argparse.ArgumentParser( prog="python3 infer-cli.py", description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.", epilog="Specify options above to override one or more settings from config.", ) parser.add_argument( "-c", "--config", type=str, default=os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"), help="The configuration file, default see infer/examples/basic/basic.toml", ) # Note. Not to provide default value here in order to read default from config file parser.add_argument( "-m", "--model", type=str, help="The model name: F5TTS_v1_Base | F5TTS_Base | E2TTS_Base | etc.", ) parser.add_argument( "-mc", "--model_cfg", type=str, help="The path to F5-TTS model config file .yaml", ) parser.add_argument( "-p", "--ckpt_file", type=str, help="The path to model checkpoint .pt, leave blank to use default", ) parser.add_argument( "-v", "--vocab_file", type=str, help="The path to vocab file .txt, leave blank to use default", ) parser.add_argument( "-r", "--ref_audio", type=str, help="The reference audio file.", ) parser.add_argument( "-s", "--ref_text", type=str, help="The transcript/subtitle for the reference audio", ) parser.add_argument( "-t", "--gen_text", type=str, help="The text to make model synthesize a speech", ) parser.add_argument( "-f", "--gen_file", type=str, help="The file with text to generate, will ignore --gen_text", ) parser.add_argument( "-o", "--output_dir", type=str, help="The path to output folder", ) parser.add_argument( "-w", "--output_file", type=str, help="The name of output file", ) parser.add_argument( "--save_chunk", action="store_true", help="To save each audio chunks during inference", ) parser.add_argument( "--no_legacy_text", action="store_false", help="Not to use lossy ASCII transliterations of unicode text in saved file names.", ) parser.add_argument( "--remove_silence", action="store_true", help="To remove long silence found in ouput", ) parser.add_argument( "--load_vocoder_from_local", action="store_true", help="To load vocoder from local dir, default to ../checkpoints/vocos-mel-24khz", ) parser.add_argument( "--vocoder_name", type=str, choices=["vocos", "bigvgan"], help=f"Used vocoder name: vocos | bigvgan, default {mel_spec_type}", ) parser.add_argument( "--target_rms", type=float, help=f"Target output speech loudness normalization value, default {target_rms}", ) parser.add_argument( "--cross_fade_duration", type=float, help=f"Duration of cross-fade between audio segments in seconds, default {cross_fade_duration}", ) parser.add_argument( "--nfe_step", type=int, help=f"The number of function evaluation (denoising steps), default {nfe_step}", ) parser.add_argument( "--cfg_strength", type=float, help=f"Classifier-free guidance strength, default {cfg_strength}", ) parser.add_argument( "--sway_sampling_coef", type=float, help=f"Sway Sampling coefficient, default {sway_sampling_coef}", ) parser.add_argument( "--speed", type=float, help=f"The speed of the generated audio, default {speed}", ) parser.add_argument( "--fix_duration", type=float, help=f"Fix the total duration (ref and gen audios) in seconds, default {fix_duration}", ) parser.add_argument( "--device", type=str, help="Specify the device to run on", ) args = parser.parse_args() # config file config = tomli.load(open(args.config, "rb")) # command-line interface parameters model = args.model or config.get("model", "F5TTS_v1_Base") ckpt_file = args.ckpt_file or config.get("ckpt_file", "") vocab_file = args.vocab_file or config.get("vocab_file", "") ref_audio = args.ref_audio or config.get("ref_audio", "infer/examples/basic/basic_ref_en.wav") ref_text = ( args.ref_text if args.ref_text is not None else config.get("ref_text", "Some call me nature, others call me mother nature.") ) gen_text = args.gen_text or config.get("gen_text", "Here we generate something just for test.") gen_file = args.gen_file or config.get("gen_file", "") output_dir = args.output_dir or config.get("output_dir", "tests") output_file = args.output_file or config.get( "output_file", f"infer_cli_{datetime.now().strftime(r'%Y%m%d_%H%M%S')}.wav" ) save_chunk = args.save_chunk or config.get("save_chunk", False) use_legacy_text = args.no_legacy_text or config.get("no_legacy_text", False) # no_legacy_text is a store_false arg if save_chunk and use_legacy_text: print( "\nWarning to --save_chunk: lossy ASCII transliterations of unicode text for legacy (.wav) file names, --no_legacy_text to disable.\n" ) remove_silence = args.remove_silence or config.get("remove_silence", False) load_vocoder_from_local = args.load_vocoder_from_local or config.get("load_vocoder_from_local", False) vocoder_name = args.vocoder_name or config.get("vocoder_name", mel_spec_type) target_rms = args.target_rms or config.get("target_rms", target_rms) cross_fade_duration = args.cross_fade_duration or config.get("cross_fade_duration", cross_fade_duration) nfe_step = args.nfe_step or config.get("nfe_step", nfe_step) cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength) sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef) speed = args.speed or config.get("speed", speed) fix_duration = args.fix_duration or config.get("fix_duration", fix_duration) device = args.device or config.get("device", device) # patches for pip pkg user if "infer/examples/" in ref_audio: ref_audio = str(files("f5_tts").joinpath(f"{ref_audio}")) if "infer/examples/" in gen_file: gen_file = str(files("f5_tts").joinpath(f"{gen_file}")) if "voices" in config: for voice in config["voices"]: voice_ref_audio = config["voices"][voice]["ref_audio"] if "infer/examples/" in voice_ref_audio: config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}")) # ignore gen_text if gen_file provided if gen_file: gen_text = codecs.open(gen_file, "r", "utf-8").read() # output path wave_path = Path(output_dir) / output_file # spectrogram_path = Path(output_dir) / "infer_cli_out.png" if save_chunk: output_chunk_dir = os.path.join(output_dir, f"{Path(output_file).stem}_chunks") if not os.path.exists(output_chunk_dir): os.makedirs(output_chunk_dir) # load vocoder if vocoder_name == "vocos": vocoder_local_path = "../checkpoints/vocos-mel-24khz" elif vocoder_name == "bigvgan": vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x" vocoder = load_vocoder( vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path, device=device ) # load TTS model model_cfg = OmegaConf.load( args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml"))) ) model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}") model_arc = model_cfg.model.arch repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors" if model != "F5TTS_Base": assert vocoder_name == model_cfg.model.mel_spec.mel_spec_type # override for previous models if model == "F5TTS_Base": if vocoder_name == "vocos": ckpt_step = 1200000 elif vocoder_name == "bigvgan": model = "F5TTS_Base_bigvgan" ckpt_type = "pt" elif model == "E2TTS_Base": repo_name = "E2-TTS" ckpt_step = 1200000 if not ckpt_file: ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}")) elif ckpt_file.startswith("hf://"): ckpt_file = str(cached_path(ckpt_file)) if vocab_file.startswith("hf://"): vocab_file = str(cached_path(vocab_file)) print(f"Using {model}...") ema_model = load_model( model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file, device=device ) # inference process def main(): main_voice = {"ref_audio": ref_audio, "ref_text": ref_text} if "voices" not in config: voices = {"main": main_voice} else: voices = config["voices"] voices["main"] = main_voice for voice in voices: print("Voice:", voice) print("ref_audio ", voices[voice]["ref_audio"]) voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text( voices[voice]["ref_audio"], voices[voice]["ref_text"] ) print("ref_audio_", voices[voice]["ref_audio"], "\n\n") generated_audio_segments = [] reg1 = r"(?=\[\w+\])" chunks = re.split(reg1, gen_text) reg2 = r"\[(\w+)\]" for text in chunks: if not text.strip(): continue match = re.match(reg2, text) if match: voice = match[1] else: print("No voice tag found, using main.") voice = "main" if voice not in voices: print(f"Voice {voice} not found, using main.") voice = "main" text = re.sub(reg2, "", text) ref_audio_ = voices[voice]["ref_audio"] ref_text_ = voices[voice]["ref_text"] local_speed = voices[voice].get("speed", speed) gen_text_ = text.strip() print(f"Voice: {voice}") audio_segment, final_sample_rate, spectrogram = infer_process( ref_audio_, ref_text_, gen_text_, ema_model, vocoder, mel_spec_type=vocoder_name, target_rms=target_rms, cross_fade_duration=cross_fade_duration, nfe_step=nfe_step, cfg_strength=cfg_strength, sway_sampling_coef=sway_sampling_coef, speed=local_speed, fix_duration=fix_duration, device=device, ) generated_audio_segments.append(audio_segment) if save_chunk: if len(gen_text_) > 200: gen_text_ = gen_text_[:200] + " ... " if use_legacy_text: gen_text_ = unidecode(gen_text_) sf.write( os.path.join(output_chunk_dir, f"{len(generated_audio_segments) - 1}_{gen_text_}.wav"), audio_segment, final_sample_rate, ) if generated_audio_segments: final_wave = np.concatenate(generated_audio_segments) if not os.path.exists(output_dir): os.makedirs(output_dir) with open(wave_path, "wb") as f: sf.write(f.name, final_wave, final_sample_rate) # Remove silence if remove_silence: remove_silence_for_generated_wav(f.name) print(f.name) if __name__ == "__main__": main() ================================================ FILE: src/f5_tts/infer/infer_gradio.py ================================================ # ruff: noqa: E402 # Above allows ruff to ignore E402: module level import not at top of file import gc import json import os import re import tempfile from collections import OrderedDict from functools import lru_cache from importlib.resources import files import click import gradio as gr import numpy as np import soundfile as sf import torch import torchaudio from cached_path import cached_path from transformers import AutoModelForCausalLM, AutoTokenizer try: import spaces USING_SPACES = True except ImportError: USING_SPACES = False def gpu_decorator(func): if USING_SPACES: return spaces.GPU(func) else: return func from f5_tts.infer.utils_infer import ( infer_process, load_model, load_vocoder, preprocess_ref_audio_text, remove_silence_for_generated_wav, save_spectrogram, tempfile_kwargs, ) from f5_tts.model import DiT, UNetT DEFAULT_TTS_MODEL = "F5-TTS_v1" tts_model_choice = DEFAULT_TTS_MODEL DEFAULT_TTS_MODEL_CFG = [ "hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors", "hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt", json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)), ] # load models vocoder = load_vocoder() def load_f5tts(): ckpt_path = str(cached_path(DEFAULT_TTS_MODEL_CFG[0])) F5TTS_model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2]) return load_model(DiT, F5TTS_model_cfg, ckpt_path) def load_e2tts(): ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors")) E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4, text_mask_padding=False, pe_attn_head=1) return load_model(UNetT, E2TTS_model_cfg, ckpt_path) def load_custom(ckpt_path: str, vocab_path="", model_cfg=None): ckpt_path, vocab_path = ckpt_path.strip(), vocab_path.strip() if ckpt_path.startswith("hf://"): ckpt_path = str(cached_path(ckpt_path)) if vocab_path.startswith("hf://"): vocab_path = str(cached_path(vocab_path)) if model_cfg is None: model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2]) elif isinstance(model_cfg, str): model_cfg = json.loads(model_cfg) return load_model(DiT, model_cfg, ckpt_path, vocab_file=vocab_path) F5TTS_ema_model = load_f5tts() E2TTS_ema_model = load_e2tts() if USING_SPACES else None custom_ema_model, pre_custom_path = None, "" chat_model_state = None chat_tokenizer_state = None @gpu_decorator def chat_model_inference(messages, model, tokenizer): """Generate response using Qwen""" text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) model_inputs = tokenizer([text], return_tensors="pt").to(model.device) generated_ids = model.generate( **model_inputs, max_new_tokens=512, temperature=0.7, top_p=0.95, ) generated_ids = [ output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) ] return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] @gpu_decorator def load_text_from_file(file): if file: with open(file, "r", encoding="utf-8") as f: text = f.read().strip() else: text = "" return gr.update(value=text) @lru_cache(maxsize=1000) # NOTE. need to ensure params of infer() hashable @gpu_decorator def infer( ref_audio_orig, ref_text, gen_text, model, remove_silence, seed, cross_fade_duration=0.15, nfe_step=32, speed=1, show_info=gr.Info, ): if not ref_audio_orig: gr.Warning("Please provide reference audio.") return gr.update(), gr.update(), ref_text # Set inference seed if seed < 0 or seed > 2**31 - 1: gr.Warning("Seed must in range 0 ~ 2147483647. Using random seed instead.") seed = np.random.randint(0, 2**31 - 1) torch.manual_seed(seed) used_seed = seed if not gen_text.strip(): gr.Warning("Please enter text to generate or upload a text file.") return gr.update(), gr.update(), ref_text ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info) if model == DEFAULT_TTS_MODEL: ema_model = F5TTS_ema_model elif model == "E2-TTS": global E2TTS_ema_model if E2TTS_ema_model is None: show_info("Loading E2-TTS model...") E2TTS_ema_model = load_e2tts() ema_model = E2TTS_ema_model elif isinstance(model, tuple) and model[0] == "Custom": assert not USING_SPACES, "Only official checkpoints allowed in Spaces." global custom_ema_model, pre_custom_path if pre_custom_path != model[1]: show_info("Loading Custom TTS model...") custom_ema_model = load_custom(model[1], vocab_path=model[2], model_cfg=model[3]) pre_custom_path = model[1] ema_model = custom_ema_model final_wave, final_sample_rate, combined_spectrogram = infer_process( ref_audio, ref_text, gen_text, ema_model, vocoder, cross_fade_duration=cross_fade_duration, nfe_step=nfe_step, speed=speed, show_info=show_info, progress=gr.Progress(), ) # Remove silence if remove_silence: with tempfile.NamedTemporaryFile(suffix=".wav", **tempfile_kwargs) as f: temp_path = f.name try: sf.write(temp_path, final_wave, final_sample_rate) remove_silence_for_generated_wav(f.name) final_wave, _ = torchaudio.load(f.name) finally: os.unlink(temp_path) final_wave = final_wave.squeeze().cpu().numpy() # Save the spectrogram with tempfile.NamedTemporaryFile(suffix=".png", **tempfile_kwargs) as tmp_spectrogram: spectrogram_path = tmp_spectrogram.name save_spectrogram(combined_spectrogram, spectrogram_path) return (final_sample_rate, final_wave), spectrogram_path, ref_text, used_seed with gr.Blocks() as app_tts: gr.Markdown("# Batched TTS") ref_audio_input = gr.Audio(label="Reference Audio", type="filepath") with gr.Row(): gen_text_input = gr.Textbox( label="Text to Generate", lines=10, max_lines=40, scale=4, ) gen_text_file = gr.File(label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1) generate_btn = gr.Button("Synthesize", variant="primary") with gr.Accordion("Advanced Settings", open=True) as adv_settn: with gr.Row(): ref_text_input = gr.Textbox( label="Reference Text", info="Leave blank to automatically transcribe the reference audio. If you enter text or upload a file, it will override automatic transcription.", lines=2, scale=4, ) ref_text_file = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"], scale=1) with gr.Row(): randomize_seed = gr.Checkbox( label="Randomize Seed", info="Check to use a random seed for each generation. Uncheck to use the seed specified.", value=True, scale=3, ) seed_input = gr.Number(show_label=False, value=0, precision=0, scale=1) with gr.Column(scale=4): remove_silence = gr.Checkbox( label="Remove Silences", info="If undesired long silence(s) produced, turn on to automatically detect and crop.", value=False, ) speed_slider = gr.Slider( label="Speed", minimum=0.3, maximum=2.0, value=1.0, step=0.1, info="Adjust the speed of the audio.", ) nfe_slider = gr.Slider( label="NFE Steps", minimum=4, maximum=64, value=32, step=2, info="Set the number of denoising steps.", ) cross_fade_duration_slider = gr.Slider( label="Cross-Fade Duration (s)", minimum=0.0, maximum=1.0, value=0.15, step=0.01, info="Set the duration of the cross-fade between audio clips.", ) def collapse_accordion(): return gr.Accordion(open=False) # Workaround for https://github.com/SWivid/F5-TTS/issues/1239#issuecomment-3677987413 # i.e. to set gr.Accordion(open=True) by default, then collapse manually Blocks loaded app_tts.load( fn=collapse_accordion, inputs=None, outputs=adv_settn, ) audio_output = gr.Audio(label="Synthesized Audio") spectrogram_output = gr.Image(label="Spectrogram") @gpu_decorator def basic_tts( ref_audio_input, ref_text_input, gen_text_input, remove_silence, randomize_seed, seed_input, cross_fade_duration_slider, nfe_slider, speed_slider, ): if randomize_seed: seed_input = np.random.randint(0, 2**31 - 1) audio_out, spectrogram_path, ref_text_out, used_seed = infer( ref_audio_input, ref_text_input, gen_text_input, tts_model_choice, remove_silence, seed=seed_input, cross_fade_duration=cross_fade_duration_slider, nfe_step=nfe_slider, speed=speed_slider, ) return audio_out, spectrogram_path, ref_text_out, used_seed gen_text_file.upload( load_text_from_file, inputs=[gen_text_file], outputs=[gen_text_input], ) ref_text_file.upload( load_text_from_file, inputs=[ref_text_file], outputs=[ref_text_input], ) ref_audio_input.clear( lambda: [None, None], None, [ref_text_input, ref_text_file], ) generate_btn.click( basic_tts, inputs=[ ref_audio_input, ref_text_input, gen_text_input, remove_silence, randomize_seed, seed_input, cross_fade_duration_slider, nfe_slider, speed_slider, ], outputs=[audio_output, spectrogram_output, ref_text_input, seed_input], ) def parse_speechtypes_text(gen_text): # Pattern to find {str} or {"name": str, "seed": int, "speed": float} pattern = r"(\{.*?\})" # Split the text by the pattern tokens = re.split(pattern, gen_text) segments = [] current_type_dict = { "name": "Regular", "seed": -1, "speed": 1.0, } for i in range(len(tokens)): if i % 2 == 0: # This is text text = tokens[i].strip() if text: current_type_dict["text"] = text segments.append(current_type_dict) else: # This is type type_str = tokens[i].strip() try: # if type dict current_type_dict = json.loads(type_str) except json.decoder.JSONDecodeError: type_str = type_str[1:-1] # remove brace {} current_type_dict = {"name": type_str, "seed": -1, "speed": 1.0} return segments with gr.Blocks() as app_multistyle: # New section for multistyle generation gr.Markdown( """ # Multiple Speech-Type Generation This section allows you to generate multiple speech types or multiple people's voices. Enter your text in the format shown below, or upload a .txt file with the same format. The system will generate speech using the appropriate type. If unspecified, the model will use the regular speech type. The current speech type will be used until the next speech type is specified. """ ) with gr.Row(): gr.Markdown( """ **Example Input:**
{Regular} Hello, I'd like to order a sandwich please.
{Surprised} What do you mean you're out of bread?
{Sad} I really wanted a sandwich though...
{Angry} You know what, darn you and your little shop!
{Whisper} I'll just go back home and cry now.
{Shouting} Why me?! """ ) gr.Markdown( """ **Example Input 2:**
{"name": "Speaker1_Happy", "seed": -1, "speed": 1} Hello, I'd like to order a sandwich please.
{"name": "Speaker2_Regular", "seed": -1, "speed": 1} Sorry, we're out of bread.
{"name": "Speaker1_Sad", "seed": -1, "speed": 1} I really wanted a sandwich though...
{"name": "Speaker2_Whisper", "seed": -1, "speed": 1} I'll give you the last one I was hiding. """ ) gr.Markdown( 'Upload different audio clips for each speech type. The first speech type is mandatory. You can add additional speech types by clicking the "Add Speech Type" button.' ) # Regular speech type (mandatory) with gr.Row(variant="compact") as regular_row: with gr.Column(scale=1, min_width=160): regular_name = gr.Textbox(value="Regular", label="Speech Type Name") regular_insert = gr.Button("Insert Label", variant="secondary") with gr.Column(scale=3): regular_audio = gr.Audio(label="Regular Reference Audio", type="filepath") with gr.Column(scale=3): regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=4) with gr.Row(): regular_seed_slider = gr.Slider( show_label=False, minimum=-1, maximum=999, value=-1, step=1, info="Seed, -1 for random" ) regular_speed_slider = gr.Slider( show_label=False, minimum=0.3, maximum=2.0, value=1.0, step=0.1, info="Adjust the speed" ) with gr.Column(scale=1, min_width=160): regular_ref_text_file = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"]) # Regular speech type (max 100) max_speech_types = 100 speech_type_rows = [regular_row] speech_type_names = [regular_name] speech_type_audios = [regular_audio] speech_type_ref_texts = [regular_ref_text] speech_type_ref_text_files = [regular_ref_text_file] speech_type_seeds = [regular_seed_slider] speech_type_speeds = [regular_speed_slider] speech_type_delete_btns = [None] speech_type_insert_btns = [regular_insert] # Additional speech types (99 more) for i in range(max_speech_types - 1): with gr.Row(variant="compact", visible=False) as row: with gr.Column(scale=1, min_width=160): name_input = gr.Textbox(label="Speech Type Name") insert_btn = gr.Button("Insert Label", variant="secondary") delete_btn = gr.Button("Delete Type", variant="stop") with gr.Column(scale=3): audio_input = gr.Audio(label="Reference Audio", type="filepath") with gr.Column(scale=3): ref_text_input = gr.Textbox(label="Reference Text", lines=4) with gr.Row(): seed_input = gr.Slider( show_label=False, minimum=-1, maximum=999, value=-1, step=1, info="Seed. -1 for random" ) speed_input = gr.Slider( show_label=False, minimum=0.3, maximum=2.0, value=1.0, step=0.1, info="Adjust the speed" ) with gr.Column(scale=1, min_width=160): ref_text_file_input = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"]) speech_type_rows.append(row) speech_type_names.append(name_input) speech_type_audios.append(audio_input) speech_type_ref_texts.append(ref_text_input) speech_type_ref_text_files.append(ref_text_file_input) speech_type_seeds.append(seed_input) speech_type_speeds.append(speed_input) speech_type_delete_btns.append(delete_btn) speech_type_insert_btns.append(insert_btn) # Global logic for all speech types for i in range(max_speech_types): speech_type_audios[i].clear( lambda: [None, None], None, [speech_type_ref_texts[i], speech_type_ref_text_files[i]], ) speech_type_ref_text_files[i].upload( load_text_from_file, inputs=[speech_type_ref_text_files[i]], outputs=[speech_type_ref_texts[i]], ) # Button to add speech type add_speech_type_btn = gr.Button("Add Speech Type") # Keep track of autoincrement of speech types, no roll back speech_type_count = 1 # Function to add a speech type def add_speech_type_fn(): row_updates = [gr.update() for _ in range(max_speech_types)] global speech_type_count if speech_type_count < max_speech_types: row_updates[speech_type_count] = gr.update(visible=True) speech_type_count += 1 else: gr.Warning("Exhausted maximum number of speech types. Consider restart the app.") return row_updates add_speech_type_btn.click(add_speech_type_fn, outputs=speech_type_rows) # Function to delete a speech type def delete_speech_type_fn(): return gr.update(visible=False), None, None, None, None # Update delete button clicks and ref text file changes for i in range(1, len(speech_type_delete_btns)): speech_type_delete_btns[i].click( delete_speech_type_fn, outputs=[ speech_type_rows[i], speech_type_names[i], speech_type_audios[i], speech_type_ref_texts[i], speech_type_ref_text_files[i], ], ) # Text input for the prompt with gr.Row(): gen_text_input_multistyle = gr.Textbox( label="Text to Generate", lines=10, max_lines=40, scale=4, placeholder="Enter the script with speaker names (or emotion types) at the start of each block, e.g.:\n\n{Regular} Hello, I'd like to order a sandwich please.\n{Surprised} What do you mean you're out of bread?\n{Sad} I really wanted a sandwich though...\n{Angry} You know what, darn you and your little shop!\n{Whisper} I'll just go back home and cry now.\n{Shouting} Why me?!", ) gen_text_file_multistyle = gr.File(label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1) def make_insert_speech_type_fn(index): def insert_speech_type_fn(current_text, speech_type_name, speech_type_seed, speech_type_speed): current_text = current_text or "" if not speech_type_name: gr.Warning("Please enter speech type name before insert.") return current_text speech_type_dict = { "name": speech_type_name, "seed": speech_type_seed, "speed": speech_type_speed, } updated_text = current_text + json.dumps(speech_type_dict) + " " return updated_text return insert_speech_type_fn for i, insert_btn in enumerate(speech_type_insert_btns): insert_fn = make_insert_speech_type_fn(i) insert_btn.click( insert_fn, inputs=[gen_text_input_multistyle, speech_type_names[i], speech_type_seeds[i], speech_type_speeds[i]], outputs=gen_text_input_multistyle, ) with gr.Accordion("Advanced Settings", open=True): with gr.Row(): with gr.Column(): show_cherrypick_multistyle = gr.Checkbox( label="Show Cherry-pick Interface", info="Turn on to show interface, picking seeds from previous generations.", value=False, ) with gr.Column(): remove_silence_multistyle = gr.Checkbox( label="Remove Silences", info="Turn on to automatically detect and crop long silences.", value=True, ) # Generate button generate_multistyle_btn = gr.Button("Generate Multi-Style Speech", variant="primary") # Output audio audio_output_multistyle = gr.Audio(label="Synthesized Audio") # Used seed gallery cherrypick_interface_multistyle = gr.Textbox( label="Cherry-pick Interface", lines=10, max_lines=40, buttons=["copy"], # show_copy_button=True if gradio<6.0 interactive=False, visible=False, ) # Logic control to show/hide the cherrypick interface show_cherrypick_multistyle.change( lambda is_visible: gr.update(visible=is_visible), show_cherrypick_multistyle, cherrypick_interface_multistyle, ) # Function to load text to generate from file gen_text_file_multistyle.upload( load_text_from_file, inputs=[gen_text_file_multistyle], outputs=[gen_text_input_multistyle], ) @gpu_decorator def generate_multistyle_speech( gen_text, *args, ): speech_type_names_list = args[:max_speech_types] speech_type_audios_list = args[max_speech_types : 2 * max_speech_types] speech_type_ref_texts_list = args[2 * max_speech_types : 3 * max_speech_types] remove_silence = args[3 * max_speech_types] # Collect the speech types and their audios into a dict speech_types = OrderedDict() ref_text_idx = 0 for name_input, audio_input, ref_text_input in zip( speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list ): if name_input and audio_input: speech_types[name_input] = {"audio": audio_input, "ref_text": ref_text_input} else: speech_types[f"@{ref_text_idx}@"] = {"audio": "", "ref_text": ""} ref_text_idx += 1 # Parse the gen_text into segments segments = parse_speechtypes_text(gen_text) # For each segment, generate speech generated_audio_segments = [] current_type_name = "Regular" inference_meta_data = "" for segment in segments: name = segment["name"] seed_input = segment["seed"] speed = segment["speed"] text = segment["text"] if name in speech_types: current_type_name = name else: gr.Warning(f"Type {name} is not available, will use Regular as default.") current_type_name = "Regular" try: ref_audio = speech_types[current_type_name]["audio"] except KeyError: gr.Warning(f"Please provide reference audio for type {current_type_name}.") return [None] + [speech_types[name]["ref_text"] for name in speech_types] + [None] ref_text = speech_types[current_type_name].get("ref_text", "") if seed_input == -1: seed_input = np.random.randint(0, 2**31 - 1) # Generate or retrieve speech for this segment audio_out, _, ref_text_out, used_seed = infer( ref_audio, ref_text, text, tts_model_choice, remove_silence, seed=seed_input, cross_fade_duration=0, speed=speed, show_info=print, # no pull to top when generating ) sr, audio_data = audio_out generated_audio_segments.append(audio_data) speech_types[current_type_name]["ref_text"] = ref_text_out inference_meta_data += json.dumps(dict(name=name, seed=used_seed, speed=speed)) + f" {text}\n" # Concatenate all audio segments if generated_audio_segments: final_audio_data = np.concatenate(generated_audio_segments) return ( [(sr, final_audio_data)] + [speech_types[name]["ref_text"] for name in speech_types] + [inference_meta_data] ) else: gr.Warning("No audio generated.") return [None] + [speech_types[name]["ref_text"] for name in speech_types] + [None] generate_multistyle_btn.click( generate_multistyle_speech, inputs=[ gen_text_input_multistyle, ] + speech_type_names + speech_type_audios + speech_type_ref_texts + [ remove_silence_multistyle, ], outputs=[audio_output_multistyle] + speech_type_ref_texts + [cherrypick_interface_multistyle], ) # Validation function to disable Generate button if speech types are missing def validate_speech_types(gen_text, regular_name, *args): speech_type_names_list = args # Collect the speech types names speech_types_available = set() if regular_name: speech_types_available.add(regular_name) for name_input in speech_type_names_list: if name_input: speech_types_available.add(name_input) # Parse the gen_text to get the speech types used segments = parse_speechtypes_text(gen_text) speech_types_in_text = set(segment["name"] for segment in segments) # Check if all speech types in text are available missing_speech_types = speech_types_in_text - speech_types_available if missing_speech_types: # Disable the generate button return gr.update(interactive=False) else: # Enable the generate button return gr.update(interactive=True) gen_text_input_multistyle.change( validate_speech_types, inputs=[gen_text_input_multistyle, regular_name] + speech_type_names, outputs=generate_multistyle_btn, ) with gr.Blocks() as app_chat: gr.Markdown( """ # Voice Chat Have a conversation with an AI using your reference voice! 1. Upload a reference audio clip and optionally its transcript (via text or .txt file). 2. Load the chat model. 3. Record your message through your microphone or type it. 4. The AI will respond using the reference voice. """ ) chat_model_name_list = [ "Qwen/Qwen2.5-3B-Instruct", "microsoft/Phi-4-mini-instruct", ] @gpu_decorator def load_chat_model(chat_model_name): show_info = gr.Info global chat_model_state, chat_tokenizer_state if chat_model_state is not None: chat_model_state = None chat_tokenizer_state = None gc.collect() torch.cuda.empty_cache() show_info(f"Loading chat model: {chat_model_name}") chat_model_state = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype="auto", device_map="auto") chat_tokenizer_state = AutoTokenizer.from_pretrained(chat_model_name) show_info(f"Chat model {chat_model_name} loaded successfully!") return gr.update(visible=False), gr.update(visible=True) if USING_SPACES: load_chat_model(chat_model_name_list[0]) chat_model_name_input = gr.Dropdown( choices=chat_model_name_list, value=chat_model_name_list[0], label="Chat Model Name", info="Enter the name of a HuggingFace chat model", allow_custom_value=not USING_SPACES, ) load_chat_model_btn = gr.Button("Load Chat Model", variant="primary", visible=not USING_SPACES) chat_interface_container = gr.Column(visible=USING_SPACES) chat_model_name_input.change( lambda: gr.update(visible=True), None, load_chat_model_btn, show_progress="hidden", ) load_chat_model_btn.click( load_chat_model, inputs=[chat_model_name_input], outputs=[load_chat_model_btn, chat_interface_container] ) with chat_interface_container: with gr.Row(): with gr.Column(): ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath") with gr.Column(): with gr.Accordion("Advanced Settings", open=False): with gr.Row(): ref_text_chat = gr.Textbox( label="Reference Text", info="Optional: Leave blank to auto-transcribe", lines=2, scale=3, ) ref_text_file_chat = gr.File( label="Load Reference Text from File (.txt)", file_types=[".txt"], scale=1 ) with gr.Row(): randomize_seed_chat = gr.Checkbox( label="Randomize Seed", value=True, info="Uncheck to use the seed specified.", scale=3, ) seed_input_chat = gr.Number(show_label=False, value=0, precision=0, scale=1) remove_silence_chat = gr.Checkbox( label="Remove Silences", value=True, ) system_prompt_chat = gr.Textbox( label="System Prompt", value="You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.", lines=2, ) chatbot_interface = gr.Chatbot( label="Conversation" ) # type="messages" hard-coded and no need to pass in since gradio 6.0 with gr.Row(): with gr.Column(): audio_input_chat = gr.Microphone( label="Speak your message", type="filepath", ) audio_output_chat = gr.Audio(autoplay=True) with gr.Column(): text_input_chat = gr.Textbox( label="Type your message", lines=1, ) send_btn_chat = gr.Button("Send Message") clear_btn_chat = gr.Button("Clear Conversation") # Modify process_audio_input to generate user input @gpu_decorator def process_audio_input(conv_state, audio_path, text): """Handle audio or text input from user""" if not audio_path and not text.strip(): return conv_state if audio_path: text = preprocess_ref_audio_text(audio_path, text)[1] if not text.strip(): return conv_state conv_state.append({"role": "user", "content": text}) return conv_state # Use model and tokenizer from state to get text response @gpu_decorator def generate_text_response(conv_state, system_prompt): """Generate text response from AI""" for single_state in conv_state: if isinstance(single_state["content"], list): assert len(single_state["content"]) == 1 and single_state["content"][0]["type"] == "text" single_state["content"] = single_state["content"][0]["text"] system_prompt_state = [{"role": "system", "content": system_prompt}] response = chat_model_inference(system_prompt_state + conv_state, chat_model_state, chat_tokenizer_state) conv_state.append({"role": "assistant", "content": response}) return conv_state @gpu_decorator def generate_audio_response(conv_state, ref_audio, ref_text, remove_silence, randomize_seed, seed_input): """Generate TTS audio for AI response""" if not conv_state or not ref_audio: return None, ref_text, seed_input last_ai_response = conv_state[-1]["content"][0]["text"] if not last_ai_response or conv_state[-1]["role"] != "assistant": return None, ref_text, seed_input if randomize_seed: seed_input = np.random.randint(0, 2**31 - 1) audio_result, _, ref_text_out, used_seed = infer( ref_audio, ref_text, last_ai_response, tts_model_choice, remove_silence, seed=seed_input, cross_fade_duration=0.15, speed=1.0, show_info=print, # show_info=print no pull to top when generating ) return audio_result, ref_text_out, used_seed def clear_conversation(): """Reset the conversation""" return [], None ref_text_file_chat.upload( load_text_from_file, inputs=[ref_text_file_chat], outputs=[ref_text_chat], ) for user_operation in [audio_input_chat.stop_recording, text_input_chat.submit, send_btn_chat.click]: user_operation( process_audio_input, inputs=[chatbot_interface, audio_input_chat, text_input_chat], outputs=[chatbot_interface], ).then( generate_text_response, inputs=[chatbot_interface, system_prompt_chat], outputs=[chatbot_interface], ).then( generate_audio_response, inputs=[ chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat, randomize_seed_chat, seed_input_chat, ], outputs=[audio_output_chat, ref_text_chat, seed_input_chat], ).then( lambda: [None, None], None, [audio_input_chat, text_input_chat], ) # Handle clear button or system prompt change and reset conversation for user_operation in [clear_btn_chat.click, system_prompt_chat.change, chatbot_interface.clear]: user_operation( clear_conversation, outputs=[chatbot_interface, audio_output_chat], ) with gr.Blocks() as app_credits: gr.Markdown(""" # Credits * [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS) * [RootingInLoad](https://github.com/RootingInLoad) for initial chunk generation and podcast app exploration * [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation & voice chat """) with gr.Blocks() as app: gr.Markdown( f""" # F5-TTS Demo Space This is {"a local web UI for [F5-TTS](https://github.com/SWivid/F5-TTS)" if not USING_SPACES else "an online demo for [F5-TTS](https://github.com/SWivid/F5-TTS)"} with advanced batch processing support. This app supports the following TTS models: * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching) * [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS) The checkpoints currently support English and Chinese. If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 12s with ✂ in the bottom right corner (otherwise might have non-optimal auto-trimmed result). **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<12s). Ensure the audio is fully uploaded before generating.** """ ) last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom_model_info_v1.txt") def load_last_used_custom(): try: custom = [] with open(last_used_custom, "r", encoding="utf-8") as f: for line in f: custom.append(line.strip()) return custom except FileNotFoundError: last_used_custom.parent.mkdir(parents=True, exist_ok=True) return DEFAULT_TTS_MODEL_CFG def switch_tts_model(new_choice): global tts_model_choice if new_choice == "Custom": # override in case webpage is refreshed custom_ckpt_path, custom_vocab_path, custom_model_cfg = load_last_used_custom() tts_model_choice = ("Custom", custom_ckpt_path, custom_vocab_path, custom_model_cfg) return ( gr.update(visible=True, value=custom_ckpt_path), gr.update(visible=True, value=custom_vocab_path), gr.update(visible=True, value=custom_model_cfg), ) else: tts_model_choice = new_choice return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) def set_custom_model(custom_ckpt_path, custom_vocab_path, custom_model_cfg): global tts_model_choice tts_model_choice = ("Custom", custom_ckpt_path, custom_vocab_path, custom_model_cfg) with open(last_used_custom, "w", encoding="utf-8") as f: f.write(custom_ckpt_path + "\n" + custom_vocab_path + "\n" + custom_model_cfg + "\n") with gr.Row(): if not USING_SPACES: choose_tts_model = gr.Radio( choices=[DEFAULT_TTS_MODEL, "E2-TTS", "Custom"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL ) else: choose_tts_model = gr.Radio( choices=[DEFAULT_TTS_MODEL, "E2-TTS"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL ) custom_ckpt_path = gr.Dropdown( choices=[DEFAULT_TTS_MODEL_CFG[0]], value=load_last_used_custom()[0], allow_custom_value=True, label="Model: local_path | hf://user_id/repo_id/model_ckpt", visible=False, ) custom_vocab_path = gr.Dropdown( choices=[DEFAULT_TTS_MODEL_CFG[1]], value=load_last_used_custom()[1], allow_custom_value=True, label="Vocab: local_path | hf://user_id/repo_id/vocab_file", visible=False, ) custom_model_cfg = gr.Dropdown( choices=[ DEFAULT_TTS_MODEL_CFG[2], json.dumps( dict( dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, text_mask_padding=False, conv_layers=4, pe_attn_head=1, ) ), json.dumps( dict( dim=768, depth=18, heads=12, ff_mult=2, text_dim=512, text_mask_padding=False, conv_layers=4, pe_attn_head=1, ) ), ], value=load_last_used_custom()[2], allow_custom_value=True, label="Config: in a dictionary form", visible=False, ) choose_tts_model.change( switch_tts_model, inputs=[choose_tts_model], outputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg], show_progress="hidden", ) custom_ckpt_path.change( set_custom_model, inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg], show_progress="hidden", ) custom_vocab_path.change( set_custom_model, inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg], show_progress="hidden", ) custom_model_cfg.change( set_custom_model, inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg], show_progress="hidden", ) gr.TabbedInterface( [app_tts, app_multistyle, app_chat, app_credits], ["Basic-TTS", "Multi-Speech", "Voice-Chat", "Credits"], ) @click.command() @click.option("--port", "-p", default=None, type=int, help="Port to run the app on") @click.option("--host", "-H", default=None, help="Host to run the app on") @click.option( "--share", "-s", default=False, is_flag=True, help="Share the app via Gradio share link", ) @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access") @click.option( "--root_path", "-r", default=None, type=str, help='The root path (or "mount point") of the application, if it\'s not served from the root ("/") of the domain. Often used when the application is behind a reverse proxy that forwards requests to the application, e.g. set "/myapp" or full URL for application served at "https://example.com/myapp".', ) @click.option( "--inbrowser", "-i", is_flag=True, default=False, help="Automatically launch the interface in the default web browser", ) def main(port, host, share, api, root_path, inbrowser): global app print("Starting app...") app.queue(api_open=api).launch( server_name=host, server_port=port, share=share, root_path=root_path, inbrowser=inbrowser, ) if __name__ == "__main__": if not USING_SPACES: main() else: app.queue().launch() ================================================ FILE: src/f5_tts/infer/speech_edit.py ================================================ import os os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility from importlib.resources import files import torch import torch.nn.functional as F import torchaudio from cached_path import cached_path from hydra.utils import get_class from omegaconf import OmegaConf from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram from f5_tts.model import CFM from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer device = ( "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) # ---------------------- infer setting ---------------------- # seed = None # int | None exp_name = "F5TTS_v1_Base" # F5TTS_v1_Base | E2TTS_Base ckpt_step = 1250000 nfe_step = 32 # 16, 32 cfg_strength = 2.0 ode_method = "euler" # euler | midpoint sway_sampling_coef = -1.0 speed = 1.0 target_rms = 0.1 model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml"))) model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}") model_arc = model_cfg.model.arch dataset_name = model_cfg.datasets.name tokenizer = model_cfg.model.tokenizer mel_spec_type = model_cfg.model.mel_spec.mel_spec_type target_sample_rate = model_cfg.model.mel_spec.target_sample_rate n_mel_channels = model_cfg.model.mel_spec.n_mel_channels hop_length = model_cfg.model.mel_spec.hop_length win_length = model_cfg.model.mel_spec.win_length n_fft = model_cfg.model.mel_spec.n_fft # ckpt_path = str(files("f5_tts").joinpath("../../")) + f"/ckpts/{exp_name}/model_{ckpt_step}.safetensors" ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.safetensors")) output_dir = "tests" # [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment] # pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git # [write the origin_text into a file, e.g. tests/test_edit.txt] # ctc-forced-aligner --audio_path "src/f5_tts/infer/examples/basic/basic_ref_en.wav" --text_path "tests/test_edit.txt" --language "zho" --romanize --split_size "char" # [result will be saved at same path of audio file] # [--language "zho" for Chinese, "eng" for English] # [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"] audio_to_edit = str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")) origin_text = "Some call me nature, others call me mother nature." target_text = "Some call me optimist, others call me realist." parts_to_edit = [ [1.42, 2.44], [4.04, 4.9], ] # stard_ends of "nature" & "mother nature", in seconds fix_duration = [ 1.2, 1, ] # fix duration for "optimist" & "realist", in seconds # audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_zh.wav" # origin_text = "对,这就是我,万人敬仰的太乙真人。" # target_text = "对,那就是你,万人敬仰的太白金星。" # parts_to_edit = [[0.84, 1.4], [1.92, 2.4], [4.26, 6.26], ] # fix_duration = None # use origin text duration # audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_zh.wav" # origin_text = "对,这就是我,万人敬仰的太乙真人。" # target_text = "对,这就是你,万人敬仰的李白金星。" # parts_to_edit = [[1.500, 2.784], [4.083, 6.760]] # fix_duration = [1.284, 2.677] # -------------------------------------------------# use_ema = True if not os.path.exists(output_dir): os.makedirs(output_dir) # Vocoder model local = False if mel_spec_type == "vocos": vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz" elif mel_spec_type == "bigvgan": vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x" vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path) # Tokenizer vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer) # Model model = CFM( transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels), mel_spec_kwargs=dict( n_fft=n_fft, hop_length=hop_length, win_length=win_length, n_mel_channels=n_mel_channels, target_sample_rate=target_sample_rate, mel_spec_type=mel_spec_type, ), odeint_kwargs=dict( method=ode_method, ), vocab_char_map=vocab_char_map, ).to(device) dtype = torch.float32 if mel_spec_type == "bigvgan" else None model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema) # Audio audio, sr = torchaudio.load(audio_to_edit) if audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True) rms = torch.sqrt(torch.mean(torch.square(audio))) if rms < target_rms: audio = audio * target_rms / rms if sr != target_sample_rate: resampler = torchaudio.transforms.Resample(sr, target_sample_rate) audio = resampler(audio) # Convert to mel spectrogram FIRST (on clean original audio) # This avoids boundary artifacts from mel windows straddling zeros and real audio audio = audio.to(device) with torch.inference_mode(): original_mel = model.mel_spec(audio) # (batch, n_mel, n_frames) original_mel = original_mel.permute(0, 2, 1) # (batch, n_frames, n_mel) # Build mel_cond and edit_mask at FRAME level # Insert zero frames in mel domain instead of zero samples in wav domain offset_frame = 0 mel_cond = torch.zeros(1, 0, n_mel_channels, device=device) edit_mask = torch.zeros(1, 0, dtype=torch.bool, device=device) fix_dur_list = fix_duration.copy() if fix_duration is not None else None for part in parts_to_edit: start, end = part part_dur_sec = end - start if fix_dur_list is None else fix_dur_list.pop(0) # Convert to frames (this is the authoritative unit) start_frame = round(start * target_sample_rate / hop_length) end_frame = round(end * target_sample_rate / hop_length) part_dur_frames = round(part_dur_sec * target_sample_rate / hop_length) # Number of frames for the kept (non-edited) region keep_frames = start_frame - offset_frame # Build mel_cond: original mel frames + zero frames for edit region mel_cond = torch.cat( ( mel_cond, original_mel[:, offset_frame:start_frame, :], torch.zeros(1, part_dur_frames, n_mel_channels, device=device), ), dim=1, ) edit_mask = torch.cat( ( edit_mask, torch.ones(1, keep_frames, dtype=torch.bool, device=device), torch.zeros(1, part_dur_frames, dtype=torch.bool, device=device), ), dim=-1, ) offset_frame = end_frame # Append remaining mel frames after last edit mel_cond = torch.cat((mel_cond, original_mel[:, offset_frame:, :]), dim=1) edit_mask = F.pad(edit_mask, (0, mel_cond.shape[1] - edit_mask.shape[-1]), value=True) # Text text_list = [target_text] if tokenizer == "pinyin": final_text_list = convert_char_to_pinyin(text_list) else: final_text_list = [text_list] print(f"text : {text_list}") print(f"pinyin: {final_text_list}") # Duration - use mel_cond length (not raw audio length) duration = mel_cond.shape[1] # Inference - pass mel_cond directly (not wav) with torch.inference_mode(): generated, trajectory = model.sample( cond=mel_cond, # Now passing mel directly, not wav text=final_text_list, duration=duration, steps=nfe_step, cfg_strength=cfg_strength, sway_sampling_coef=sway_sampling_coef, seed=seed, edit_mask=edit_mask, ) print(f"Generated mel: {generated.shape}") # Final result generated = generated.to(torch.float32) gen_mel_spec = generated.permute(0, 2, 1) if mel_spec_type == "vocos": generated_wave = vocoder.decode(gen_mel_spec).cpu() elif mel_spec_type == "bigvgan": generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu() if rms < target_rms: generated_wave = generated_wave * rms / target_rms save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png") torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave, target_sample_rate) print(f"Generated wav: {generated_wave.shape}") ================================================ FILE: src/f5_tts/infer/utils_infer.py ================================================ # A unified script for inference process # Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format import os import sys from concurrent.futures import ThreadPoolExecutor os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/") import hashlib import re import tempfile from importlib.resources import files import matplotlib matplotlib.use("Agg") import matplotlib.pylab as plt import numpy as np import torch import torchaudio import tqdm from huggingface_hub import hf_hub_download from pydub import AudioSegment, silence from transformers import pipeline from vocos import Vocos from f5_tts.model import CFM from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer _ref_audio_cache = {} _ref_text_cache = {} device = ( "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) tempfile_kwargs = {"delete_on_close": False} if sys.version_info >= (3, 12) else {"delete": False} # ----------------------------------------- target_sample_rate = 24000 n_mel_channels = 100 hop_length = 256 win_length = 1024 n_fft = 1024 mel_spec_type = "vocos" target_rms = 0.1 cross_fade_duration = 0.15 ode_method = "euler" nfe_step = 32 # 16, 32 cfg_strength = 2.0 sway_sampling_coef = -1.0 speed = 1.0 fix_duration = None # ----------------------------------------- # chunk text into smaller pieces def chunk_text(text, max_chars=135): """ Splits the input text into chunks, each with a maximum number of characters. Args: text (str): The text to be split. max_chars (int): The maximum number of characters per chunk. Returns: List[str]: A list of text chunks. """ chunks = [] current_chunk = "" # Split the text into sentences based on punctuation followed by whitespace sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text) for sentence in sentences: if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars: current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence else: if current_chunk: chunks.append(current_chunk.strip()) current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence if current_chunk: chunks.append(current_chunk.strip()) return chunks # load vocoder def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=device, hf_cache_dir=None): if vocoder_name == "vocos": # vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device) if is_local: print(f"Load vocos from local path {local_path}") config_path = f"{local_path}/config.yaml" model_path = f"{local_path}/pytorch_model.bin" else: print("Download Vocos from huggingface charactr/vocos-mel-24khz") repo_id = "charactr/vocos-mel-24khz" config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml") model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin") vocoder = Vocos.from_hparams(config_path) state_dict = torch.load(model_path, map_location="cpu", weights_only=True) from vocos.feature_extractors import EncodecFeatures if isinstance(vocoder.feature_extractor, EncodecFeatures): encodec_parameters = { "feature_extractor.encodec." + key: value for key, value in vocoder.feature_extractor.encodec.state_dict().items() } state_dict.update(encodec_parameters) vocoder.load_state_dict(state_dict) vocoder = vocoder.eval().to(device) elif vocoder_name == "bigvgan": try: from third_party.BigVGAN import bigvgan except ImportError: print("You need to follow the README to init submodule and change the BigVGAN source code.") if is_local: # download generator from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False) else: vocoder = bigvgan.BigVGAN.from_pretrained( "nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False, cache_dir=hf_cache_dir ) vocoder.remove_weight_norm() vocoder = vocoder.eval().to(device) return vocoder # load asr pipeline asr_pipe = None def initialize_asr_pipeline(device: str = device, dtype=None): if dtype is None: dtype = ( torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 7 and not torch.cuda.get_device_name().endswith("[ZLUDA]") else torch.float32 ) global asr_pipe asr_pipe = pipeline( "automatic-speech-recognition", model="openai/whisper-large-v3-turbo", torch_dtype=dtype, device=device, ) # transcribe def transcribe(ref_audio, language=None): global asr_pipe if asr_pipe is None: initialize_asr_pipeline(device=device) return asr_pipe( ref_audio, chunk_length_s=30, batch_size=128, generate_kwargs={"task": "transcribe", "language": language} if language else {"task": "transcribe"}, return_timestamps=False, )["text"].strip() # load model checkpoint for inference def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True): if dtype is None: dtype = ( torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 7 and not torch.cuda.get_device_name().endswith("[ZLUDA]") else torch.float32 ) model = model.to(dtype) ckpt_type = ckpt_path.split(".")[-1] if ckpt_type == "safetensors": from safetensors.torch import load_file checkpoint = load_file(ckpt_path, device=device) else: checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True) if use_ema: if ckpt_type == "safetensors": checkpoint = {"ema_model_state_dict": checkpoint} checkpoint["model_state_dict"] = { k.replace("ema_model.", ""): v for k, v in checkpoint["ema_model_state_dict"].items() if k not in ["initted", "step"] } # patch for backward compatibility, 305e3ea for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]: if key in checkpoint["model_state_dict"]: del checkpoint["model_state_dict"][key] model.load_state_dict(checkpoint["model_state_dict"]) else: if ckpt_type == "safetensors": checkpoint = {"model_state_dict": checkpoint} model.load_state_dict(checkpoint["model_state_dict"]) del checkpoint torch.cuda.empty_cache() return model.to(device) # load model for inference def load_model( model_cls, model_cfg, ckpt_path, mel_spec_type=mel_spec_type, vocab_file="", ode_method=ode_method, use_ema=True, device=device, ): if vocab_file == "": vocab_file = str(files("f5_tts").joinpath("infer/examples/vocab.txt")) tokenizer = "custom" print("\nvocab : ", vocab_file) print("token : ", tokenizer) print("model : ", ckpt_path, "\n") vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer) model = CFM( transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels), mel_spec_kwargs=dict( n_fft=n_fft, hop_length=hop_length, win_length=win_length, n_mel_channels=n_mel_channels, target_sample_rate=target_sample_rate, mel_spec_type=mel_spec_type, ), odeint_kwargs=dict( method=ode_method, ), vocab_char_map=vocab_char_map, ).to(device) dtype = torch.float32 if mel_spec_type == "bigvgan" else None model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema) return model def remove_silence_edges(audio, silence_threshold=-42): # Remove silence from the start non_silent_start_idx = silence.detect_leading_silence(audio, silence_threshold=silence_threshold) audio = audio[non_silent_start_idx:] # Remove silence from the end non_silent_end_duration = audio.duration_seconds for ms in reversed(audio): if ms.dBFS > silence_threshold: break non_silent_end_duration -= 0.001 trimmed_audio = audio[: int(non_silent_end_duration * 1000)] return trimmed_audio # preprocess reference audio and text def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print): show_info("Converting audio...") # Compute a hash of the reference audio file with open(ref_audio_orig, "rb") as audio_file: audio_data = audio_file.read() audio_hash = hashlib.md5(audio_data).hexdigest() global _ref_audio_cache if audio_hash in _ref_audio_cache: show_info("Using cached preprocessed reference audio...") ref_audio = _ref_audio_cache[audio_hash] else: # first pass, do preprocess with tempfile.NamedTemporaryFile(suffix=".wav", **tempfile_kwargs) as f: temp_path = f.name aseg = AudioSegment.from_file(ref_audio_orig) # 1. try to find long silence for clipping non_silent_segs = silence.split_on_silence( aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10 ) non_silent_wave = AudioSegment.silent(duration=0) for non_silent_seg in non_silent_segs: if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000: show_info("Audio is over 12s, clipping short. (1)") break non_silent_wave += non_silent_seg # 2. try to find short silence for clipping if 1. failed if len(non_silent_wave) > 12000: non_silent_segs = silence.split_on_silence( aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10 ) non_silent_wave = AudioSegment.silent(duration=0) for non_silent_seg in non_silent_segs: if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000: show_info("Audio is over 12s, clipping short. (2)") break non_silent_wave += non_silent_seg aseg = non_silent_wave # 3. if no proper silence found for clipping if len(aseg) > 12000: aseg = aseg[:12000] show_info("Audio is over 12s, clipping short. (3)") aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50) aseg.export(temp_path, format="wav") ref_audio = temp_path # Cache the processed reference audio _ref_audio_cache[audio_hash] = ref_audio if not ref_text.strip(): global _ref_text_cache if audio_hash in _ref_text_cache: # Use cached asr transcription show_info("Using cached reference text...") ref_text = _ref_text_cache[audio_hash] else: show_info("No reference text provided, transcribing reference audio...") ref_text = transcribe(ref_audio) # Cache the transcribed text (not caching custom ref_text, enabling users to do manual tweak) _ref_text_cache[audio_hash] = ref_text else: show_info("Using custom reference text...") # Ensure ref_text ends with a proper sentence-ending punctuation if not ref_text.endswith(". ") and not ref_text.endswith("。"): if ref_text.endswith("."): ref_text += " " else: ref_text += ". " print("\nref_text ", ref_text) return ref_audio, ref_text # infer process: chunk text -> infer batches [i.e. infer_batch_process()] def infer_process( ref_audio, ref_text, gen_text, model_obj, vocoder, mel_spec_type=mel_spec_type, show_info=print, progress=tqdm, target_rms=target_rms, cross_fade_duration=cross_fade_duration, nfe_step=nfe_step, cfg_strength=cfg_strength, sway_sampling_coef=sway_sampling_coef, speed=speed, fix_duration=fix_duration, device=device, ): # Split the input text into batches audio, sr = torchaudio.load(ref_audio) max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr) * speed) gen_text_batches = chunk_text(gen_text, max_chars=max_chars) for i, gen_text in enumerate(gen_text_batches): print(f"gen_text {i}", gen_text) print("\n") show_info(f"Generating audio in {len(gen_text_batches)} batches...") return next( infer_batch_process( (audio, sr), ref_text, gen_text_batches, model_obj, vocoder, mel_spec_type=mel_spec_type, progress=progress, target_rms=target_rms, cross_fade_duration=cross_fade_duration, nfe_step=nfe_step, cfg_strength=cfg_strength, sway_sampling_coef=sway_sampling_coef, speed=speed, fix_duration=fix_duration, device=device, ) ) # infer batches def infer_batch_process( ref_audio, ref_text, gen_text_batches, model_obj, vocoder, mel_spec_type="vocos", progress=tqdm, target_rms=0.1, cross_fade_duration=0.15, nfe_step=32, cfg_strength=2.0, sway_sampling_coef=-1, speed=1, fix_duration=None, device=None, streaming=False, chunk_size=2048, ): audio, sr = ref_audio if audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True) rms = torch.sqrt(torch.mean(torch.square(audio))) if rms < target_rms: audio = audio * target_rms / rms if sr != target_sample_rate: resampler = torchaudio.transforms.Resample(sr, target_sample_rate) audio = resampler(audio) audio = audio.to(device) generated_waves = [] spectrograms = [] if len(ref_text[-1].encode("utf-8")) == 1: ref_text = ref_text + " " def process_batch(gen_text): local_speed = speed if len(gen_text.encode("utf-8")) < 10: local_speed = 0.3 # Prepare the text text_list = [ref_text + gen_text] final_text_list = convert_char_to_pinyin(text_list) ref_audio_len = audio.shape[-1] // hop_length if fix_duration is not None: duration = int(fix_duration * target_sample_rate / hop_length) else: # Calculate duration ref_text_len = len(ref_text.encode("utf-8")) gen_text_len = len(gen_text.encode("utf-8")) duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / local_speed) # inference with torch.inference_mode(): generated, _ = model_obj.sample( cond=audio, text=final_text_list, duration=duration, steps=nfe_step, cfg_strength=cfg_strength, sway_sampling_coef=sway_sampling_coef, ) del _ generated = generated.to(torch.float32) # generated mel spectrogram generated = generated[:, ref_audio_len:, :] generated = generated.permute(0, 2, 1) if mel_spec_type == "vocos": generated_wave = vocoder.decode(generated) elif mel_spec_type == "bigvgan": generated_wave = vocoder(generated) if rms < target_rms: generated_wave = generated_wave * rms / target_rms # wav -> numpy generated_wave = generated_wave.squeeze().cpu().numpy() if streaming: for j in range(0, len(generated_wave), chunk_size): yield generated_wave[j : j + chunk_size], target_sample_rate else: generated_cpu = generated[0].cpu().numpy() del generated yield generated_wave, generated_cpu if streaming: for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches: for chunk in process_batch(gen_text): yield chunk else: with ThreadPoolExecutor() as executor: futures = [executor.submit(process_batch, gen_text) for gen_text in gen_text_batches] for future in progress.tqdm(futures) if progress is not None else futures: result = future.result() if result: generated_wave, generated_mel_spec = next(result) generated_waves.append(generated_wave) spectrograms.append(generated_mel_spec) if generated_waves: if cross_fade_duration <= 0: # Simply concatenate final_wave = np.concatenate(generated_waves) else: # Combine all generated waves with cross-fading final_wave = generated_waves[0] for i in range(1, len(generated_waves)): prev_wave = final_wave next_wave = generated_waves[i] # Calculate cross-fade samples, ensuring it does not exceed wave lengths cross_fade_samples = int(cross_fade_duration * target_sample_rate) cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave)) if cross_fade_samples <= 0: # No overlap possible, concatenate final_wave = np.concatenate([prev_wave, next_wave]) continue # Overlapping parts prev_overlap = prev_wave[-cross_fade_samples:] next_overlap = next_wave[:cross_fade_samples] # Fade out and fade in fade_out = np.linspace(1, 0, cross_fade_samples) fade_in = np.linspace(0, 1, cross_fade_samples) # Cross-faded overlap cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in # Combine new_wave = np.concatenate( [prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]] ) final_wave = new_wave # Create a combined spectrogram combined_spectrogram = np.concatenate(spectrograms, axis=1) yield final_wave, target_sample_rate, combined_spectrogram else: yield None, target_sample_rate, None # remove silence from generated wav def remove_silence_for_generated_wav(filename): aseg = AudioSegment.from_file(filename) non_silent_segs = silence.split_on_silence( aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500, seek_step=10 ) non_silent_wave = AudioSegment.silent(duration=0) for non_silent_seg in non_silent_segs: non_silent_wave += non_silent_seg aseg = non_silent_wave aseg.export(filename, format="wav") # save spectrogram def save_spectrogram(spectrogram, path): plt.figure(figsize=(12, 4)) plt.imshow(spectrogram, origin="lower", aspect="auto") plt.colorbar() plt.savefig(path) plt.close() ================================================ FILE: src/f5_tts/model/__init__.py ================================================ from f5_tts.model.backbones.dit import DiT from f5_tts.model.backbones.mmdit import MMDiT from f5_tts.model.backbones.unett import UNetT from f5_tts.model.cfm import CFM from f5_tts.model.trainer import Trainer __all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"] ================================================ FILE: src/f5_tts/model/backbones/README.md ================================================ ## Backbones quick introduction ### unett.py - flat unet transformer - structure same as in e2-tts & voicebox paper except using rotary pos emb - possible abs pos emb & convnextv2 blocks for embedded text before concat ### dit.py - adaln-zero dit - embedded timestep as condition - concatted noised_input + masked_cond + embedded_text, linear proj in - possible abs pos emb & convnextv2 blocks for embedded text before concat - possible long skip connection (first layer to last layer) ### mmdit.py - stable diffusion 3 block structure - timestep as condition - left stream: text embedded and applied a abs pos emb - right stream: masked_cond & noised_input concatted and with same conv pos emb as unett ================================================ FILE: src/f5_tts/model/backbones/dit.py ================================================ """ ein notation: b - batch n - sequence nt - text sequence nw - raw wave length d - dimension """ # ruff: noqa: F722 F821 from __future__ import annotations import torch import torch.nn.functional as F from torch import nn from x_transformers.x_transformers import RotaryEmbedding from f5_tts.model.modules import ( AdaLayerNorm_Final, ConvNeXtV2Block, ConvPositionEmbedding, DiTBlock, TimestepEmbedding, precompute_freqs_cis, ) # Text embedding class TextEmbedding(nn.Module): def __init__( self, text_num_embeds, text_dim, mask_padding=True, average_upsampling=False, conv_layers=0, conv_mult=2 ): super().__init__() self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token self.mask_padding = mask_padding # mask filler and batch padding tokens or not self.average_upsampling = average_upsampling # zipvoice-style text late average upsampling (after text encoder) if average_upsampling: assert mask_padding, "text_embedding_average_upsampling requires text_mask_padding to be True" if conv_layers > 0: self.extra_modeling = True self.precompute_max_pos = 8192 # 8192 is ~87.38s of 24khz audio; 4096 is ~43.69s of 24khz audio self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False) self.text_blocks = nn.Sequential( *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)] ) else: self.extra_modeling = False def average_upsample_text_by_mask(self, text, text_mask, target_lens): batch, max_seq_len, text_dim = text.shape text_lens = text_mask.sum(dim=1) # [batch] upsampled_text = torch.zeros_like(text) for i in range(batch): text_len = int(text_lens[i].item()) audio_len = int(target_lens[i].item()) if text_len == 0 or audio_len <= 0: continue valid_ind = torch.where(text_mask[i])[0] valid_data = text[i, valid_ind, :] # [text_len, text_dim] base_repeat = audio_len // text_len remainder = audio_len % text_len indices = [] for j in range(text_len): repeat_count = base_repeat + (1 if j >= text_len - remainder else 0) indices.extend([j] * repeat_count) indices = torch.tensor(indices[:audio_len], device=text.device, dtype=torch.long) upsampled = valid_data[indices] # [audio_len, text_dim] upsampled_text[i, :audio_len, :] = upsampled return upsampled_text def forward(self, text: int["b nt"], seq_len, drop_text=False): text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() valid_pos_mask = None if torch.is_tensor(seq_len): seq_len = seq_len.to(device=text.device, dtype=torch.long) max_seq_len = int(seq_len.max().item()) else: max_seq_len = int(seq_len) text = text[:, :max_seq_len] # curtail if character tokens are more than the mel spec tokens text = F.pad(text, (0, max_seq_len - text.shape[1]), value=0) if torch.is_tensor(seq_len): seq_pos = torch.arange(max_seq_len, device=text.device).unsqueeze(0) valid_pos_mask = seq_pos < seq_len.unsqueeze(1) text = text.masked_fill(~valid_pos_mask, 0) if self.mask_padding: text_mask = text == 0 if drop_text: # cfg for text text = torch.zeros_like(text) text = self.text_embed(text) # b n -> b n d if valid_pos_mask is not None: # Keep short-sample tail strictly zero (equivalent to per-sample pad_sequence(..., 0)). text = text.masked_fill(~valid_pos_mask.unsqueeze(-1), 0.0) # possible extra modeling if self.extra_modeling: # sinus pos emb; for variable seq lengths, only add positions within each sample's valid range. freqs = self.freqs_cis[:max_seq_len, :] if valid_pos_mask is not None: freqs = freqs.unsqueeze(0) * valid_pos_mask.unsqueeze(-1).to(freqs.dtype) text = text + freqs # convnextv2 blocks if self.mask_padding: text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0) for block in self.text_blocks: text = block(text) text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0) else: text = self.text_blocks(text) if self.average_upsampling: if torch.is_tensor(seq_len): target_lens = seq_len.to(device=text.device, dtype=torch.long) else: target_lens = torch.full((text.shape[0],), int(seq_len), device=text.device, dtype=torch.long) text = self.average_upsample_text_by_mask(text, ~text_mask, target_lens) return text # noised input audio and context mixing embedding class InputEmbedding(nn.Module): def __init__(self, mel_dim, text_dim, out_dim): super().__init__() self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim) self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim) def forward( self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False, audio_mask: bool["b n"] | None = None, ): if drop_audio_cond: # cfg for cond audio cond = torch.zeros_like(cond) x = self.proj(torch.cat((x, cond, text_embed), dim=-1)) x = self.conv_pos_embed(x, mask=audio_mask) + x return x # Transformer backbone using DiT blocks class DiT(nn.Module): def __init__( self, *, dim, depth=8, heads=8, dim_head=64, dropout=0.1, ff_mult=4, mel_dim=100, text_num_embeds=256, text_dim=None, text_mask_padding=True, text_embedding_average_upsampling=False, qk_norm=None, conv_layers=0, pe_attn_head=None, attn_backend="torch", # "torch" | "flash_attn" attn_mask_enabled=False, long_skip_connection=False, checkpoint_activations=False, ): super().__init__() self.time_embed = TimestepEmbedding(dim) if text_dim is None: text_dim = mel_dim self.text_embed = TextEmbedding( text_num_embeds, text_dim, mask_padding=text_mask_padding, average_upsampling=text_embedding_average_upsampling, conv_layers=conv_layers, ) self.text_cond, self.text_uncond = None, None # text cache self.input_embed = InputEmbedding(mel_dim, text_dim, dim) self.rotary_embed = RotaryEmbedding(dim_head) self.dim = dim self.depth = depth self.transformer_blocks = nn.ModuleList( [ DiTBlock( dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout, qk_norm=qk_norm, pe_attn_head=pe_attn_head, attn_backend=attn_backend, attn_mask_enabled=attn_mask_enabled, ) for _ in range(depth) ] ) self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None self.norm_out = AdaLayerNorm_Final(dim) # final modulation self.proj_out = nn.Linear(dim, mel_dim) self.checkpoint_activations = checkpoint_activations self.initialize_weights() def initialize_weights(self): # Zero-out AdaLN layers in DiT blocks: for block in self.transformer_blocks: nn.init.constant_(block.attn_norm.linear.weight, 0) nn.init.constant_(block.attn_norm.linear.bias, 0) # Zero-out output layers: nn.init.constant_(self.norm_out.linear.weight, 0) nn.init.constant_(self.norm_out.linear.bias, 0) nn.init.constant_(self.proj_out.weight, 0) nn.init.constant_(self.proj_out.bias, 0) def ckpt_wrapper(self, module): # https://github.com/chuanyangjin/fast-DiT/blob/main/models.py def ckpt_forward(*inputs): outputs = module(*inputs) return outputs return ckpt_forward def get_input_embed( self, x, # b n d cond, # b n d text, # b nt drop_audio_cond: bool = False, drop_text: bool = False, cache: bool = True, audio_mask: bool["b n"] | None = None, ): if self.text_uncond is None or self.text_cond is None or not cache: if audio_mask is None: seq_len = x.shape[1] else: seq_len = audio_mask.sum(dim=1) # per-sample valid speech length text_embed = self.text_embed(text, seq_len=seq_len, drop_text=drop_text) if cache: if drop_text: self.text_uncond = text_embed else: self.text_cond = text_embed if cache: if drop_text: text_embed = self.text_uncond else: text_embed = self.text_cond x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond, audio_mask=audio_mask) return x def clear_cache(self): self.text_cond, self.text_uncond = None, None def forward( self, x: float["b n d"], # nosied input audio cond: float["b n d"], # masked cond audio text: int["b nt"], # text time: float["b"] | float[""], # time step mask: bool["b n"] | None = None, drop_audio_cond: bool = False, # cfg for cond audio drop_text: bool = False, # cfg for text cfg_infer: bool = False, # cfg inference, pack cond & uncond forward cache: bool = False, ): batch, seq_len = x.shape[0], x.shape[1] if time.ndim == 0: time = time.repeat(batch) # t: conditioning time, text: text, x: noised audio + cond audio + text t = self.time_embed(time) if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d x_cond = self.get_input_embed( x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache, audio_mask=mask ) x_uncond = self.get_input_embed( x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache, audio_mask=mask ) x = torch.cat((x_cond, x_uncond), dim=0) t = torch.cat((t, t), dim=0) mask = torch.cat((mask, mask), dim=0) if mask is not None else None else: x = self.get_input_embed( x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache, audio_mask=mask ) rope = self.rotary_embed.forward_from_seq_len(seq_len) if self.long_skip_connection is not None: residual = x for block in self.transformer_blocks: if self.checkpoint_activations: # https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False) else: x = block(x, t, mask=mask, rope=rope) if self.long_skip_connection is not None: x = self.long_skip_connection(torch.cat((x, residual), dim=-1)) x = self.norm_out(x, t) output = self.proj_out(x) return output ================================================ FILE: src/f5_tts/model/backbones/mmdit.py ================================================ """ ein notation: b - batch n - sequence nt - text sequence nw - raw wave length d - dimension """ # ruff: noqa: F722 F821 from __future__ import annotations import torch from torch import nn from x_transformers.x_transformers import RotaryEmbedding from f5_tts.model.modules import ( AdaLayerNorm_Final, ConvPositionEmbedding, MMDiTBlock, TimestepEmbedding, get_pos_embed_indices, precompute_freqs_cis, ) # text embedding class TextEmbedding(nn.Module): def __init__(self, out_dim, text_num_embeds, mask_padding=True): super().__init__() self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token self.mask_padding = mask_padding # mask filler and batch padding tokens or not self.precompute_max_pos = 1024 self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False) def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() if self.mask_padding: text_mask = text == 0 if drop_text: # cfg for text text = torch.zeros_like(text) text = self.text_embed(text) # b nt -> b nt d # sinus pos emb batch_start = torch.zeros((text.shape[0],), dtype=torch.long) batch_text_len = text.shape[1] pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos) text_pos_embed = self.freqs_cis[pos_idx] text = text + text_pos_embed if self.mask_padding: text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0) return text # noised input & masked cond audio embedding class AudioEmbedding(nn.Module): def __init__(self, in_dim, out_dim): super().__init__() self.linear = nn.Linear(2 * in_dim, out_dim) self.conv_pos_embed = ConvPositionEmbedding(out_dim) def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): if drop_audio_cond: cond = torch.zeros_like(cond) x = torch.cat((x, cond), dim=-1) x = self.linear(x) x = self.conv_pos_embed(x) + x return x # Transformer backbone using MM-DiT blocks class MMDiT(nn.Module): def __init__( self, *, dim, depth=8, heads=8, dim_head=64, dropout=0.1, ff_mult=4, mel_dim=100, text_num_embeds=256, text_mask_padding=True, qk_norm=None, checkpoint_activations=False, attn_backend="torch", attn_mask_enabled=False, ): super().__init__() self.time_embed = TimestepEmbedding(dim) self.text_embed = TextEmbedding(dim, text_num_embeds, mask_padding=text_mask_padding) self.text_cond, self.text_uncond = None, None # text cache self.audio_embed = AudioEmbedding(mel_dim, dim) self.rotary_embed = RotaryEmbedding(dim_head) self.dim = dim self.depth = depth self.transformer_blocks = nn.ModuleList( [ MMDiTBlock( dim=dim, heads=heads, dim_head=dim_head, dropout=dropout, ff_mult=ff_mult, context_pre_only=i == depth - 1, qk_norm=qk_norm, attn_backend=attn_backend, attn_mask_enabled=attn_mask_enabled, ) for i in range(depth) ] ) self.norm_out = AdaLayerNorm_Final(dim) # final modulation self.proj_out = nn.Linear(dim, mel_dim) self.checkpoint_activations = checkpoint_activations self.initialize_weights() def initialize_weights(self): # Zero-out AdaLN layers in MMDiT blocks: for block in self.transformer_blocks: nn.init.constant_(block.attn_norm_x.linear.weight, 0) nn.init.constant_(block.attn_norm_x.linear.bias, 0) nn.init.constant_(block.attn_norm_c.linear.weight, 0) nn.init.constant_(block.attn_norm_c.linear.bias, 0) # Zero-out output layers: nn.init.constant_(self.norm_out.linear.weight, 0) nn.init.constant_(self.norm_out.linear.bias, 0) nn.init.constant_(self.proj_out.weight, 0) nn.init.constant_(self.proj_out.bias, 0) def ckpt_wrapper(self, module): def ckpt_forward(*inputs): outputs = module(*inputs) return outputs return ckpt_forward def get_input_embed( self, x, # b n d cond, # b n d text, # b nt drop_audio_cond: bool = False, drop_text: bool = False, cache: bool = True, ): if cache: if drop_text: if self.text_uncond is None: self.text_uncond = self.text_embed(text, drop_text=True) c = self.text_uncond else: if self.text_cond is None: self.text_cond = self.text_embed(text, drop_text=False) c = self.text_cond else: c = self.text_embed(text, drop_text=drop_text) x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond) return x, c def clear_cache(self): self.text_cond, self.text_uncond = None, None def forward( self, x: float["b n d"], # nosied input audio cond: float["b n d"], # masked cond audio text: int["b nt"], # text time: float["b"] | float[""], # time step mask: bool["b n"] | None = None, drop_audio_cond: bool = False, # cfg for cond audio drop_text: bool = False, # cfg for text cfg_infer: bool = False, # cfg inference, pack cond & uncond forward cache: bool = False, ): batch = x.shape[0] if time.ndim == 0: time = time.repeat(batch) # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio t = self.time_embed(time) c_mask = (text + 1) != 0 # True = valid, False = padding (-1 tokens) if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d x_cond, c_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache) x_uncond, c_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache) x = torch.cat((x_cond, x_uncond), dim=0) c = torch.cat((c_cond, c_uncond), dim=0) t = torch.cat((t, t), dim=0) mask = torch.cat((mask, mask), dim=0) if mask is not None else None c_mask = torch.cat((c_mask, c_mask), dim=0) else: x, c = self.get_input_embed( x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache ) seq_len = x.shape[1] text_len = text.shape[1] rope_audio = self.rotary_embed.forward_from_seq_len(seq_len) rope_text = self.rotary_embed.forward_from_seq_len(text_len) for block in self.transformer_blocks: if self.checkpoint_activations: c, x = torch.utils.checkpoint.checkpoint( self.ckpt_wrapper(block), x, c, t, mask, rope_audio, rope_text, c_mask, use_reentrant=False ) else: c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text, c_mask=c_mask) x = self.norm_out(x, t) output = self.proj_out(x) return output ================================================ FILE: src/f5_tts/model/backbones/unett.py ================================================ """ ein notation: b - batch n - sequence nt - text sequence nw - raw wave length d - dimension """ # ruff: noqa: F722 F821 from __future__ import annotations from typing import Literal import torch import torch.nn.functional as F from torch import nn from x_transformers import RMSNorm from x_transformers.x_transformers import RotaryEmbedding from f5_tts.model.modules import ( Attention, AttnProcessor, ConvNeXtV2Block, ConvPositionEmbedding, FeedForward, TimestepEmbedding, get_pos_embed_indices, precompute_freqs_cis, ) # Text embedding class TextEmbedding(nn.Module): def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2): super().__init__() self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token self.mask_padding = mask_padding # mask filler and batch padding tokens or not if conv_layers > 0: self.extra_modeling = True self.precompute_max_pos = 4096 # ~44s of 24khz audio self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False) self.text_blocks = nn.Sequential( *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)] ) else: self.extra_modeling = False def forward(self, text: int["b nt"], seq_len, drop_text=False): text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens batch, text_len = text.shape[0], text.shape[1] text = F.pad(text, (0, seq_len - text_len), value=0) if self.mask_padding: text_mask = text == 0 if drop_text: # cfg for text text = torch.zeros_like(text) text = self.text_embed(text) # b n -> b n d # possible extra modeling if self.extra_modeling: # sinus pos emb batch_start = torch.zeros((batch,), dtype=torch.long) pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos) text_pos_embed = self.freqs_cis[pos_idx] text = text + text_pos_embed # convnextv2 blocks if self.mask_padding: text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0) for block in self.text_blocks: text = block(text) text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0) else: text = self.text_blocks(text) return text # noised input audio and context mixing embedding class InputEmbedding(nn.Module): def __init__(self, mel_dim, text_dim, out_dim): super().__init__() self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim) self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim) def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): if drop_audio_cond: # cfg for cond audio cond = torch.zeros_like(cond) x = self.proj(torch.cat((x, cond, text_embed), dim=-1)) x = self.conv_pos_embed(x) + x return x # Flat UNet Transformer backbone class UNetT(nn.Module): def __init__( self, *, dim, depth=8, heads=8, dim_head=64, dropout=0.1, ff_mult=4, mel_dim=100, text_num_embeds=256, text_dim=None, text_mask_padding=True, qk_norm=None, conv_layers=0, pe_attn_head=None, attn_backend="torch", # "torch" | "flash_attn" attn_mask_enabled=False, skip_connect_type: Literal["add", "concat", "none"] = "concat", ): super().__init__() assert depth % 2 == 0, "UNet-Transformer's depth should be even." self.time_embed = TimestepEmbedding(dim) if text_dim is None: text_dim = mel_dim self.text_embed = TextEmbedding( text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers ) self.text_cond, self.text_uncond = None, None # text cache self.input_embed = InputEmbedding(mel_dim, text_dim, dim) self.rotary_embed = RotaryEmbedding(dim_head) # transformer layers & skip connections self.dim = dim self.skip_connect_type = skip_connect_type needs_skip_proj = skip_connect_type == "concat" self.depth = depth self.layers = nn.ModuleList([]) for idx in range(depth): is_later_half = idx >= (depth // 2) attn_norm = RMSNorm(dim) attn = Attention( processor=AttnProcessor( pe_attn_head=pe_attn_head, attn_backend=attn_backend, attn_mask_enabled=attn_mask_enabled, ), dim=dim, heads=heads, dim_head=dim_head, dropout=dropout, qk_norm=qk_norm, ) ff_norm = RMSNorm(dim) ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") skip_proj = nn.Linear(dim * 2, dim, bias=False) if needs_skip_proj and is_later_half else None self.layers.append( nn.ModuleList( [ skip_proj, attn_norm, attn, ff_norm, ff, ] ) ) self.norm_out = RMSNorm(dim) self.proj_out = nn.Linear(dim, mel_dim) def get_input_embed( self, x, # b n d cond, # b n d text, # b nt drop_audio_cond: bool = False, drop_text: bool = False, cache: bool = True, ): seq_len = x.shape[1] if cache: if drop_text: if self.text_uncond is None: self.text_uncond = self.text_embed(text, seq_len, drop_text=True) text_embed = self.text_uncond else: if self.text_cond is None: self.text_cond = self.text_embed(text, seq_len, drop_text=False) text_embed = self.text_cond else: text_embed = self.text_embed(text, seq_len, drop_text=drop_text) x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond) return x def clear_cache(self): self.text_cond, self.text_uncond = None, None def forward( self, x: float["b n d"], # nosied input audio cond: float["b n d"], # masked cond audio text: int["b nt"], # text time: float["b"] | float[""], # time step mask: bool["b n"] | None = None, drop_audio_cond: bool = False, # cfg for cond audio drop_text: bool = False, # cfg for text cfg_infer: bool = False, # cfg inference, pack cond & uncond forward cache: bool = False, ): batch, seq_len = x.shape[0], x.shape[1] if time.ndim == 0: time = time.repeat(batch) # t: conditioning time, c: context (text + masked cond audio), x: noised input audio t = self.time_embed(time) if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d x_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache) x_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache) x = torch.cat((x_cond, x_uncond), dim=0) t = torch.cat((t, t), dim=0) mask = torch.cat((mask, mask), dim=0) if mask is not None else None else: x = self.get_input_embed(x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache) # postfix time t to input x, [b n d] -> [b n+1 d] x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x if mask is not None: mask = F.pad(mask, (1, 0), value=1) rope = self.rotary_embed.forward_from_seq_len(seq_len + 1) # flat unet transformer skip_connect_type = self.skip_connect_type skips = [] for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(self.layers): layer = idx + 1 # skip connection logic is_first_half = layer <= (self.depth // 2) is_later_half = not is_first_half if is_first_half: skips.append(x) if is_later_half: skip = skips.pop() if skip_connect_type == "concat": x = torch.cat((x, skip), dim=-1) x = maybe_skip_proj(x) elif skip_connect_type == "add": x = x + skip # attention and feedforward blocks x = attn(attn_norm(x), rope=rope, mask=mask) + x x = ff(ff_norm(x)) + x assert len(skips) == 0 x = self.norm_out(x)[:, 1:, :] # unpack t from x return self.proj_out(x) ================================================ FILE: src/f5_tts/model/cfm.py ================================================ """ ein notation: b - batch n - sequence nt - text sequence nw - raw wave length d - dimension """ # ruff: noqa: F722 F821 from __future__ import annotations from random import random from typing import Callable import torch import torch.nn.functional as F from torch import nn from torch.nn.utils.rnn import pad_sequence from torchdiffeq import odeint from f5_tts.model.modules import MelSpec from f5_tts.model.utils import ( default, exists, get_epss_timesteps, lens_to_mask, list_str_to_idx, list_str_to_tensor, mask_from_frac_lengths, ) class CFM(nn.Module): def __init__( self, transformer: nn.Module, sigma=0.0, odeint_kwargs: dict = dict( # atol = 1e-5, # rtol = 1e-5, method="euler" # 'midpoint' ), audio_drop_prob=0.3, cond_drop_prob=0.2, num_channels=None, mel_spec_module: nn.Module | None = None, mel_spec_kwargs: dict = dict(), frac_lengths_mask: tuple[float, float] = (0.7, 1.0), vocab_char_map: dict[str:int] | None = None, ): super().__init__() self.frac_lengths_mask = frac_lengths_mask # mel spec self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs)) num_channels = default(num_channels, self.mel_spec.n_mel_channels) self.num_channels = num_channels # classifier-free guidance self.audio_drop_prob = audio_drop_prob self.cond_drop_prob = cond_drop_prob # transformer self.transformer = transformer dim = transformer.dim self.dim = dim # conditional flow related self.sigma = sigma # sampling related self.odeint_kwargs = odeint_kwargs # vocab map for tokenization self.vocab_char_map = vocab_char_map @property def device(self): return next(self.parameters()).device @torch.no_grad() def sample( self, cond: float["b n d"] | float["b nw"], text: int["b nt"] | list[str], duration: int | int["b"], *, lens: int["b"] | None = None, steps=32, cfg_strength=1.0, sway_sampling_coef=None, seed: int | None = None, max_duration=65536, vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, use_epss=True, no_ref_audio=False, duplicate_test=False, t_inter=0.1, edit_mask=None, ): self.eval() # raw wave if cond.ndim == 2: cond = self.mel_spec(cond) cond = cond.permute(0, 2, 1) assert cond.shape[-1] == self.num_channels cond = cond.to(next(self.parameters()).dtype) batch, cond_seq_len, device = *cond.shape[:2], cond.device if not exists(lens): lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long) # text if isinstance(text, list): if exists(self.vocab_char_map): text = list_str_to_idx(text, self.vocab_char_map).to(device) else: text = list_str_to_tensor(text).to(device) assert text.shape[0] == batch # duration cond_mask = lens_to_mask(lens) if edit_mask is not None: cond_mask = cond_mask & edit_mask if isinstance(duration, int): duration = torch.full((batch,), duration, device=device, dtype=torch.long) duration = torch.maximum( torch.maximum((text != -1).sum(dim=-1), lens) + 1, duration ) # duration at least text/audio prompt length plus one token, so something is generated duration = duration.clamp(max=max_duration) max_duration = duration.amax() # duplicate test corner for inner time step oberservation if duplicate_test: test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0) cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0) if no_ref_audio: cond = torch.zeros_like(cond) cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False) cond_mask = cond_mask.unsqueeze(-1) step_cond = torch.where( cond_mask, cond, torch.zeros_like(cond) ) # allow direct control (cut cond audio) with lens passed in if batch > 1: mask = lens_to_mask(duration) else: # save memory and speed up, as single inference need no mask currently mask = None # neural ode def fn(t, x): # at each step, conditioning is fixed # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) # predict flow (cond) if cfg_strength < 1e-5: pred = self.transformer( x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False, cache=True, ) return pred # predict flow (cond and uncond), for classifier-free guidance pred_cfg = self.transformer( x=x, cond=step_cond, text=text, time=t, mask=mask, cfg_infer=True, cache=True, ) pred, null_pred = torch.chunk(pred_cfg, 2, dim=0) return pred + (pred - null_pred) * cfg_strength # noise input # to make sure batch inference result is same with different batch size, and for sure single inference # still some difference maybe due to convolutional layers y0 = [] for dur in duration: if exists(seed): torch.manual_seed(seed) y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype)) y0 = pad_sequence(y0, padding_value=0, batch_first=True) t_start = 0 # duplicate test corner for inner time step oberservation if duplicate_test: t_start = t_inter y0 = (1 - t_start) * y0 + t_start * test_cond steps = int(steps * (1 - t_start)) if t_start == 0 and use_epss: # use Empirically Pruned Step Sampling for low NFE t = get_epss_timesteps(steps, device=self.device, dtype=step_cond.dtype) else: t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype) if sway_sampling_coef is not None: t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) trajectory = odeint(fn, y0, t, **self.odeint_kwargs) self.transformer.clear_cache() sampled = trajectory[-1] out = sampled out = torch.where(cond_mask, cond, out) if exists(vocoder): out = out.permute(0, 2, 1) out = vocoder(out) return out, trajectory def forward( self, inp: float["b n d"] | float["b nw"], # mel or raw wave text: int["b nt"] | list[str], *, lens: int["b"] | None = None, noise_scheduler: str | None = None, ): # handle raw wave if inp.ndim == 2: inp = self.mel_spec(inp) inp = inp.permute(0, 2, 1) assert inp.shape[-1] == self.num_channels batch, seq_len, dtype, device, _σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma # handle text as string if isinstance(text, list): if exists(self.vocab_char_map): text = list_str_to_idx(text, self.vocab_char_map).to(device) else: text = list_str_to_tensor(text).to(device) assert text.shape[0] == batch # lens and mask if not exists(lens): # if lens not acquired by trainer from collate_fn lens = torch.full((batch,), seq_len, device=device) mask = lens_to_mask(lens, length=seq_len) # get a random span to mask out for training conditionally frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask) rand_span_mask = mask_from_frac_lengths(lens, frac_lengths) if exists(mask): rand_span_mask &= mask # mel is x1 x1 = inp # x0 is gaussian noise x0 = torch.randn_like(x1) # time step time = torch.rand((batch,), dtype=dtype, device=self.device) # TODO. noise_scheduler # sample xt (φ_t(x) in the paper) t = time.unsqueeze(-1).unsqueeze(-1) φ = (1 - t) * x0 + t * x1 flow = x1 - x0 # only predict what is within the random mask span for infilling cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) # transformer and cfg training with a drop rate drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper if random() < self.cond_drop_prob: # p_uncond in voicebox paper drop_audio_cond = True drop_text = True else: drop_text = False # apply mask will use more memory; might adjust batchsize or batchsampler long sequence threshold pred = self.transformer( x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text, mask=mask ) # flow matching loss loss = F.mse_loss(pred, flow, reduction="none") loss = loss[rand_span_mask] return loss.mean(), cond, pred ================================================ FILE: src/f5_tts/model/dataset.py ================================================ import json from importlib.resources import files import torch import torch.nn.functional as F import torchaudio from datasets import Dataset as Dataset_ from datasets import load_from_disk from torch import nn from torch.utils.data import Dataset, Sampler from tqdm import tqdm from f5_tts.model.modules import MelSpec from f5_tts.model.utils import default class HFDataset(Dataset): def __init__( self, hf_dataset: Dataset, target_sample_rate=24_000, n_mel_channels=100, hop_length=256, n_fft=1024, win_length=1024, mel_spec_type="vocos", ): self.data = hf_dataset self.target_sample_rate = target_sample_rate self.hop_length = hop_length self.mel_spectrogram = MelSpec( n_fft=n_fft, hop_length=hop_length, win_length=win_length, n_mel_channels=n_mel_channels, target_sample_rate=target_sample_rate, mel_spec_type=mel_spec_type, ) def get_frame_len(self, index): row = self.data[index] audio = row["audio"]["array"] sample_rate = row["audio"]["sampling_rate"] return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length def __len__(self): return len(self.data) def __getitem__(self, index): row = self.data[index] audio = row["audio"]["array"] # logger.info(f"Audio shape: {audio.shape}") sample_rate = row["audio"]["sampling_rate"] duration = audio.shape[-1] / sample_rate if duration > 30 or duration < 0.3: return self.__getitem__((index + 1) % len(self.data)) audio_tensor = torch.from_numpy(audio).float() if sample_rate != self.target_sample_rate: resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate) audio_tensor = resampler(audio_tensor) audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t') mel_spec = self.mel_spectrogram(audio_tensor) mel_spec = mel_spec.squeeze(0) # '1 d t -> d t' text = row["text"] return dict( mel_spec=mel_spec, text=text, ) class CustomDataset(Dataset): def __init__( self, custom_dataset: Dataset, durations=None, target_sample_rate=24_000, hop_length=256, n_mel_channels=100, n_fft=1024, win_length=1024, mel_spec_type="vocos", preprocessed_mel=False, mel_spec_module: nn.Module | None = None, ): self.data = custom_dataset self.durations = durations self.target_sample_rate = target_sample_rate self.hop_length = hop_length self.n_fft = n_fft self.win_length = win_length self.mel_spec_type = mel_spec_type self.preprocessed_mel = preprocessed_mel if not preprocessed_mel: self.mel_spectrogram = default( mel_spec_module, MelSpec( n_fft=n_fft, hop_length=hop_length, win_length=win_length, n_mel_channels=n_mel_channels, target_sample_rate=target_sample_rate, mel_spec_type=mel_spec_type, ), ) def get_frame_len(self, index): if ( self.durations is not None ): # Please make sure the separately provided durations are correct, otherwise 99.99% OOM return self.durations[index] * self.target_sample_rate / self.hop_length return self.data[index]["duration"] * self.target_sample_rate / self.hop_length def __len__(self): return len(self.data) def __getitem__(self, index): while True: row = self.data[index] audio_path = row["audio_path"] text = row["text"] duration = row["duration"] # filter by given length if 0.3 <= duration <= 30: break # valid index = (index + 1) % len(self.data) if self.preprocessed_mel: mel_spec = torch.tensor(row["mel_spec"]) else: audio, source_sample_rate = torchaudio.load(audio_path) # make sure mono input if audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True) # resample if necessary if source_sample_rate != self.target_sample_rate: resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate) audio = resampler(audio) # to mel spectrogram mel_spec = self.mel_spectrogram(audio) mel_spec = mel_spec.squeeze(0) # '1 d t -> d t' return { "mel_spec": mel_spec, "text": text, } # Dynamic Batch Sampler class DynamicBatchSampler(Sampler[list[int]]): """Extension of Sampler that will do the following: 1. Change the batch size (essentially number of sequences) in a batch to ensure that the total number of frames are less than a certain threshold. 2. Make sure the padding efficiency in the batch is high. 3. Shuffle batches each epoch while maintaining reproducibility. """ def __init__( self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_residual: bool = False ): self.sampler = sampler self.frames_threshold = frames_threshold self.max_samples = max_samples self.random_seed = random_seed self.epoch = 0 indices, batches = [], [] data_source = self.sampler.data_source for idx in tqdm( self.sampler, desc="Sorting with sampler... if slow, check whether dataset is provided with duration" ): indices.append((idx, data_source.get_frame_len(idx))) indices.sort(key=lambda elem: elem[1]) batch = [] batch_frames = 0 for idx, frame_len in tqdm( indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu" ): if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples): batch.append(idx) batch_frames += frame_len else: if len(batch) > 0: batches.append(batch) if frame_len <= self.frames_threshold: batch = [idx] batch_frames = frame_len else: batch = [] batch_frames = 0 if not drop_residual and len(batch) > 0: batches.append(batch) del indices self.batches = batches # Ensure even batches with accelerate BatchSamplerShard cls under frame_per_batch setting self.drop_last = True def set_epoch(self, epoch: int) -> None: """Sets the epoch for this sampler.""" self.epoch = epoch def __iter__(self): # Use both random_seed and epoch for deterministic but different shuffling per epoch if self.random_seed is not None: g = torch.Generator() g.manual_seed(self.random_seed + self.epoch) # Use PyTorch's random permutation for better reproducibility across PyTorch versions indices = torch.randperm(len(self.batches), generator=g).tolist() batches = [self.batches[i] for i in indices] else: batches = self.batches return iter(batches) def __len__(self): return len(self.batches) # Load dataset def load_dataset( dataset_name: str, tokenizer: str = "pinyin", dataset_type: str = "CustomDataset", audio_type: str = "raw", mel_spec_module: nn.Module | None = None, mel_spec_kwargs: dict = dict(), ) -> CustomDataset | HFDataset: """ dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset - "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer """ print("Loading dataset ...") if dataset_type == "CustomDataset": rel_data_path = str(files("f5_tts").joinpath(f"../../data/{dataset_name}_{tokenizer}")) if audio_type == "raw": try: train_dataset = load_from_disk(f"{rel_data_path}/raw") except: # noqa: E722 train_dataset = Dataset_.from_file(f"{rel_data_path}/raw.arrow") preprocessed_mel = False elif audio_type == "mel": train_dataset = Dataset_.from_file(f"{rel_data_path}/mel.arrow") preprocessed_mel = True with open(f"{rel_data_path}/duration.json", "r", encoding="utf-8") as f: data_dict = json.load(f) durations = data_dict["duration"] train_dataset = CustomDataset( train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, mel_spec_module=mel_spec_module, **mel_spec_kwargs, ) elif dataset_type == "CustomDatasetPath": try: train_dataset = load_from_disk(f"{dataset_name}/raw") except: # noqa: E722 train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow") with open(f"{dataset_name}/duration.json", "r", encoding="utf-8") as f: data_dict = json.load(f) durations = data_dict["duration"] train_dataset = CustomDataset( train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs ) elif dataset_type == "HFDataset": print( "Should manually modify the path of huggingface dataset to your need.\n" + "May also the corresponding script cuz different dataset may have different format." ) pre, post = dataset_name.split("_") train_dataset = HFDataset( load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir=str(files("f5_tts").joinpath("../../data"))), ) return train_dataset # collation def collate_fn(batch): mel_specs = [item["mel_spec"].squeeze(0) for item in batch] mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs]) max_mel_length = mel_lengths.amax() padded_mel_specs = [] for spec in mel_specs: padding = (0, max_mel_length - spec.size(-1)) padded_spec = F.pad(spec, padding, value=0) padded_mel_specs.append(padded_spec) mel_specs = torch.stack(padded_mel_specs) text = [item["text"] for item in batch] text_lengths = torch.LongTensor([len(item) for item in text]) return dict( mel=mel_specs, mel_lengths=mel_lengths, # records for padding mask text=text, text_lengths=text_lengths, ) ================================================ FILE: src/f5_tts/model/modules.py ================================================ """ ein notation: b - batch n - sequence nt - text sequence nw - raw wave length d - dimension """ # ruff: noqa: F722 F821 from __future__ import annotations import math import warnings from typing import Optional import torch import torch.nn.functional as F import torchaudio from librosa.filters import mel as librosa_mel_fn from torch import nn from x_transformers.x_transformers import apply_rotary_pos_emb from f5_tts.model.utils import is_package_available # raw wav to mel spec mel_basis_cache = {} hann_window_cache = {} def get_bigvgan_mel_spectrogram( waveform, n_fft=1024, n_mel_channels=100, target_sample_rate=24000, hop_length=256, win_length=1024, fmin=0, fmax=None, center=False, ): # Copy from https://github.com/NVIDIA/BigVGAN/tree/main device = waveform.device key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}" if key not in mel_basis_cache: mel = librosa_mel_fn(sr=target_sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=fmin, fmax=fmax) mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) # TODO: why they need .float()? hann_window_cache[key] = torch.hann_window(win_length).to(device) mel_basis = mel_basis_cache[key] hann_window = hann_window_cache[key] padding = (n_fft - hop_length) // 2 waveform = torch.nn.functional.pad(waveform.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1) spec = torch.stft( waveform, n_fft, hop_length=hop_length, win_length=win_length, window=hann_window, center=center, pad_mode="reflect", normalized=False, onesided=True, return_complex=True, ) spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9) mel_spec = torch.matmul(mel_basis, spec) mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5)) return mel_spec def get_vocos_mel_spectrogram( waveform, n_fft=1024, n_mel_channels=100, target_sample_rate=24000, hop_length=256, win_length=1024, ): mel_stft = torchaudio.transforms.MelSpectrogram( sample_rate=target_sample_rate, n_fft=n_fft, win_length=win_length, hop_length=hop_length, n_mels=n_mel_channels, power=1, center=True, normalized=False, norm=None, ).to(waveform.device) if len(waveform.shape) == 3: waveform = waveform.squeeze(1) # 'b 1 nw -> b nw' assert len(waveform.shape) == 2 mel = mel_stft(waveform) mel = mel.clamp(min=1e-5).log() return mel class MelSpec(nn.Module): def __init__( self, n_fft=1024, hop_length=256, win_length=1024, n_mel_channels=100, target_sample_rate=24_000, mel_spec_type="vocos", ): super().__init__() assert mel_spec_type in ["vocos", "bigvgan"], print("We only support two extract mel backend: vocos or bigvgan") self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length self.n_mel_channels = n_mel_channels self.target_sample_rate = target_sample_rate if mel_spec_type == "vocos": self.extractor = get_vocos_mel_spectrogram elif mel_spec_type == "bigvgan": self.extractor = get_bigvgan_mel_spectrogram self.register_buffer("dummy", torch.tensor(0), persistent=False) def forward(self, wav): if self.dummy.device != wav.device: self.to(wav.device) mel = self.extractor( waveform=wav, n_fft=self.n_fft, n_mel_channels=self.n_mel_channels, target_sample_rate=self.target_sample_rate, hop_length=self.hop_length, win_length=self.win_length, ) return mel # sinusoidal position embedding class SinusPositionEmbedding(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x, scale=1000): device = x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb # convolutional position embedding class ConvPositionEmbedding(nn.Module): def __init__(self, dim, kernel_size=31, groups=16): super().__init__() assert kernel_size % 2 != 0 self.conv1d = nn.Sequential( nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), nn.Mish(), nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), nn.Mish(), ) self.layer_need_mask_idx = [i for i, layer in enumerate(self.conv1d) if isinstance(layer, nn.Conv1d)] def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): if mask is not None: mask = mask.unsqueeze(1) # [B 1 N] x = x.permute(0, 2, 1) # [B D N] if mask is not None: x = x.masked_fill(~mask, 0.0) for i, block in enumerate(self.conv1d): x = block(x) if mask is not None and i in self.layer_need_mask_idx: x = x.masked_fill(~mask, 0.0) x = x.permute(0, 2, 1) # [B N D] return x # rotary positional embedding related def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0): # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning # has some connection to NTK literature # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py theta *= theta_rescale_factor ** (dim / (dim - 2)) freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) # type: ignore freqs = torch.outer(t, freqs).float() # type: ignore freqs_cos = torch.cos(freqs) # real part freqs_sin = torch.sin(freqs) # imaginary part return torch.cat([freqs_cos, freqs_sin], dim=-1) def get_pos_embed_indices(start, length, max_pos, scale=1.0): # length = length if isinstance(length, int) else length.max() scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar pos = ( start.unsqueeze(1) + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long() ) # avoid extra long error. pos = torch.where(pos < max_pos, pos, max_pos - 1) return pos # Global Response Normalization layer (Instance Normalization ?) class GRN(nn.Module): def __init__(self, dim): super().__init__() self.gamma = nn.Parameter(torch.zeros(1, 1, dim)) self.beta = nn.Parameter(torch.zeros(1, 1, dim)) def forward(self, x): Gx = torch.norm(x, p=2, dim=1, keepdim=True) Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) return self.gamma * (x * Nx) + self.beta + x # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108 class ConvNeXtV2Block(nn.Module): def __init__( self, dim: int, intermediate_dim: int, dilation: int = 1, ): super().__init__() padding = (dilation * (7 - 1)) // 2 self.dwconv = nn.Conv1d( dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation ) # depthwise conv self.norm = nn.LayerNorm(dim, eps=1e-6) self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers self.act = nn.GELU() self.grn = GRN(intermediate_dim) self.pwconv2 = nn.Linear(intermediate_dim, dim) def forward(self, x: torch.Tensor) -> torch.Tensor: residual = x x = x.transpose(1, 2) # b n d -> b d n x = self.dwconv(x) x = x.transpose(1, 2) # b d n -> b n d x = self.norm(x) x = self.pwconv1(x) x = self.act(x) x = self.grn(x) x = self.pwconv2(x) return residual + x # RMSNorm class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) self.native_rms_norm = float(torch.__version__[:3]) >= 2.4 def forward(self, x): if self.native_rms_norm: if self.weight.dtype in [torch.float16, torch.bfloat16]: x = x.to(self.weight.dtype) x = F.rms_norm(x, normalized_shape=(x.shape[-1],), weight=self.weight, eps=self.eps) else: variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(variance + self.eps) if self.weight.dtype in [torch.float16, torch.bfloat16]: x = x.to(self.weight.dtype) x = x * self.weight return x # AdaLayerNorm # return with modulated x for attn input, and params for later mlp modulation class AdaLayerNorm(nn.Module): def __init__(self, dim): super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear(dim, dim * 6) self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) def forward(self, x, emb=None): emb = self.linear(self.silu(emb)) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1) x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] return x, gate_msa, shift_mlp, scale_mlp, gate_mlp # AdaLayerNorm for final layer # return only with modulated x for attn input, cuz no more mlp modulation class AdaLayerNorm_Final(nn.Module): def __init__(self, dim): super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear(dim, dim * 2) self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) def forward(self, x, emb): emb = self.linear(self.silu(emb)) scale, shift = torch.chunk(emb, 2, dim=1) x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] return x # FeedForward class FeedForward(nn.Module): def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"): super().__init__() inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim activation = nn.GELU(approximate=approximate) project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation) self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) def forward(self, x): return self.ff(x) # Attention with possible joint part # modified from diffusers/src/diffusers/models/attention_processor.py class Attention(nn.Module): def __init__( self, processor: JointAttnProcessor | AttnProcessor, dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0, context_dim: Optional[int] = None, # if not None -> joint attention context_pre_only: bool = False, qk_norm: Optional[str] = None, ): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") self.processor = processor self.dim = dim self.heads = heads self.inner_dim = dim_head * heads self.dropout = dropout self.context_dim = context_dim self.context_pre_only = context_pre_only self.to_q = nn.Linear(dim, self.inner_dim) self.to_k = nn.Linear(dim, self.inner_dim) self.to_v = nn.Linear(dim, self.inner_dim) if qk_norm is None: self.q_norm = None self.k_norm = None elif qk_norm == "rms_norm": self.q_norm = RMSNorm(dim_head, eps=1e-6) self.k_norm = RMSNorm(dim_head, eps=1e-6) else: raise ValueError(f"Unimplemented qk_norm: {qk_norm}") if self.context_dim is not None: self.to_q_c = nn.Linear(context_dim, self.inner_dim) self.to_k_c = nn.Linear(context_dim, self.inner_dim) self.to_v_c = nn.Linear(context_dim, self.inner_dim) if qk_norm is None: self.c_q_norm = None self.c_k_norm = None elif qk_norm == "rms_norm": self.c_q_norm = RMSNorm(dim_head, eps=1e-6) self.c_k_norm = RMSNorm(dim_head, eps=1e-6) self.to_out = nn.ModuleList([]) self.to_out.append(nn.Linear(self.inner_dim, dim)) self.to_out.append(nn.Dropout(dropout)) if self.context_dim is not None and not self.context_pre_only: self.to_out_c = nn.Linear(self.inner_dim, context_dim) def forward( self, x: float["b n d"], # noised input x c: float["b n d"] = None, # context c mask: bool["b n"] | None = None, rope=None, # rotary position embedding for x c_rope=None, # rotary position embedding for c c_mask: bool["b nt"] | None = None, # text mask ) -> torch.Tensor: if c is not None: return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope, c_mask=c_mask) else: return self.processor(self, x, mask=mask, rope=rope) # Attention processor if is_package_available("flash_attn"): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import pad_input, unpad_input class AttnProcessor: def __init__( self, pe_attn_head: int | None = None, # number of attention head to apply rope, None for all attn_backend: str = "torch", # "torch" or "flash_attn" attn_mask_enabled: bool = True, ): if attn_backend == "flash_attn": assert is_package_available("flash_attn"), "Please install flash-attn first." if attn_backend == "torch" and attn_mask_enabled: warnings.warn( "attn_mask_enabled=True with attn_backend='torch' can consume large GPU memory. " "Please switch attn_backend to 'flash_attn'.", UserWarning, ) self.pe_attn_head = pe_attn_head self.attn_backend = attn_backend self.attn_mask_enabled = attn_mask_enabled def __call__( self, attn: Attention, x: float["b n d"], # noised input x mask: bool["b n"] | None = None, rope=None, # rotary position embedding ) -> torch.FloatTensor: batch_size = x.shape[0] # `sample` projections query = attn.to_q(x) key = attn.to_k(x) value = attn.to_v(x) # attention inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # qk norm if attn.q_norm is not None: query = attn.q_norm(query) if attn.k_norm is not None: key = attn.k_norm(key) # apply rotary position embedding if rope is not None: freqs, xpos_scale = rope q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) if self.pe_attn_head is not None: pn = self.pe_attn_head query[:, :pn, :, :] = apply_rotary_pos_emb(query[:, :pn, :, :], freqs, q_xpos_scale) key[:, :pn, :, :] = apply_rotary_pos_emb(key[:, :pn, :, :], freqs, k_xpos_scale) else: query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) if self.attn_backend == "torch": # mask. e.g. inference got a batch with different target durations, mask out the padding if self.attn_mask_enabled and mask is not None: attn_mask = mask attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2]) else: attn_mask = None x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) elif self.attn_backend == "flash_attn": query = query.transpose(1, 2) # [b, h, n, d] -> [b, n, h, d] key = key.transpose(1, 2) value = value.transpose(1, 2) if self.attn_mask_enabled and mask is not None: query, indices, q_cu_seqlens, q_max_seqlen_in_batch, _ = unpad_input(query, mask) key, _, k_cu_seqlens, k_max_seqlen_in_batch, _ = unpad_input(key, mask) value, _, _, _, _ = unpad_input(value, mask) x = flash_attn_varlen_func( query, key, value, q_cu_seqlens, k_cu_seqlens, q_max_seqlen_in_batch, k_max_seqlen_in_batch, ) x = pad_input(x, indices, batch_size, q_max_seqlen_in_batch) x = x.reshape(batch_size, -1, attn.heads * head_dim) else: x = flash_attn_func(query, key, value, dropout_p=0.0, causal=False) x = x.reshape(batch_size, -1, attn.heads * head_dim) x = x.to(query.dtype) # linear proj x = attn.to_out[0](x) # dropout x = attn.to_out[1](x) if mask is not None: mask = mask.unsqueeze(-1) x = x.masked_fill(~mask, 0.0) return x # Joint Attention processor for MM-DiT # modified from diffusers/src/diffusers/models/attention_processor.py class JointAttnProcessor: def __init__( self, attn_backend: str = "torch", # "torch" or "flash_attn" attn_mask_enabled: bool = True, ): if attn_backend == "flash_attn": assert is_package_available("flash_attn"), "Please install flash-attn first." if attn_backend == "torch" and attn_mask_enabled: warnings.warn( "attn_mask_enabled=True with attn_backend='torch' can consume large GPU memory. " "Please switch attn_backend to 'flash_attn'.", UserWarning, ) self.attn_backend = attn_backend self.attn_mask_enabled = attn_mask_enabled def __call__( self, attn: Attention, x: float["b n d"], # noised input x c: float["b nt d"] = None, # context c, here text mask: bool["b n"] | None = None, rope=None, # rotary position embedding for x c_rope=None, # rotary position embedding for c c_mask: bool["b nt"] | None = None, # text mask ) -> torch.FloatTensor: residual = x audio_mask = mask batch_size = c.shape[0] # `sample` projections query = attn.to_q(x) key = attn.to_k(x) value = attn.to_v(x) # `context` projections c_query = attn.to_q_c(c) c_key = attn.to_k_c(c) c_value = attn.to_v_c(c) # attention inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) c_query = c_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) c_key = c_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) c_value = c_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # qk norm if attn.q_norm is not None: query = attn.q_norm(query) if attn.k_norm is not None: key = attn.k_norm(key) if attn.c_q_norm is not None: c_query = attn.c_q_norm(c_query) if attn.c_k_norm is not None: c_key = attn.c_k_norm(c_key) # apply rope for context and noised input independently if rope is not None: freqs, xpos_scale = rope q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) if c_rope is not None: freqs, xpos_scale = c_rope q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale) c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale) # joint attention query = torch.cat([query, c_query], dim=2) key = torch.cat([key, c_key], dim=2) value = torch.cat([value, c_value], dim=2) # build combined mask for joint attention: audio mask + text mask if self.attn_mask_enabled and mask is not None: if c_mask is not None: mask = torch.cat([mask, c_mask], dim=1) else: mask = F.pad(mask, (0, c.shape[1]), value=True) if self.attn_backend == "torch": # mask. e.g. inference got a batch with different target durations, mask out the padding if self.attn_mask_enabled and mask is not None: attn_mask = mask attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2]) else: attn_mask = None x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) elif self.attn_backend == "flash_attn": query = query.transpose(1, 2) # [b, h, n, d] -> [b, n, h, d] key = key.transpose(1, 2) value = value.transpose(1, 2) if self.attn_mask_enabled and mask is not None: total_seq_len = query.shape[1] query, indices, q_cu_seqlens, q_max_seqlen_in_batch, _ = unpad_input(query, mask) key, _, k_cu_seqlens, k_max_seqlen_in_batch, _ = unpad_input(key, mask) value, _, _, _, _ = unpad_input(value, mask) x = flash_attn_varlen_func( query, key, value, q_cu_seqlens, k_cu_seqlens, q_max_seqlen_in_batch, k_max_seqlen_in_batch, ) x = pad_input(x, indices, batch_size, total_seq_len) x = x.reshape(batch_size, -1, attn.heads * head_dim) else: x = flash_attn_func(query, key, value, dropout_p=0.0, causal=False) x = x.reshape(batch_size, -1, attn.heads * head_dim) x = x.to(query.dtype) # Split the attention outputs. x, c = ( x[:, : residual.shape[1]], x[:, residual.shape[1] :], ) # linear proj x = attn.to_out[0](x) # dropout x = attn.to_out[1](x) if not attn.context_pre_only: c = attn.to_out_c(c) if audio_mask is not None: x = x.masked_fill(~audio_mask.unsqueeze(-1), 0.0) if c_mask is not None: c = c.masked_fill(~c_mask.unsqueeze(-1), 0.0) return x, c # DiT Block class DiTBlock(nn.Module): def __init__( self, dim, heads, dim_head, ff_mult=4, dropout=0.1, qk_norm=None, pe_attn_head=None, attn_backend="torch", # "torch" or "flash_attn" attn_mask_enabled=True, ): super().__init__() self.attn_norm = AdaLayerNorm(dim) self.attn = Attention( processor=AttnProcessor( pe_attn_head=pe_attn_head, attn_backend=attn_backend, attn_mask_enabled=attn_mask_enabled, ), dim=dim, heads=heads, dim_head=dim_head, dropout=dropout, qk_norm=qk_norm, ) self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding # pre-norm & modulation for attention input norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t) # attention attn_output = self.attn(x=norm, mask=mask, rope=rope) # process attention output for input x x = x + gate_msa.unsqueeze(1) * attn_output norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] ff_output = self.ff(norm) x = x + gate_mlp.unsqueeze(1) * ff_output return x # MMDiT Block https://arxiv.org/abs/2403.03206 class MMDiTBlock(nn.Module): r""" modified from diffusers/src/diffusers/models/attention.py notes. _c: context related. text, cond, etc. (left part in sd3 fig2.b) _x: noised input related. (right part) context_pre_only: last layer only do prenorm + modulation cuz no more ffn """ def __init__( self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_dim=None, context_pre_only=False, qk_norm=None, attn_backend="torch", attn_mask_enabled=False, ): super().__init__() if context_dim is None: context_dim = dim self.context_pre_only = context_pre_only self.attn_norm_c = AdaLayerNorm_Final(context_dim) if context_pre_only else AdaLayerNorm(context_dim) self.attn_norm_x = AdaLayerNorm(dim) self.attn = Attention( processor=JointAttnProcessor( attn_backend=attn_backend, attn_mask_enabled=attn_mask_enabled, ), dim=dim, heads=heads, dim_head=dim_head, dropout=dropout, context_dim=context_dim, context_pre_only=context_pre_only, qk_norm=qk_norm, ) if not context_pre_only: self.ff_norm_c = nn.LayerNorm(context_dim, elementwise_affine=False, eps=1e-6) self.ff_c = FeedForward(dim=context_dim, mult=ff_mult, dropout=dropout, approximate="tanh") else: self.ff_norm_c = None self.ff_c = None self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") def forward( self, x, c, t, mask=None, rope=None, c_rope=None, c_mask=None ): # x: noised input, c: context, t: time embedding # pre-norm & modulation for attention input if self.context_pre_only: norm_c = self.attn_norm_c(c, t) else: norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t) norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t) # attention x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope, c_mask=c_mask) # process attention output for context c if self.context_pre_only: c = None else: # if not last layer c = c + c_gate_msa.unsqueeze(1) * c_attn_output norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] c_ff_output = self.ff_c(norm_c) c = c + c_gate_mlp.unsqueeze(1) * c_ff_output # process attention output for input x x = x + x_gate_msa.unsqueeze(1) * x_attn_output norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None] x_ff_output = self.ff_x(norm_x) x = x + x_gate_mlp.unsqueeze(1) * x_ff_output return c, x # time step conditioning embedding class TimestepEmbedding(nn.Module): def __init__(self, dim, freq_embed_dim=256): super().__init__() self.time_embed = SinusPositionEmbedding(freq_embed_dim) self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) def forward(self, timestep: float["b"]): time_hidden = self.time_embed(timestep) time_hidden = time_hidden.to(timestep.dtype) time = self.time_mlp(time_hidden) # b d return time ================================================ FILE: src/f5_tts/model/trainer.py ================================================ from __future__ import annotations import gc import math import os import torch import torchaudio import wandb from accelerate import Accelerator from accelerate.utils import DistributedDataParallelKwargs from ema_pytorch import EMA from torch.optim import AdamW from torch.optim.lr_scheduler import LinearLR, SequentialLR from torch.utils.data import DataLoader, Dataset, SequentialSampler from tqdm import tqdm from f5_tts.model import CFM from f5_tts.model.dataset import DynamicBatchSampler, collate_fn from f5_tts.model.utils import default, exists # trainer class Trainer: def __init__( self, model: CFM, epochs, learning_rate, num_warmup_updates=20000, save_per_updates=1000, keep_last_n_checkpoints: int = -1, # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints checkpoint_path=None, batch_size_per_gpu=32, batch_size_type: str = "sample", max_samples=32, grad_accumulation_steps=1, max_grad_norm=1.0, noise_scheduler: str | None = None, duration_predictor: torch.nn.Module | None = None, logger: str | None = "wandb", # "wandb" | "tensorboard" | None wandb_project="test_f5-tts", wandb_run_name="test_run", wandb_resume_id: str = None, log_samples: bool = False, last_per_updates=None, accelerate_kwargs: dict = dict(), ema_kwargs: dict = dict(), bnb_optimizer: bool = False, mel_spec_type: str = "vocos", # "vocos" | "bigvgan" is_local_vocoder: bool = False, # use local path vocoder local_vocoder_path: str = "", # local vocoder path model_cfg_dict: dict = dict(), # training config ): ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) if logger == "wandb" and not wandb.api.api_key: logger = None self.log_samples = log_samples self.accelerator = Accelerator( log_with=logger if logger == "wandb" else None, kwargs_handlers=[ddp_kwargs], gradient_accumulation_steps=grad_accumulation_steps, **accelerate_kwargs, ) self.logger = logger if self.logger == "wandb": if exists(wandb_resume_id): init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}} else: init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}} if not model_cfg_dict: model_cfg_dict = { "epochs": epochs, "learning_rate": learning_rate, "num_warmup_updates": num_warmup_updates, "batch_size_per_gpu": batch_size_per_gpu, "batch_size_type": batch_size_type, "max_samples": max_samples, "grad_accumulation_steps": grad_accumulation_steps, "max_grad_norm": max_grad_norm, "noise_scheduler": noise_scheduler, "bnb_optimizer": bnb_optimizer, } model_cfg_dict["gpus"] = self.accelerator.num_processes self.accelerator.init_trackers( project_name=wandb_project, init_kwargs=init_kwargs, config=model_cfg_dict, ) elif self.logger == "tensorboard": from torch.utils.tensorboard import SummaryWriter self.writer = None if self.accelerator.is_main_process: self.writer = SummaryWriter(log_dir=f"runs/{wandb_run_name}") self.model = model if self.is_main: self.ema_model = EMA(model, include_online_model=False, **ema_kwargs) self.ema_model.to(self.accelerator.device) print(f"Using logger: {logger}") if grad_accumulation_steps > 1: print( "Gradient accumulation checkpointing with per_updates now, old logic per_steps used with before f992c4e" ) self.epochs = epochs self.num_warmup_updates = num_warmup_updates self.save_per_updates = save_per_updates self.keep_last_n_checkpoints = keep_last_n_checkpoints self.last_per_updates = default(last_per_updates, save_per_updates) self.checkpoint_path = default(checkpoint_path, "ckpts/test_f5-tts") self.batch_size_per_gpu = batch_size_per_gpu self.batch_size_type = batch_size_type self.max_samples = max_samples self.grad_accumulation_steps = grad_accumulation_steps self.max_grad_norm = max_grad_norm # mel vocoder config self.vocoder_name = mel_spec_type self.is_local_vocoder = is_local_vocoder self.local_vocoder_path = local_vocoder_path self.noise_scheduler = noise_scheduler self.duration_predictor = duration_predictor if bnb_optimizer: import bitsandbytes as bnb self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate) else: self.optimizer = AdamW(model.parameters(), lr=learning_rate, fused=True) self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) @property def is_main(self): return self.accelerator.is_main_process def save_checkpoint(self, update, last=False): self.accelerator.wait_for_everyone() if self.is_main: checkpoint = dict( model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(), optimizer_state_dict=self.optimizer.state_dict(), ema_model_state_dict=self.ema_model.state_dict(), scheduler_state_dict=self.scheduler.state_dict(), update=update, ) if not os.path.exists(self.checkpoint_path): os.makedirs(self.checkpoint_path) if last: self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt") print(f"Saved last checkpoint at update {update}") else: if self.keep_last_n_checkpoints == 0: return self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt") if self.keep_last_n_checkpoints > 0: # Updated logic to exclude pretrained model from rotation checkpoints = [ f for f in os.listdir(self.checkpoint_path) if f.startswith("model_") and not f.startswith("pretrained_") # Exclude pretrained models and f.endswith(".pt") and f != "model_last.pt" ] checkpoints.sort(key=lambda x: int(x.split("_")[1].split(".")[0])) while len(checkpoints) > self.keep_last_n_checkpoints: oldest_checkpoint = checkpoints.pop(0) os.remove(os.path.join(self.checkpoint_path, oldest_checkpoint)) print(f"Removed old checkpoint: {oldest_checkpoint}") def load_checkpoint(self): if ( not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path) or not any(filename.endswith((".pt", ".safetensors")) for filename in os.listdir(self.checkpoint_path)) ): return 0 self.accelerator.wait_for_everyone() if "model_last.pt" in os.listdir(self.checkpoint_path): latest_checkpoint = "model_last.pt" else: # Updated to consider pretrained models for loading but prioritize training checkpoints all_checkpoints = [ f for f in os.listdir(self.checkpoint_path) if (f.startswith("model_") or f.startswith("pretrained_")) and f.endswith((".pt", ".safetensors")) ] # First try to find regular training checkpoints training_checkpoints = [f for f in all_checkpoints if f.startswith("model_") and f != "model_last.pt"] if training_checkpoints: latest_checkpoint = sorted( training_checkpoints, key=lambda x: int("".join(filter(str.isdigit, x))), )[-1] else: # If no training checkpoints, use pretrained model latest_checkpoint = next(f for f in all_checkpoints if f.startswith("pretrained_")) if latest_checkpoint.endswith(".safetensors"): # always a pretrained checkpoint from safetensors.torch import load_file checkpoint = load_file(f"{self.checkpoint_path}/{latest_checkpoint}", device="cpu") checkpoint = {"ema_model_state_dict": checkpoint} elif latest_checkpoint.endswith(".pt"): # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ checkpoint = torch.load( f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu" ) # patch for backward compatibility, 305e3ea for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]: if key in checkpoint["ema_model_state_dict"]: del checkpoint["ema_model_state_dict"][key] if self.is_main: self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"]) if "update" in checkpoint or "step" in checkpoint: # patch for backward compatibility, with before f992c4e if "step" in checkpoint: checkpoint["update"] = checkpoint["step"] // self.grad_accumulation_steps if self.grad_accumulation_steps > 1 and self.is_main: print( "F5-TTS WARNING: Loading checkpoint saved with per_steps logic (before f992c4e), will convert to per_updates according to grad_accumulation_steps setting, may have unexpected behaviour." ) # patch for backward compatibility, 305e3ea for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]: if key in checkpoint["model_state_dict"]: del checkpoint["model_state_dict"][key] self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"]) self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) if self.scheduler: self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) update = checkpoint["update"] else: checkpoint["model_state_dict"] = { k.replace("ema_model.", ""): v for k, v in checkpoint["ema_model_state_dict"].items() if k not in ["initted", "update", "step"] } self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"]) update = 0 del checkpoint gc.collect() return update def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None): if self.log_samples: from f5_tts.infer.utils_infer import cfg_strength, load_vocoder, nfe_step, sway_sampling_coef vocoder = load_vocoder( vocoder_name=self.vocoder_name, is_local=self.is_local_vocoder, local_path=self.local_vocoder_path ) target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.target_sample_rate log_samples_path = f"{self.checkpoint_path}/samples" os.makedirs(log_samples_path, exist_ok=True) if exists(resumable_with_seed): generator = torch.Generator() generator.manual_seed(resumable_with_seed) else: generator = None if self.batch_size_type == "sample": train_dataloader = DataLoader( train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True, batch_size=self.batch_size_per_gpu, shuffle=True, generator=generator, ) elif self.batch_size_type == "frame": self.accelerator.even_batches = False sampler = SequentialSampler(train_dataset) batch_sampler = DynamicBatchSampler( sampler, self.batch_size_per_gpu, max_samples=self.max_samples, random_seed=resumable_with_seed, # This enables reproducible shuffling drop_residual=False, ) train_dataloader = DataLoader( train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True, batch_sampler=batch_sampler, ) else: raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}") # accelerator.prepare() dispatches batches to devices; # which means the length of dataloader calculated before, should consider the number of devices warmup_updates = ( self.num_warmup_updates * self.accelerator.num_processes ) # consider a fixed warmup steps while using accelerate multi-gpu ddp # otherwise by default with split_batches=False, warmup steps change with num_processes total_updates = math.ceil(len(train_dataloader) / self.grad_accumulation_steps) * self.epochs decay_updates = total_updates - warmup_updates warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_updates) decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_updates) self.scheduler = SequentialLR( self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_updates] ) train_dataloader, self.scheduler = self.accelerator.prepare( train_dataloader, self.scheduler ) # actual multi_gpu updates = single_gpu updates / gpu nums start_update = self.load_checkpoint() global_update = start_update if exists(resumable_with_seed): orig_epoch_step = len(train_dataloader) start_step = start_update * self.grad_accumulation_steps skipped_epoch = int(start_step // orig_epoch_step) skipped_batch = start_step % orig_epoch_step skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch) else: skipped_epoch = 0 for epoch in range(skipped_epoch, self.epochs): self.model.train() if exists(resumable_with_seed) and epoch == skipped_epoch: progress_bar_initial = math.ceil(skipped_batch / self.grad_accumulation_steps) current_dataloader = skipped_dataloader else: progress_bar_initial = 0 current_dataloader = train_dataloader # Set epoch for the batch sampler if it exists if hasattr(train_dataloader, "batch_sampler") and hasattr(train_dataloader.batch_sampler, "set_epoch"): train_dataloader.batch_sampler.set_epoch(epoch) progress_bar = tqdm( range(math.ceil(len(train_dataloader) / self.grad_accumulation_steps)), desc=f"Epoch {epoch + 1}/{self.epochs}", unit="update", disable=not self.accelerator.is_local_main_process, initial=progress_bar_initial, ) for batch in current_dataloader: with self.accelerator.accumulate(self.model): text_inputs = batch["text"] mel_spec = batch["mel"].permute(0, 2, 1) mel_lengths = batch["mel_lengths"] # TODO. add duration predictor training if self.duration_predictor is not None and self.accelerator.is_local_main_process: dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations")) self.accelerator.log({"duration loss": dur_loss.item()}, step=global_update) loss, cond, pred = self.model( mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler ) self.accelerator.backward(loss) if self.max_grad_norm > 0 and self.accelerator.sync_gradients: self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) self.optimizer.step() self.scheduler.step() self.optimizer.zero_grad() if self.accelerator.sync_gradients: if self.is_main: self.ema_model.update() global_update += 1 progress_bar.update(1) progress_bar.set_postfix(update=str(global_update), loss=loss.item()) if self.accelerator.is_local_main_process: self.accelerator.log( {"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_update ) if self.logger == "tensorboard" and self.accelerator.is_main_process: self.writer.add_scalar("loss", loss.item(), global_update) self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_update) if global_update % self.last_per_updates == 0 and self.accelerator.sync_gradients: self.save_checkpoint(global_update, last=True) if global_update % self.save_per_updates == 0 and self.accelerator.sync_gradients: self.save_checkpoint(global_update) if self.log_samples and self.accelerator.is_local_main_process: ref_audio_len = mel_lengths[0] infer_text = [ text_inputs[0] + ([" "] if isinstance(text_inputs[0], list) else " ") + text_inputs[0] ] with torch.inference_mode(), self.accelerator.autocast(): generated, _ = self.accelerator.unwrap_model(self.model).sample( cond=mel_spec[0][:ref_audio_len].unsqueeze(0), text=infer_text, duration=ref_audio_len * 2, steps=nfe_step, cfg_strength=cfg_strength, sway_sampling_coef=sway_sampling_coef, ) generated = generated.to(torch.float32) gen_mel_spec = generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device) ref_mel_spec = batch["mel"][0, :, :ref_audio_len].unsqueeze(0) if self.vocoder_name == "vocos": gen_audio = vocoder.decode(gen_mel_spec).cpu() ref_audio = vocoder.decode(ref_mel_spec).cpu() elif self.vocoder_name == "bigvgan": gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu() ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu() torchaudio.save( f"{log_samples_path}/update_{global_update}_gen.wav", gen_audio, target_sample_rate ) torchaudio.save( f"{log_samples_path}/update_{global_update}_ref.wav", ref_audio, target_sample_rate ) self.model.train() self.save_checkpoint(global_update, last=True) self.accelerator.end_training() ================================================ FILE: src/f5_tts/model/utils.py ================================================ # ruff: noqa: F722 F821 from __future__ import annotations import os import random from collections import defaultdict from importlib.resources import files import rjieba import torch from pypinyin import Style, lazy_pinyin from torch.nn.utils.rnn import pad_sequence # seed everything def seed_everything(seed=0): random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # helpers def exists(v): return v is not None def default(v, d): return v if exists(v) else d def is_package_available(package_name: str) -> bool: try: import importlib package_exists = importlib.util.find_spec(package_name) is not None return package_exists except Exception: return False # tensor helpers def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: if not exists(length): length = t.amax() seq = torch.arange(length, device=t.device) return seq[None, :] < t[:, None] def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): max_seq_len = seq_len.max().item() seq = torch.arange(max_seq_len, device=start.device).long() start_mask = seq[None, :] >= start[:, None] end_mask = seq[None, :] < end[:, None] return start_mask & end_mask def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): lengths = (frac_lengths * seq_len).long() max_start = seq_len - lengths rand = torch.rand_like(frac_lengths) start = (max_start * rand).long().clamp(min=0) end = start + lengths return mask_from_start_end_indices(seq_len, start, end) def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: if not exists(mask): return t.mean(dim=1) t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device)) num = t.sum(dim=1) den = mask.float().sum(dim=1) return num / den.clamp(min=1.0) # simple utf-8 tokenizer, since paper went character based def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True) return text # char tokenizer, based on custom dataset's extracted .txt file def list_str_to_idx( text: list[str] | list[list[str]], vocab_char_map: dict[str, int], # {char: idx} padding_value=-1, ) -> int["b nt"]: list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True) return text # Get tokenizer def get_tokenizer(dataset_name, tokenizer: str = "pinyin"): """ tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file - "char" for char-wise tokenizer, need .txt vocab_file - "byte" for utf-8 tokenizer - "custom" if you're directly passing in a path to the vocab.txt you want to use vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols - if use "char", derived from unfiltered character & symbol counts of custom dataset - if use "byte", set to 256 (unicode byte range) """ if tokenizer in ["pinyin", "char"]: tokenizer_path = os.path.join(files("f5_tts").joinpath("../../data"), f"{dataset_name}_{tokenizer}/vocab.txt") with open(tokenizer_path, "r", encoding="utf-8") as f: vocab_char_map = {} for i, char in enumerate(f): vocab_char_map[char[:-1]] = i vocab_size = len(vocab_char_map) assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char" elif tokenizer == "byte": vocab_char_map = None vocab_size = 256 elif tokenizer == "custom": with open(dataset_name, "r", encoding="utf-8") as f: vocab_char_map = {} for i, char in enumerate(f): vocab_char_map[char[:-1]] = i vocab_size = len(vocab_char_map) return vocab_char_map, vocab_size # convert char to pinyin def convert_char_to_pinyin(text_list, polyphone=True): final_text_list = [] custom_trans = str.maketrans( {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"} ) # add custom trans here, to address oov def is_chinese(c): return ( "\u3100" <= c <= "\u9fff" # common chinese characters ) for text in text_list: char_list = [] text = text.translate(custom_trans) for seg in rjieba.cut(text): seg_byte_len = len(bytes(seg, "UTF-8")) if seg_byte_len == len(seg): # if pure alphabets and symbols if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"": char_list.append(" ") char_list.extend(seg) elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True) for i, c in enumerate(seg): if is_chinese(c): char_list.append(" ") char_list.append(seg_[i]) else: # if mixed characters, alphabets and symbols for c in seg: if ord(c) < 256: char_list.extend(c) elif is_chinese(c): char_list.append(" ") char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True)) else: char_list.append(c) final_text_list.append(char_list) return final_text_list # filter func for dirty data with many repetitions def repetition_found(text, length=2, tolerance=10): pattern_count = defaultdict(int) for i in range(len(text) - length + 1): pattern = text[i : i + length] pattern_count[pattern] += 1 for pattern, count in pattern_count.items(): if count > tolerance: return True return False # get the empirically pruned step for sampling def get_epss_timesteps(n, device, dtype): dt = 1 / 32 predefined_timesteps = { 5: [0, 2, 4, 8, 16, 32], 6: [0, 2, 4, 6, 8, 16, 32], 7: [0, 2, 4, 6, 8, 16, 24, 32], 10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32], 12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32], 16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32], } t = predefined_timesteps.get(n, []) if not t: return torch.linspace(0, 1, n + 1, device=device, dtype=dtype) return dt * torch.tensor(t, device=device, dtype=dtype) ================================================ FILE: src/f5_tts/runtime/triton_trtllm/.gitignore ================================================ # runtime/triton_trtllm related model.cache model_repo/ ================================================ FILE: src/f5_tts/runtime/triton_trtllm/Dockerfile.server ================================================ FROM nvcr.io/nvidia/tritonserver:24.12-py3 RUN pip install tritonclient[grpc] tensorrt-llm==0.16.0 torchaudio==2.5.1 rjieba pypinyin librosa vocos WORKDIR /workspace ================================================ FILE: src/f5_tts/runtime/triton_trtllm/README.md ================================================ ## Triton Inference Serving Best Practice for F5-TTS ### Setup #### Option 1: Quick Start ```sh # Directly launch the service using docker compose MODEL=F5TTS_v1_Base docker compose up ``` #### Option 2: Build from scratch ```sh # Build the docker image docker build . -f Dockerfile.server -t soar97/triton-f5-tts:24.12 # Create Docker Container your_mount_dir=/mnt:/mnt docker run -it --name "f5-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-f5-tts:24.12 ``` ### Build TensorRT-LLM Engines and Launch Server Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/models/core/whisper). ```sh # F5TTS_v1_Base | F5TTS_Base | F5TTS_v1_Small | F5TTS_Small bash run.sh 0 4 F5TTS_v1_Base ``` > [!NOTE] > If use custom checkpoint, set `ckpt_file` and `vocab_file` in `run.sh`. > Remember to used matched model version (`F5TTS_v1_*` for v1, `F5TTS_*` for v0). > > If use checkpoint of different structure, see `scripts/convert_checkpoint.py`, and perform modification if necessary. > [!IMPORTANT] > If train or finetune with fp32, add `--dtype float32` flag when converting checkpoint in `run.sh` phase 1. ### HTTP Client ```sh python3 client_http.py ``` ### Benchmarking #### Using Client-Server Mode ```sh # bash run.sh 5 5 F5TTS_v1_Base num_task=2 python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts ``` #### Using Offline TRT-LLM Mode ```sh # bash run.sh 7 7 F5TTS_v1_Base batch_size=1 split_name=wenetspeech4tts backend_type=trt log_dir=./tests/benchmark_batch_size_${batch_size}_${split_name}_${backend_type} rm -r $log_dir torchrun --nproc_per_node=1 \ benchmark.py --output-dir $log_dir \ --batch-size $batch_size \ --enable-warmup \ --split-name $split_name \ --model-path $ckpt_file \ --vocab-file $vocab_file \ --vocoder-trt-engine-path $VOCODER_TRT_ENGINE_PATH \ --backend-type $backend_type \ --tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1 ``` ### Benchmark Results Decoding on a single L20 GPU, using 26 different prompt_audio & target_text pairs, 16 NFE. | Model | Concurrency | Avg Latency | RTF | Mode | |---------------------|----------------|-------------|--------|-----------------| | F5-TTS Base (Vocos) | 2 | 253 ms | 0.0394 | Client-Server | | F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.0402 | Offline TRT-LLM | | F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.1467 | Offline Pytorch | ### Credits 1. [Yuekai Zhang](https://github.com/yuekaizhang) 2. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm) ================================================ FILE: src/f5_tts/runtime/triton_trtllm/benchmark.py ================================================ # Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song) # 2025 (authors: Yuekai Zhang) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Modified from https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/cli.py """ Example Usage torchrun --nproc_per_node=1 \ benchmark.py --output-dir $log_dir \ --batch-size $batch_size \ --enable-warmup \ --split-name $split_name \ --model-path $CKPT_DIR/$model/model_1200000.pt \ --vocab-file $CKPT_DIR/$model/vocab.txt \ --vocoder-trt-engine-path $vocoder_trt_engine_path \ --backend-type $backend_type \ --tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1 """ import argparse import importlib import json import os import sys import time import datasets import tensorrt as trt import torch import torch.distributed as dist import torch.nn.functional as F import torchaudio from datasets import load_dataset from huggingface_hub import hf_hub_download from tensorrt_llm._utils import trt_dtype_to_torch from tensorrt_llm.logger import logger from tensorrt_llm.runtime.session import Session, TensorInfo from torch.utils.data import DataLoader, DistributedSampler from tqdm import tqdm from vocos import Vocos sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/") from f5_tts.eval.utils_eval import padded_mel_batch from f5_tts.model.modules import get_vocos_mel_spectrogram from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer, list_str_to_idx F5TTS = importlib.import_module("model_repo_f5_tts.f5_tts.1.f5_tts_trtllm").F5TTS torch.manual_seed(0) def get_args(): parser = argparse.ArgumentParser(description="extract speech code") parser.add_argument( "--split-name", type=str, default="wenetspeech4tts", choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"], help="huggingface dataset split name", ) parser.add_argument("--output-dir", required=True, type=str, help="dir to save result") parser.add_argument( "--vocab-file", required=True, type=str, help="vocab file", ) parser.add_argument( "--model-path", required=True, type=str, help="model path, to load text embedding", ) parser.add_argument( "--tllm-model-dir", required=True, type=str, help="tllm model dir", ) parser.add_argument( "--batch-size", required=True, type=int, help="batch size (per-device) for inference", ) parser.add_argument("--num-workers", type=int, default=0, help="workers for dataloader") parser.add_argument("--prefetch", type=int, default=None, help="prefetch for dataloader") parser.add_argument( "--vocoder", default="vocos", type=str, help="vocoder name", ) parser.add_argument( "--vocoder-trt-engine-path", default=None, type=str, help="vocoder trt engine path", ) parser.add_argument("--enable-warmup", action="store_true") parser.add_argument("--remove-input-padding", action="store_true") parser.add_argument("--use-perf", action="store_true", help="use nvtx to record performance") parser.add_argument("--backend-type", type=str, default="triton", choices=["trt", "pytorch"], help="backend type") args = parser.parse_args() return args def data_collator(batch, vocab_char_map, device="cuda", use_perf=False): if use_perf: torch.cuda.nvtx.range_push("data_collator") target_sample_rate = 24000 target_rms = 0.1 ( ids, ref_rms_list, ref_mel_list, ref_mel_len_list, estimated_reference_target_mel_len, reference_target_texts_list, ) = ( [], [], [], [], [], [], ) for i, item in enumerate(batch): item_id, prompt_text, target_text = ( item["id"], item["prompt_text"], item["target_text"], ) ids.append(item_id) reference_target_texts_list.append(prompt_text + target_text) ref_audio_org, ref_sr = ( item["prompt_audio"]["array"], item["prompt_audio"]["sampling_rate"], ) ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float() ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org))) ref_rms_list.append(ref_rms) if ref_rms < target_rms: ref_audio_org = ref_audio_org * target_rms / ref_rms if ref_sr != target_sample_rate: resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) ref_audio = resampler(ref_audio_org) else: ref_audio = ref_audio_org if use_perf: torch.cuda.nvtx.range_push(f"mel_spectrogram {i}") ref_audio = ref_audio.to("cuda") ref_mel = get_vocos_mel_spectrogram(ref_audio).squeeze(0) if use_perf: torch.cuda.nvtx.range_pop() ref_mel_len = ref_mel.shape[-1] assert ref_mel.shape[0] == 100 ref_mel_list.append(ref_mel) ref_mel_len_list.append(ref_mel_len) estimated_reference_target_mel_len.append( int(ref_mel_len * (1 + len(target_text.encode("utf-8")) / len(prompt_text.encode("utf-8")))) ) ref_mel_batch = padded_mel_batch(ref_mel_list) ref_mel_len_batch = torch.LongTensor(ref_mel_len_list) pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True) text_pad_sequence = list_str_to_idx(pinyin_list, vocab_char_map) if use_perf: torch.cuda.nvtx.range_pop() return { "ids": ids, "ref_rms_list": ref_rms_list, "ref_mel_batch": ref_mel_batch, "ref_mel_len_batch": ref_mel_len_batch, "text_pad_sequence": text_pad_sequence, "estimated_reference_target_mel_len": estimated_reference_target_mel_len, } def init_distributed(): world_size = int(os.environ.get("WORLD_SIZE", 1)) local_rank = int(os.environ.get("LOCAL_RANK", 0)) rank = int(os.environ.get("RANK", 0)) print( "Inference on multiple gpus, this gpu {}".format(local_rank) + ", rank {}, world_size {}".format(rank, world_size) ) torch.cuda.set_device(local_rank) # Initialize process group with explicit device IDs dist.init_process_group( "nccl", ) return world_size, local_rank, rank def load_vocoder( vocoder_name="vocos", is_local=False, local_path="", device="cuda", hf_cache_dir=None, vocoder_trt_engine_path=None ): if vocoder_name == "vocos": if vocoder_trt_engine_path is not None: vocoder = VocosTensorRT(engine_path=vocoder_trt_engine_path) else: # vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device) if is_local: print(f"Load vocos from local path {local_path}") config_path = f"{local_path}/config.yaml" model_path = f"{local_path}/pytorch_model.bin" else: print("Download Vocos from huggingface charactr/vocos-mel-24khz") repo_id = "charactr/vocos-mel-24khz" config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml") model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin") vocoder = Vocos.from_hparams(config_path) state_dict = torch.load(model_path, map_location="cpu", weights_only=True) from vocos.feature_extractors import EncodecFeatures if isinstance(vocoder.feature_extractor, EncodecFeatures): encodec_parameters = { "feature_extractor.encodec." + key: value for key, value in vocoder.feature_extractor.encodec.state_dict().items() } state_dict.update(encodec_parameters) vocoder.load_state_dict(state_dict) vocoder = vocoder.eval().to(device) elif vocoder_name == "bigvgan": raise NotImplementedError("BigVGAN is not implemented yet") return vocoder class VocosTensorRT: def __init__(self, engine_path="./vocos_vocoder.plan", stream=None): TRT_LOGGER = trt.Logger(trt.Logger.WARNING) trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="") logger.info(f"Loading vocoder engine from {engine_path}") self.engine_path = engine_path with open(engine_path, "rb") as f: engine_buffer = f.read() self.session = Session.from_serialized_engine(engine_buffer) self.stream = stream if stream is not None else torch.cuda.current_stream().cuda_stream def decode(self, mels): mels = mels.contiguous() inputs = {"mel": mels} output_info = self.session.infer_shapes([TensorInfo("mel", trt.DataType.FLOAT, mels.shape)]) outputs = { t.name: torch.empty(tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device="cuda") for t in output_info } ok = self.session.run(inputs, outputs, self.stream) assert ok, "Runtime execution failed for vae session" samples = outputs["waveform"] return samples def main(): args = get_args() os.makedirs(args.output_dir, exist_ok=True) assert torch.cuda.is_available() world_size, local_rank, rank = init_distributed() device = torch.device(f"cuda:{local_rank}") vocab_char_map, vocab_size = get_tokenizer(args.vocab_file, "custom") tllm_model_dir = args.tllm_model_dir with open(os.path.join(tllm_model_dir, "config.json")) as f: tllm_model_config = json.load(f) if args.backend_type == "trt": model = F5TTS( tllm_model_config, debug_mode=False, tllm_model_dir=tllm_model_dir, model_path=args.model_path, vocab_size=vocab_size, ) elif args.backend_type == "pytorch": from f5_tts.infer.utils_infer import load_model from f5_tts.model import DiT pretrained_config = tllm_model_config["pretrained_config"] pt_model_config = dict( dim=pretrained_config["hidden_size"], depth=pretrained_config["num_hidden_layers"], heads=pretrained_config["num_attention_heads"], ff_mult=pretrained_config["ff_mult"], text_dim=pretrained_config["text_dim"], text_mask_padding=pretrained_config["text_mask_padding"], conv_layers=pretrained_config["conv_layers"], pe_attn_head=pretrained_config["pe_attn_head"], # attn_backend="flash_attn", # attn_mask_enabled=True, ) model = load_model(DiT, pt_model_config, args.model_path) vocoder = load_vocoder( vocoder_name=args.vocoder, device=device, vocoder_trt_engine_path=args.vocoder_trt_engine_path ) dataset = load_dataset( "yuekai/seed_tts", split=args.split_name, trust_remote_code=True, ) def add_estimated_duration(example): prompt_audio_len = example["prompt_audio"]["array"].shape[0] scale_factor = 1 + len(example["target_text"]) / len(example["prompt_text"]) estimated_duration = prompt_audio_len * scale_factor example["estimated_duration"] = estimated_duration / example["prompt_audio"]["sampling_rate"] return example dataset = dataset.map(add_estimated_duration) dataset = dataset.sort("estimated_duration", reverse=True) if args.use_perf: # dataset_list = [dataset.select(range(1)) for i in range(16)] # seq_len 1000 dataset_list_short = [dataset.select([24]) for i in range(8)] # seq_len 719 # dataset_list_long = [dataset.select([23]) for i in range(8)] # seq_len 2002 # dataset = datasets.concatenate_datasets(dataset_list_short + dataset_list_long) dataset = datasets.concatenate_datasets(dataset_list_short) if world_size > 1: sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) else: # This would disable shuffling sampler = None dataloader = DataLoader( dataset, batch_size=args.batch_size, sampler=sampler, shuffle=False, num_workers=args.num_workers, prefetch_factor=args.prefetch, collate_fn=lambda x: data_collator(x, vocab_char_map, use_perf=args.use_perf), ) total_steps = len(dataset) if args.enable_warmup: for batch in dataloader: ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device) text_pad_seq = batch["text_pad_sequence"].to(device) total_mel_lens = batch["estimated_reference_target_mel_len"] cond_pad_seq = F.pad(ref_mels, (0, 0, 0, max(total_mel_lens) - ref_mels.shape[1], 0, 0)) if args.backend_type == "trt": _ = model.sample( text_pad_seq, cond_pad_seq, ref_mel_lens, total_mel_lens, remove_input_padding=args.remove_input_padding, ) elif args.backend_type == "pytorch": total_mel_lens = torch.tensor(total_mel_lens, device=device) with torch.inference_mode(): generated, _ = model.sample( cond=ref_mels, text=text_pad_seq, duration=total_mel_lens, steps=32, cfg_strength=2.0, sway_sampling_coef=-1, ) if rank == 0: progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs") decoding_time = 0 vocoder_time = 0 total_duration = 0 if args.use_perf: torch.cuda.cudart().cudaProfilerStart() total_decoding_time = time.time() for batch in dataloader: if args.use_perf: torch.cuda.nvtx.range_push("data sample") ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device) text_pad_seq = batch["text_pad_sequence"].to(device) total_mel_lens = batch["estimated_reference_target_mel_len"] cond_pad_seq = F.pad(ref_mels, (0, 0, 0, max(total_mel_lens) - ref_mels.shape[1], 0, 0)) if args.use_perf: torch.cuda.nvtx.range_pop() if args.backend_type == "trt": generated, cost_time = model.sample( text_pad_seq, cond_pad_seq, ref_mel_lens, total_mel_lens, remove_input_padding=args.remove_input_padding, use_perf=args.use_perf, ) elif args.backend_type == "pytorch": total_mel_lens = torch.tensor(total_mel_lens, device=device) with torch.inference_mode(): start_time = time.time() generated, _ = model.sample( cond=ref_mels, text=text_pad_seq, duration=total_mel_lens, lens=ref_mel_lens, steps=32, cfg_strength=2.0, sway_sampling_coef=-1, ) cost_time = time.time() - start_time decoding_time += cost_time vocoder_start_time = time.time() target_rms = 0.1 target_sample_rate = 24000 for i, gen in enumerate(generated): gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0) gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32) if args.vocoder == "vocos": if args.use_perf: torch.cuda.nvtx.range_push("vocoder decode") generated_wave = vocoder.decode(gen_mel_spec).cpu() if args.use_perf: torch.cuda.nvtx.range_pop() else: generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu() if batch["ref_rms_list"][i] < target_rms: generated_wave = generated_wave * batch["ref_rms_list"][i] / target_rms utt = batch["ids"][i] torchaudio.save( f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate, ) total_duration += generated_wave.shape[1] / target_sample_rate vocoder_time += time.time() - vocoder_start_time if rank == 0: progress_bar.update(world_size * len(batch["ids"])) total_decoding_time = time.time() - total_decoding_time if rank == 0: progress_bar.close() rtf = total_decoding_time / total_duration s = f"RTF: {rtf:.4f}\n" s += f"total_duration: {total_duration:.3f} seconds\n" s += f"({total_duration / 3600:.2f} hours)\n" s += f"DiT time: {decoding_time:.3f} seconds ({decoding_time / 3600:.2f} hours)\n" s += f"Vocoder time: {vocoder_time:.3f} seconds ({vocoder_time / 3600:.2f} hours)\n" s += f"total decoding time: {total_decoding_time:.3f} seconds ({total_decoding_time / 3600:.2f} hours)\n" s += f"batch size: {args.batch_size}\n" print(s) with open(f"{args.output_dir}/rtf.txt", "w") as f: f.write(s) dist.barrier() dist.destroy_process_group() if __name__ == "__main__": main() ================================================ FILE: src/f5_tts/runtime/triton_trtllm/client_grpc.py ================================================ #!/usr/bin/env python3 # Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) # 2023 Nvidia (authors: Yuekai Zhang) # 2023 Recurrent.ai (authors: Songtao Shi) # See LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ This script supports to load dataset from huggingface and sends it to the server for decoding, in parallel. Usage: num_task=2 # For offline F5-TTS python3 client_grpc.py \ --server-addr localhost \ --model-name f5_tts \ --num-tasks $num_task \ --huggingface-dataset yuekai/seed_tts \ --split-name test_zh \ --log-dir ./log_concurrent_tasks_${num_task} """ import argparse import asyncio import json import os import time import types from pathlib import Path import numpy as np import soundfile as sf import tritonclient import tritonclient.grpc.aio as grpcclient from tritonclient.utils import np_to_triton_dtype def write_triton_stats(stats, summary_file): with open(summary_file, "w") as summary_f: model_stats = stats["model_stats"] # write a note, the log is from triton_client.get_inference_statistics(), to better human readability summary_f.write( "The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n" ) summary_f.write("To learn more about the log, please refer to: \n") summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n") summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n") summary_f.write( "To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n" ) summary_f.write( "However, there is a trade-off between the increased queue time and the increased batch size. \n" ) summary_f.write( "You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n" ) summary_f.write( "See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n" ) for model_state in model_stats: if "last_inference" not in model_state: continue summary_f.write(f"model name is {model_state['name']} \n") model_inference_stats = model_state["inference_stats"] total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9 total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9 total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9 total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9 summary_f.write( f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa ) model_batch_stats = model_state["batch_stats"] for batch in model_batch_stats: batch_size = int(batch["batch_size"]) compute_input = batch["compute_input"] compute_output = batch["compute_output"] compute_infer = batch["compute_infer"] batch_count = int(compute_infer["count"]) assert compute_infer["count"] == compute_output["count"] == compute_input["count"] compute_infer_time_ms = int(compute_infer["ns"]) / 1e6 compute_input_time_ms = int(compute_input["ns"]) / 1e6 compute_output_time_ms = int(compute_output["ns"]) / 1e6 summary_f.write( f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" # noqa ) summary_f.write( f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " # noqa ) summary_f.write( f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n" # noqa ) def get_args(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( "--server-addr", type=str, default="localhost", help="Address of the server", ) parser.add_argument( "--server-port", type=int, default=8001, help="Grpc port of the triton server, default is 8001", ) parser.add_argument( "--reference-audio", type=str, default=None, help="Path to a single audio file. It can't be specified at the same time with --manifest-dir", ) parser.add_argument( "--reference-text", type=str, default="", help="", ) parser.add_argument( "--target-text", type=str, default="", help="", ) parser.add_argument( "--huggingface-dataset", type=str, default="yuekai/seed_tts", help="dataset name in huggingface dataset hub", ) parser.add_argument( "--split-name", type=str, default="wenetspeech4tts", choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"], help="dataset split name, default is 'test'", ) parser.add_argument( "--manifest-path", type=str, default=None, help="Path to the manifest dir which includes wav.scp trans.txt files.", ) parser.add_argument( "--model-name", type=str, default="f5_tts", help="triton model_repo module name to request", ) parser.add_argument( "--num-tasks", type=int, default=1, help="Number of concurrent tasks for sending", ) parser.add_argument( "--log-interval", type=int, default=5, help="Controls how frequently we print the log.", ) parser.add_argument( "--compute-wer", action="store_true", default=False, help="""True to compute WER. """, ) parser.add_argument( "--log-dir", type=str, required=False, default="./tests/client_grpc", help="log directory", ) parser.add_argument( "--batch-size", type=int, default=1, help="Inference batch_size per request for offline mode.", ) return parser.parse_args() def load_audio(wav_path, target_sample_rate=24000): assert target_sample_rate == 24000, "hard coding in server" if isinstance(wav_path, dict): waveform = wav_path["array"] sample_rate = wav_path["sampling_rate"] else: waveform, sample_rate = sf.read(wav_path) if sample_rate != target_sample_rate: from scipy.signal import resample waveform = resample(waveform, int(len(waveform) * (target_sample_rate / sample_rate))) return waveform, target_sample_rate async def send( manifest_item_list: list, name: str, triton_client: tritonclient.grpc.aio.InferenceServerClient, protocol_client: types.ModuleType, log_interval: int, model_name: str, padding_duration: int = None, audio_save_dir: str = "./", save_sample_rate: int = 24000, ): total_duration = 0.0 latency_data = [] task_id = int(name[5:]) print(f"manifest_item_list: {manifest_item_list}") for i, item in enumerate(manifest_item_list): if i % log_interval == 0: print(f"{name}: {i}/{len(manifest_item_list)}") waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=24000) duration = len(waveform) / sample_rate lengths = np.array([[len(waveform)]], dtype=np.int32) reference_text, target_text = item["reference_text"], item["target_text"] estimated_target_duration = duration / len(reference_text) * len(target_text) if padding_duration: # padding to nearset 10 seconds samples = np.zeros( ( 1, padding_duration * sample_rate * ((int(estimated_target_duration + duration) // padding_duration) + 1), ), dtype=np.float32, ) samples[0, : len(waveform)] = waveform else: samples = waveform samples = samples.reshape(1, -1).astype(np.float32) inputs = [ protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)), protocol_client.InferInput("reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)), protocol_client.InferInput("reference_text", [1, 1], "BYTES"), protocol_client.InferInput("target_text", [1, 1], "BYTES"), ] inputs[0].set_data_from_numpy(samples) inputs[1].set_data_from_numpy(lengths) input_data_numpy = np.array([reference_text], dtype=object) input_data_numpy = input_data_numpy.reshape((1, 1)) inputs[2].set_data_from_numpy(input_data_numpy) input_data_numpy = np.array([target_text], dtype=object) input_data_numpy = input_data_numpy.reshape((1, 1)) inputs[3].set_data_from_numpy(input_data_numpy) outputs = [protocol_client.InferRequestedOutput("waveform")] sequence_id = 100000000 + i + task_id * 10 start = time.time() response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs) audio = response.as_numpy("waveform").reshape(-1) end = time.time() - start audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav") sf.write(audio_save_path, audio, save_sample_rate, "PCM_16") actual_duration = len(audio) / save_sample_rate latency_data.append((end, actual_duration)) total_duration += actual_duration return total_duration, latency_data def load_manifests(manifest_path): with open(manifest_path, "r") as f: manifest_list = [] for line in f: assert len(line.strip().split("|")) == 4 utt, prompt_text, prompt_wav, gt_text = line.strip().split("|") utt = Path(utt).stem # gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav") if not os.path.isabs(prompt_wav): prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav) manifest_list.append( { "audio_filepath": prompt_wav, "reference_text": prompt_text, "target_text": gt_text, "target_audio_path": utt, } ) return manifest_list def split_data(data, k): n = len(data) if n < k: print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.") k = n quotient = n // k remainder = n % k result = [] start = 0 for i in range(k): if i < remainder: end = start + quotient + 1 else: end = start + quotient result.append(data[start:end]) start = end return result async def main(): args = get_args() url = f"{args.server_addr}:{args.server_port}" triton_client = grpcclient.InferenceServerClient(url=url, verbose=False) protocol_client = grpcclient if args.reference_audio: args.num_tasks = 1 args.log_interval = 1 manifest_item_list = [ { "reference_text": args.reference_text, "target_text": args.target_text, "audio_filepath": args.reference_audio, "target_audio_path": "test", } ] elif args.huggingface_dataset: import datasets dataset = datasets.load_dataset( args.huggingface_dataset, split=args.split_name, trust_remote_code=True, ) manifest_item_list = [] for i in range(len(dataset)): manifest_item_list.append( { "audio_filepath": dataset[i]["prompt_audio"], "reference_text": dataset[i]["prompt_text"], "target_audio_path": dataset[i]["id"], "target_text": dataset[i]["target_text"], } ) else: manifest_item_list = load_manifests(args.manifest_path) args.num_tasks = min(args.num_tasks, len(manifest_item_list)) manifest_item_list = split_data(manifest_item_list, args.num_tasks) os.makedirs(args.log_dir, exist_ok=True) tasks = [] start_time = time.time() for i in range(args.num_tasks): task = asyncio.create_task( send( manifest_item_list[i], name=f"task-{i}", triton_client=triton_client, protocol_client=protocol_client, log_interval=args.log_interval, model_name=args.model_name, audio_save_dir=args.log_dir, padding_duration=1, save_sample_rate=24000, ) ) tasks.append(task) ans_list = await asyncio.gather(*tasks) end_time = time.time() elapsed = end_time - start_time total_duration = 0.0 latency_data = [] for ans in ans_list: total_duration += ans[0] latency_data += ans[1] rtf = elapsed / total_duration s = f"RTF: {rtf:.4f}\n" s += f"total_duration: {total_duration:.3f} seconds\n" s += f"({total_duration / 3600:.2f} hours)\n" s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n" latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data] latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0 latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0 s += f"latency_variance: {latency_variance:.2f}\n" s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n" s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n" s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n" s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n" s += f"average_latency_ms: {latency_ms:.2f}\n" print(s) if args.manifest_path: name = Path(args.manifest_path).stem elif args.split_name: name = args.split_name with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f: f.write(s) stats = await triton_client.get_inference_statistics(model_name="", as_json=True) write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt") metadata = await triton_client.get_model_config(model_name=args.model_name, as_json=True) with open(f"{args.log_dir}/model_config-{name}.json", "w") as f: json.dump(metadata, f, indent=4) if __name__ == "__main__": asyncio.run(main()) ================================================ FILE: src/f5_tts/runtime/triton_trtllm/client_http.py ================================================ # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions # are met: # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # * Neither the name of NVIDIA CORPORATION nor the names of its # contributors may be used to endorse or promote products derived # from this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import argparse import os import numpy as np import requests import soundfile as sf def get_args(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( "--server-url", type=str, default="localhost:8000", help="Address of the server", ) parser.add_argument( "--reference-audio", type=str, default="../../infer/examples/basic/basic_ref_en.wav", help="Path to a single audio file. It can't be specified at the same time with --manifest-dir", ) parser.add_argument( "--reference-text", type=str, default="Some call me nature, others call me mother nature.", help="", ) parser.add_argument( "--target-text", type=str, default="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring.", help="", ) parser.add_argument( "--model-name", type=str, default="f5_tts", help="triton model_repo module name to request", ) parser.add_argument( "--output-audio", type=str, default="tests/client_http.wav", help="Path to save the output audio", ) return parser.parse_args() def prepare_request( waveform, reference_text, target_text, sample_rate=24000, audio_save_dir: str = "./", ): assert len(waveform.shape) == 1, "waveform should be 1D" lengths = np.array([[len(waveform)]], dtype=np.int32) waveform = waveform.reshape(1, -1).astype(np.float32) data = { "inputs": [ {"name": "reference_wav", "shape": waveform.shape, "datatype": "FP32", "data": waveform.tolist()}, { "name": "reference_wav_len", "shape": lengths.shape, "datatype": "INT32", "data": lengths.tolist(), }, {"name": "reference_text", "shape": [1, 1], "datatype": "BYTES", "data": [reference_text]}, {"name": "target_text", "shape": [1, 1], "datatype": "BYTES", "data": [target_text]}, ] } return data def load_audio(wav_path, target_sample_rate=24000): assert target_sample_rate == 24000, "hard coding in server" if isinstance(wav_path, dict): waveform = wav_path["array"] sample_rate = wav_path["sampling_rate"] else: waveform, sample_rate = sf.read(wav_path) if sample_rate != target_sample_rate: from scipy.signal import resample waveform = resample(waveform, int(len(waveform) * (target_sample_rate / sample_rate))) return waveform, target_sample_rate if __name__ == "__main__": args = get_args() server_url = args.server_url if not server_url.startswith(("http://", "https://")): server_url = f"http://{server_url}" url = f"{server_url}/v2/models/{args.model_name}/infer" waveform, sr = load_audio(args.reference_audio) assert sr == 24000, "sample rate hardcoded in server" waveform = np.array(waveform, dtype=np.float32) data = prepare_request(waveform, args.reference_text, args.target_text) rsp = requests.post( url, headers={"Content-Type": "application/json"}, json=data, verify=False, params={"request_id": "0"} ) result = rsp.json() audio = result["outputs"][0]["data"] audio = np.array(audio, dtype=np.float32) os.makedirs(os.path.dirname(args.output_audio), exist_ok=True) sf.write(args.output_audio, audio, 24000, "PCM_16") ================================================ FILE: src/f5_tts/runtime/triton_trtllm/docker-compose.yml ================================================ services: tts: image: soar97/triton-f5-tts:24.12 shm_size: '1gb' ports: - "8000:8000" - "8001:8001" - "8002:8002" environment: - PYTHONIOENCODING=utf-8 - MODEL_ID=${MODEL_ID} deploy: resources: reservations: devices: - driver: nvidia device_ids: ['0'] capabilities: [gpu] command: > /bin/bash -c "pip install vocos && rm -rf F5-TTS && git clone https://github.com/SWivid/F5-TTS.git && cd F5-TTS/src/f5_tts/runtime/triton_trtllm/ && bash run.sh 0 4 $MODEL" ================================================ FILE: src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ================================================ import math import os import time from functools import wraps from typing import List, Optional import tensorrt as trt import tensorrt_llm import torch import torch.nn as nn import torch.nn.functional as F from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch from tensorrt_llm.logger import logger from tensorrt_llm.runtime.session import Session from torch.nn.utils.rnn import pad_sequence def remove_tensor_padding(input_tensor, input_tensor_lengths=None): # Audio tensor case: batch, seq_len, feature_len # position_ids case: batch, seq_len assert input_tensor_lengths is not None, "input_tensor_lengths must be provided for 3D input_tensor" # Initialize a list to collect valid sequences valid_sequences = [] for i in range(input_tensor.shape[0]): valid_length = input_tensor_lengths[i] valid_sequences.append(input_tensor[i, :valid_length]) # Concatenate all valid sequences along the batch dimension output_tensor = torch.cat(valid_sequences, dim=0).contiguous() return output_tensor class TextEmbedding(nn.Module): def __init__( self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2, precompute_max_pos=4096 ): super().__init__() self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token self.mask_padding = mask_padding self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, precompute_max_pos), persistent=False) self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]) def forward(self, text, seq_len, drop_text=False): text = text + 1 text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens text = F.pad(text, (0, seq_len - text.shape[1]), value=0) if self.mask_padding: text_mask = text == 0 if drop_text: # cfg for text text = torch.zeros_like(text) text = self.text_embed(text) # b n -> b n d text = text + self.freqs_cis[:seq_len, :] if self.mask_padding: text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0) for block in self.text_blocks: text = block(text) text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0) else: text = self.text_blocks(text) return text class GRN(nn.Module): def __init__(self, dim): super().__init__() self.gamma = nn.Parameter(torch.zeros(1, 1, dim)) self.beta = nn.Parameter(torch.zeros(1, 1, dim)) def forward(self, x): Gx = torch.norm(x, p=2, dim=1, keepdim=True) Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) return self.gamma * (x * Nx) + self.beta + x class ConvNeXtV2Block(nn.Module): def __init__( self, dim: int, intermediate_dim: int, dilation: int = 1, ): super().__init__() padding = (dilation * (7 - 1)) // 2 self.dwconv = nn.Conv1d( dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation ) # depthwise conv self.norm = nn.LayerNorm(dim, eps=1e-6) self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers self.act = nn.GELU() self.grn = GRN(intermediate_dim) self.pwconv2 = nn.Linear(intermediate_dim, dim) def forward(self, x: torch.Tensor) -> torch.Tensor: residual = x x = x.transpose(1, 2) # b n d -> b d n x = self.dwconv(x) x = x.transpose(1, 2) # b d n -> b n d x = self.norm(x) x = self.pwconv1(x) x = self.act(x) x = self.grn(x) x = self.pwconv2(x) return residual + x def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0): # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning # has some connection to NTK literature # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py theta *= theta_rescale_factor ** (dim / (dim - 2)) freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) # type: ignore freqs = torch.outer(t, freqs).float() # type: ignore freqs_cos = torch.cos(freqs) # real part freqs_sin = torch.sin(freqs) # imaginary part return torch.cat([freqs_cos, freqs_sin], dim=-1) def get_text_embed_dict(ckpt_path, use_ema=True): ckpt_type = ckpt_path.split(".")[-1] if ckpt_type == "safetensors": from safetensors.torch import load_file checkpoint = load_file(ckpt_path) else: checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True) if use_ema: if ckpt_type == "safetensors": checkpoint = {"ema_model_state_dict": checkpoint} checkpoint["model_state_dict"] = { k.replace("ema_model.", ""): v for k, v in checkpoint["ema_model_state_dict"].items() if k not in ["initted", "step"] } else: if ckpt_type == "safetensors": checkpoint = {"model_state_dict": checkpoint} model_params = checkpoint["model_state_dict"] text_embed_dict = {} for key in model_params.keys(): # transformer.text_embed.text_embed.weight -> text_embed.weight if "text_embed" in key: text_embed_dict[key.replace("transformer.text_embed.", "")] = model_params[key] return text_embed_dict class F5TTS(object): def __init__( self, config, debug_mode=True, stream: Optional[torch.cuda.Stream] = None, tllm_model_dir: Optional[str] = None, model_path: Optional[str] = None, vocab_size: Optional[int] = None, ): self.dtype = config["pretrained_config"]["dtype"] rank = tensorrt_llm.mpi_rank() world_size = config["pretrained_config"]["mapping"]["world_size"] cp_size = config["pretrained_config"]["mapping"]["cp_size"] tp_size = config["pretrained_config"]["mapping"]["tp_size"] pp_size = config["pretrained_config"]["mapping"]["pp_size"] assert pp_size == 1 self.mapping = tensorrt_llm.Mapping( world_size=world_size, rank=rank, cp_size=cp_size, tp_size=tp_size, pp_size=1, gpus_per_node=1 ) local_rank = rank % self.mapping.gpus_per_node self.device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(self.device) self.stream = stream if self.stream is None: self.stream = torch.cuda.Stream(self.device) torch.cuda.set_stream(self.stream) engine_file = os.path.join(tllm_model_dir, f"rank{rank}.engine") logger.info(f"Loading engine from {engine_file}") with open(engine_file, "rb") as f: engine_buffer = f.read() assert engine_buffer is not None self.session = Session.from_serialized_engine(engine_buffer) self.debug_mode = debug_mode self.inputs = {} self.outputs = {} self.buffer_allocated = False expected_tensor_names = ["noise", "cond", "time", "rope_cos", "rope_sin", "input_lengths", "denoised"] found_tensor_names = [self.session.engine.get_tensor_name(i) for i in range(self.session.engine.num_io_tensors)] if not self.debug_mode and set(expected_tensor_names) != set(found_tensor_names): logger.error( f"The following expected tensors are not found: {set(expected_tensor_names).difference(set(found_tensor_names))}" ) logger.error( f"Those tensors in engine are not expected: {set(found_tensor_names).difference(set(expected_tensor_names))}" ) logger.error(f"Expected tensor names: {expected_tensor_names}") logger.error(f"Found tensor names: {found_tensor_names}") raise RuntimeError("Tensor names in engine are not the same as expected.") if self.debug_mode: self.debug_tensors = list(set(found_tensor_names) - set(expected_tensor_names)) self.max_mel_len = 4096 self.text_embedding = TextEmbedding( text_num_embeds=vocab_size, text_dim=config["pretrained_config"]["text_dim"], mask_padding=config["pretrained_config"]["text_mask_padding"], conv_layers=config["pretrained_config"]["conv_layers"], precompute_max_pos=self.max_mel_len, ).to(self.device) self.text_embedding.load_state_dict(get_text_embed_dict(model_path), strict=True) self.n_mel_channels = config["pretrained_config"]["mel_dim"] self.head_dim = config["pretrained_config"]["dim_head"] self.base_rescale_factor = 1.0 self.interpolation_factor = 1.0 base = 10000.0 * self.base_rescale_factor ** (self.head_dim / (self.head_dim - 2)) inv_freq = 1.0 / (base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim)) freqs = torch.outer(torch.arange(self.max_mel_len, dtype=torch.float32), inv_freq) / self.interpolation_factor self.freqs = freqs.repeat_interleave(2, dim=-1).unsqueeze(0) self.rope_cos = self.freqs.cos().half() self.rope_sin = self.freqs.sin().half() self.nfe_steps = 32 epss = { 5: [0, 2, 4, 8, 16, 32], 6: [0, 2, 4, 6, 8, 16, 32], 7: [0, 2, 4, 6, 8, 16, 24, 32], 10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32], 12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32], 16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32], } t = 1 / 32 * torch.tensor(epss.get(self.nfe_steps, list(range(self.nfe_steps + 1))), dtype=torch.float32) time_step = 1 - torch.cos(torch.pi * t / 2) delta_t = torch.diff(time_step) freq_embed_dim = 256 # Warning: hard coding 256 here time_expand = torch.zeros((1, self.nfe_steps, freq_embed_dim), dtype=torch.float32) half_dim = freq_embed_dim // 2 emb_factor = math.log(10000) / (half_dim - 1) emb_factor = 1000.0 * torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb_factor) for i in range(self.nfe_steps): emb = time_step[i] * emb_factor time_expand[:, i, :] = torch.cat((emb.sin(), emb.cos()), dim=-1) self.time_expand = time_expand.to(self.device) self.delta_t = torch.cat((delta_t, delta_t), dim=0).contiguous().to(self.device) def _tensor_dtype(self, name): # return torch dtype given tensor name for convenience dtype = trt_dtype_to_torch(self.session.engine.get_tensor_dtype(name)) return dtype def _setup(self, batch_size, seq_len): for i in range(self.session.engine.num_io_tensors): name = self.session.engine.get_tensor_name(i) if self.session.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT: shape = list(self.session.engine.get_tensor_shape(name)) shape[0] = batch_size shape[1] = seq_len self.outputs[name] = torch.empty(shape, dtype=self._tensor_dtype(name), device=self.device) self.buffer_allocated = True def cuda_stream_guard(func): """Sync external stream and set current stream to the one bound to the session. Reset on exit.""" @wraps(func) def wrapper(self, *args, **kwargs): external_stream = torch.cuda.current_stream() if external_stream != self.stream: external_stream.synchronize() torch.cuda.set_stream(self.stream) ret = func(self, *args, **kwargs) if external_stream != self.stream: self.stream.synchronize() torch.cuda.set_stream(external_stream) return ret return wrapper @cuda_stream_guard def forward( self, noise: torch.Tensor, cond: torch.Tensor, time_expand: torch.Tensor, rope_cos: torch.Tensor, rope_sin: torch.Tensor, input_lengths: torch.Tensor, delta_t: torch.Tensor, use_perf: bool = False, ): if use_perf: torch.cuda.nvtx.range_push("flow matching") cfg_strength = 2.0 batch_size = noise.shape[0] half_batch = batch_size // 2 noise_half = noise[:half_batch] # Store the initial half of noise input_type = str_dtype_to_torch(self.dtype) # Keep a copy of the initial tensors cond = cond.to(input_type) rope_cos = rope_cos.to(input_type) rope_sin = rope_sin.to(input_type) input_lengths = input_lengths.to(str_dtype_to_torch("int32")) # Instead of iteratively updating noise within a single model context, # we'll do a single forward pass for each iteration with fresh context setup for i in range(self.nfe_steps): # Re-setup the buffers for clean execution self._setup(batch_size, noise.shape[1]) if not self.buffer_allocated: raise RuntimeError("Buffer not allocated, please call setup first!") # Re-create combined noises for this iteration current_noise = torch.cat([noise_half, noise_half], dim=0).to(input_type) # Get time step for this iteration current_time = time_expand[:, i].to(input_type) # Create fresh input dictionary for this iteration current_inputs = { "noise": current_noise, "cond": cond, "time": current_time, "rope_cos": rope_cos, "rope_sin": rope_sin, "input_lengths": input_lengths, } # Update inputs and set shapes self.inputs.clear() # Clear previous inputs self.inputs.update(**current_inputs) self.session.set_shapes(self.inputs) if use_perf: torch.cuda.nvtx.range_push(f"execute {i}") ok = self.session.run(self.inputs, self.outputs, self.stream.cuda_stream) assert ok, "Failed to execute model" # self.session.context.execute_async_v3(self.stream.cuda_stream) if use_perf: torch.cuda.nvtx.range_pop() # Process results t_scale = delta_t[i].unsqueeze(0).to(input_type) # Extract predictions pred_cond = self.outputs["denoised"][:half_batch] pred_uncond = self.outputs["denoised"][half_batch:] # Apply classifier-free guidance with safeguards guidance = pred_cond + (pred_cond - pred_uncond) * cfg_strength # Calculate update for noise noise_half = noise_half + guidance * t_scale if use_perf: torch.cuda.nvtx.range_pop() return noise_half def sample( self, text_pad_sequence: torch.Tensor, cond_pad_sequence: torch.Tensor, ref_mel_len_batch: torch.Tensor, estimated_reference_target_mel_len: List[int], remove_input_padding: bool = False, use_perf: bool = False, ): if use_perf: torch.cuda.nvtx.range_push("text embedding") batch = text_pad_sequence.shape[0] max_seq_len = cond_pad_sequence.shape[1] # get text_embed one by one to avoid misalignment text_and_drop_embedding_list = [] for i in range(batch): text_embedding_i = self.text_embedding( text_pad_sequence[i].unsqueeze(0).to(self.device), estimated_reference_target_mel_len[i], drop_text=False, ) text_embedding_drop_i = self.text_embedding( text_pad_sequence[i].unsqueeze(0).to(self.device), estimated_reference_target_mel_len[i], drop_text=True, ) text_and_drop_embedding_list.extend([text_embedding_i[0], text_embedding_drop_i[0]]) # pad separately computed text_embed to form batch with max_seq_len text_and_drop_embedding = pad_sequence( text_and_drop_embedding_list, batch_first=True, padding_value=0, ) text_embedding = text_and_drop_embedding[0::2] text_embedding_drop = text_and_drop_embedding[1::2] noise = torch.randn_like(cond_pad_sequence).to(self.device) rope_cos = self.rope_cos[:, :max_seq_len, :].float().repeat(batch, 1, 1) rope_sin = self.rope_sin[:, :max_seq_len, :].float().repeat(batch, 1, 1) cat_mel_text = torch.cat( ( cond_pad_sequence, text_embedding, ), dim=-1, ) cat_mel_text_drop = torch.cat( ( torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device), text_embedding_drop, ), dim=-1, ) time_expand = self.time_expand.repeat(2 * batch, 1, 1).contiguous() # Convert estimated_reference_target_mel_len to tensor input_lengths = torch.tensor(estimated_reference_target_mel_len, dtype=torch.int32) # combine above along the batch dimension inputs = { "noise": torch.cat((noise, noise), dim=0).contiguous(), "cond": torch.cat((cat_mel_text, cat_mel_text_drop), dim=0).contiguous(), "time_expand": time_expand, "rope_cos": torch.cat((rope_cos, rope_cos), dim=0).contiguous(), "rope_sin": torch.cat((rope_sin, rope_sin), dim=0).contiguous(), "input_lengths": torch.cat((input_lengths, input_lengths), dim=0).contiguous(), "delta_t": self.delta_t, } if use_perf and remove_input_padding: torch.cuda.nvtx.range_push("remove input padding") if remove_input_padding: max_seq_len = inputs["cond"].shape[1] inputs["noise"] = remove_tensor_padding(inputs["noise"], inputs["input_lengths"]) inputs["cond"] = remove_tensor_padding(inputs["cond"], inputs["input_lengths"]) # for time_expand, convert from B,D to B,T,D by repeat inputs["time_expand"] = inputs["time_expand"].unsqueeze(1).repeat(1, max_seq_len, 1, 1) inputs["time_expand"] = remove_tensor_padding(inputs["time_expand"], inputs["input_lengths"]) inputs["rope_cos"] = remove_tensor_padding(inputs["rope_cos"], inputs["input_lengths"]) inputs["rope_sin"] = remove_tensor_padding(inputs["rope_sin"], inputs["input_lengths"]) if use_perf and remove_input_padding: torch.cuda.nvtx.range_pop() for key in inputs: inputs[key] = inputs[key].to(self.device) if use_perf: torch.cuda.nvtx.range_pop() start_time = time.time() denoised = self.forward(**inputs, use_perf=use_perf) cost_time = time.time() - start_time if use_perf and remove_input_padding: torch.cuda.nvtx.range_push("remove input padding output") if remove_input_padding: denoised_list = [] start_idx = 0 for i in range(batch): denoised_list.append(denoised[start_idx : start_idx + inputs["input_lengths"][i]]) start_idx += inputs["input_lengths"][i] if use_perf and remove_input_padding: torch.cuda.nvtx.range_pop() return denoised_list, cost_time return denoised, cost_time ================================================ FILE: src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py ================================================ # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions # are met: # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # * Neither the name of NVIDIA CORPORATION nor the names of its # contributors may be used to endorse or promote products derived # from this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import json import os import rjieba import torch import torchaudio import triton_python_backend_utils as pb_utils from f5_tts_trtllm import F5TTS from pypinyin import Style, lazy_pinyin from torch.nn.utils.rnn import pad_sequence from torch.utils.dlpack import from_dlpack, to_dlpack def get_tokenizer(vocab_file_path: str): """ tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file - "char" for char-wise tokenizer, need .txt vocab_file - "byte" for utf-8 tokenizer - "custom" if you're directly passing in a path to the vocab.txt you want to use vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols - if use "char", derived from unfiltered character & symbol counts of custom dataset - if use "byte", set to 256 (unicode byte range) """ with open(vocab_file_path, "r", encoding="utf-8") as f: vocab_char_map = {} for i, char in enumerate(f): vocab_char_map[char[:-1]] = i vocab_size = len(vocab_char_map) return vocab_char_map, vocab_size def convert_char_to_pinyin(reference_target_texts_list, polyphone=True): final_reference_target_texts_list = [] custom_trans = str.maketrans( {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"} ) # add custom trans here, to address oov def is_chinese(c): return "\u3100" <= c <= "\u9fff" # common chinese characters for text in reference_target_texts_list: char_list = [] text = text.translate(custom_trans) for seg in rjieba.cut(text): seg_byte_len = len(bytes(seg, "UTF-8")) if seg_byte_len == len(seg): # if pure alphabets and symbols if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"": char_list.append(" ") char_list.extend(seg) elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True) for i, c in enumerate(seg): if is_chinese(c): char_list.append(" ") char_list.append(seg_[i]) else: # if mixed characters, alphabets and symbols for c in seg: if ord(c) < 256: char_list.extend(c) elif is_chinese(c): char_list.append(" ") char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True)) else: char_list.append(c) final_reference_target_texts_list.append(char_list) return final_reference_target_texts_list def list_str_to_idx( text: list[str] | list[list[str]], vocab_char_map: dict[str, int], # {char: idx} padding_value=-1, ): # noqa: F722 list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True) return text class TritonPythonModel: def initialize(self, args): self.use_perf = True self.device = torch.device("cuda") self.target_audio_sample_rate = 24000 self.target_rms = 0.1 # least rms when inference, normalize to if lower self.n_fft = 1024 self.win_length = 1024 self.hop_length = 256 self.n_mel_channels = 100 self.max_mel_len = 4096 parameters = json.loads(args["model_config"])["parameters"] for key, value in parameters.items(): parameters[key] = value["string_value"] self.vocab_char_map, self.vocab_size = get_tokenizer(parameters["vocab_file"]) self.reference_sample_rate = int(parameters["reference_audio_sample_rate"]) self.resampler = torchaudio.transforms.Resample(self.reference_sample_rate, self.target_audio_sample_rate) self.tllm_model_dir = parameters["tllm_model_dir"] config_file = os.path.join(self.tllm_model_dir, "config.json") with open(config_file) as f: config = json.load(f) self.model = F5TTS( config, debug_mode=False, tllm_model_dir=self.tllm_model_dir, model_path=parameters["model_path"], vocab_size=self.vocab_size, ) self.vocoder = parameters["vocoder"] assert self.vocoder in ["vocos", "bigvgan"] if self.vocoder == "vocos": self.mel_stft = torchaudio.transforms.MelSpectrogram( sample_rate=self.target_audio_sample_rate, n_fft=self.n_fft, win_length=self.win_length, hop_length=self.hop_length, n_mels=self.n_mel_channels, power=1, center=True, normalized=False, norm=None, ).to(self.device) self.compute_mel_fn = self.get_vocos_mel_spectrogram elif self.vocoder == "bigvgan": self.compute_mel_fn = self.get_bigvgan_mel_spectrogram def get_vocos_mel_spectrogram(self, waveform): mel = self.mel_stft(waveform) mel = mel.clamp(min=1e-5).log() return mel.transpose(1, 2) def forward_vocoder(self, mel): mel = mel.to(torch.float32).contiguous().cpu() input_tensor_0 = pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel)) inference_request = pb_utils.InferenceRequest( model_name="vocoder", requested_output_names=["waveform"], inputs=[input_tensor_0] ) inference_response = inference_request.exec() if inference_response.has_error(): raise pb_utils.TritonModelException(inference_response.error().message()) else: waveform = pb_utils.get_output_tensor_by_name(inference_response, "waveform") waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu() return waveform def execute(self, requests): ( reference_text_list, target_text_list, reference_target_texts_list, estimated_reference_target_mel_len, reference_mel_len, reference_rms_list, ) = [], [], [], [], [], [] mel_features_list = [] if self.use_perf: torch.cuda.nvtx.range_push("preprocess") for request in requests: wav_tensor = pb_utils.get_input_tensor_by_name(request, "reference_wav") wav_lens = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() reference_text = reference_text[0][0].decode("utf-8") reference_text_list.append(reference_text) target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy() target_text = target_text[0][0].decode("utf-8") target_text_list.append(target_text) text = reference_text + target_text reference_target_texts_list.append(text) wav = from_dlpack(wav_tensor.to_dlpack()) wav_len = from_dlpack(wav_lens.to_dlpack()) wav_len = wav_len.squeeze() assert wav.shape[0] == 1, "Only support batch size 1 for now." wav = wav[:, :wav_len] ref_rms = torch.sqrt(torch.mean(torch.square(wav))) if ref_rms < self.target_rms: wav = wav * self.target_rms / ref_rms reference_rms_list.append(ref_rms) if self.reference_sample_rate != self.target_audio_sample_rate: wav = self.resampler(wav) wav = wav.to(self.device) if self.use_perf: torch.cuda.nvtx.range_push("compute_mel") mel_features = self.compute_mel_fn(wav) if self.use_perf: torch.cuda.nvtx.range_pop() mel_features_list.append(mel_features) reference_mel_len.append(mel_features.shape[1]) estimated_reference_target_mel_len.append( int( mel_features.shape[1] * (1 + len(target_text.encode("utf-8")) / len(reference_text.encode("utf-8"))) ) ) max_seq_len = min(max(estimated_reference_target_mel_len), self.max_mel_len) batch = len(requests) mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device) for i, mel in enumerate(mel_features_list): mel_features[i, : mel.shape[1], :] = mel reference_mel_len_tensor = torch.LongTensor(reference_mel_len).to(self.device) pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True) text_pad_sequence = list_str_to_idx(pinyin_list, self.vocab_char_map) if self.use_perf: torch.cuda.nvtx.range_pop() denoised, cost_time = self.model.sample( text_pad_sequence, mel_features, reference_mel_len_tensor, estimated_reference_target_mel_len, remove_input_padding=False, use_perf=self.use_perf, ) if self.use_perf: torch.cuda.nvtx.range_push("vocoder") responses = [] for i in range(batch): ref_mel_len = reference_mel_len[i] estimated_mel_len = estimated_reference_target_mel_len[i] denoised_one_item = denoised[i, ref_mel_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2) audio = self.forward_vocoder(denoised_one_item) if reference_rms_list[i] < self.target_rms: audio = audio * reference_rms_list[i] / self.target_rms audio = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio)) inference_response = pb_utils.InferenceResponse(output_tensors=[audio]) responses.append(inference_response) if self.use_perf: torch.cuda.nvtx.range_pop() return responses ================================================ FILE: src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/config.pbtxt ================================================ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. name: "f5_tts" backend: "python" max_batch_size: 4 dynamic_batching { max_queue_delay_microseconds: 1000 } parameters [ { key: "vocab_file" value: { string_value: "${vocab}"} }, { key: "model_path", value: {string_value:"${model}"} }, { key: "tllm_model_dir", value: {string_value:"${trtllm}"} }, { key: "reference_audio_sample_rate", value: {string_value:"24000"} }, { key: "vocoder", value: {string_value:"${vocoder}"} } ] input [ { name: "reference_wav" data_type: TYPE_FP32 dims: [-1] optional: True }, { name: "reference_wav_len" data_type: TYPE_INT32 dims: [1] optional: True }, { name: "reference_text" data_type: TYPE_STRING dims: [1] }, { name: "target_text" data_type: TYPE_STRING dims: [1] } ] output [ { name: "waveform" data_type: TYPE_FP32 dims: [ -1 ] } ] instance_group [ { count: 1 kind: KIND_GPU } ] ================================================ FILE: src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/vocoder/1/.gitkeep ================================================ ================================================ FILE: src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/vocoder/config.pbtxt ================================================ name: "vocoder" backend: "tensorrt" default_model_filename: "vocoder.plan" max_batch_size: 4 input [ { name: "mel" data_type: TYPE_FP32 dims: [ 100, -1 ] } ] output [ { name: "waveform" data_type: TYPE_FP32 dims: [ -1 ] } ] dynamic_batching { preferred_batch_size: [1, 2, 4] max_queue_delay_microseconds: 1 } instance_group [ { count: 1 kind: KIND_GPU } ] ================================================ FILE: src/f5_tts/runtime/triton_trtllm/patch/__init__.py ================================================ # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from .baichuan.model import BaichuanForCausalLM from .bert.model import ( BertForQuestionAnswering, BertForSequenceClassification, BertModel, RobertaForQuestionAnswering, RobertaForSequenceClassification, RobertaModel, ) from .bloom.model import BloomForCausalLM, BloomModel from .chatglm.config import ChatGLMConfig from .chatglm.model import ChatGLMForCausalLM, ChatGLMModel from .cogvlm.config import CogVLMConfig from .cogvlm.model import CogVLMForCausalLM from .commandr.model import CohereForCausalLM from .dbrx.config import DbrxConfig from .dbrx.model import DbrxForCausalLM from .deepseek_v1.model import DeepseekForCausalLM from .deepseek_v2.model import DeepseekV2ForCausalLM from .dit.model import DiT from .eagle.model import EagleForCausalLM from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder from .f5tts.model import F5TTS from .falcon.config import FalconConfig from .falcon.model import FalconForCausalLM, FalconModel from .gemma.config import GEMMA2_ARCHITECTURE, GEMMA_ARCHITECTURE, GemmaConfig from .gemma.model import GemmaForCausalLM from .gpt.config import GPTConfig from .gpt.model import GPTForCausalLM, GPTModel from .gptj.config import GPTJConfig from .gptj.model import GPTJForCausalLM, GPTJModel from .gptneox.model import GPTNeoXForCausalLM, GPTNeoXModel from .grok.model import GrokForCausalLM from .llama.config import LLaMAConfig from .llama.model import LLaMAForCausalLM, LLaMAModel from .mamba.model import MambaForCausalLM from .medusa.config import MedusaConfig from .medusa.model import MedusaForCausalLm from .mllama.model import MLLaMAModel from .modeling_utils import PretrainedConfig, PretrainedModel, SpeculativeDecodingMode from .mpt.model import MPTForCausalLM, MPTModel from .nemotron_nas.model import DeciLMForCausalLM from .opt.model import OPTForCausalLM, OPTModel from .phi.model import PhiForCausalLM, PhiModel from .phi3.model import Phi3ForCausalLM, Phi3Model from .qwen.model import QWenForCausalLM from .recurrentgemma.model import RecurrentGemmaForCausalLM from .redrafter.model import ReDrafterForCausalLM __all__ = [ "BertModel", "BertForQuestionAnswering", "BertForSequenceClassification", "RobertaModel", "RobertaForQuestionAnswering", "RobertaForSequenceClassification", "BloomModel", "BloomForCausalLM", "DiT", "DeepseekForCausalLM", "FalconConfig", "DeepseekV2ForCausalLM", "FalconForCausalLM", "FalconModel", "GPTConfig", "GPTModel", "GPTForCausalLM", "OPTForCausalLM", "OPTModel", "LLaMAConfig", "LLaMAForCausalLM", "LLaMAModel", "MedusaConfig", "MedusaForCausalLm", "ReDrafterForCausalLM", "GPTJConfig", "GPTJModel", "GPTJForCausalLM", "GPTNeoXModel", "GPTNeoXForCausalLM", "PhiModel", "PhiConfig", "Phi3Model", "Phi3Config", "PhiForCausalLM", "Phi3ForCausalLM", "ChatGLMConfig", "ChatGLMForCausalLM", "ChatGLMModel", "BaichuanForCausalLM", "QWenConfigQWenForCausalLM", "QWenModel", "EncoderModel", "DecoderModel", "PretrainedConfig", "PretrainedModel", "WhisperEncoder", "MambaForCausalLM", "MambaConfig", "MPTForCausalLM", "MPTModel", "SkyworkForCausalLM", "GemmaConfig", "GemmaForCausalLM", "DbrxConfig", "DbrxForCausalLM", "RecurrentGemmaForCausalLM", "CogVLMConfig", "CogVLMForCausalLM", "EagleForCausalLM", "SpeculativeDecodingMode", "CohereForCausalLM", "MLLaMAModel", "F5TTS", ] MODEL_MAP = { "GPT2LMHeadModel": GPTForCausalLM, "GPT2LMHeadCustomModel": GPTForCausalLM, "GPTBigCodeForCausalLM": GPTForCausalLM, "Starcoder2ForCausalLM": GPTForCausalLM, "FuyuForCausalLM": GPTForCausalLM, "Kosmos2ForConditionalGeneration": GPTForCausalLM, "JAISLMHeadModel": GPTForCausalLM, "GPTForCausalLM": GPTForCausalLM, "NemotronForCausalLM": GPTForCausalLM, "OPTForCausalLM": OPTForCausalLM, "BloomForCausalLM": BloomForCausalLM, "RWForCausalLM": FalconForCausalLM, "FalconForCausalLM": FalconForCausalLM, "PhiForCausalLM": PhiForCausalLM, "Phi3ForCausalLM": Phi3ForCausalLM, "Phi3VForCausalLM": Phi3ForCausalLM, "Phi3SmallForCausalLM": Phi3ForCausalLM, "PhiMoEForCausalLM": Phi3ForCausalLM, "MambaForCausalLM": MambaForCausalLM, "GPTNeoXForCausalLM": GPTNeoXForCausalLM, "GPTJForCausalLM": GPTJForCausalLM, "MPTForCausalLM": MPTForCausalLM, "GLMModel": ChatGLMForCausalLM, "ChatGLMModel": ChatGLMForCausalLM, "ChatGLMForCausalLM": ChatGLMForCausalLM, "LlamaForCausalLM": LLaMAForCausalLM, "ExaoneForCausalLM": LLaMAForCausalLM, "MistralForCausalLM": LLaMAForCausalLM, "MixtralForCausalLM": LLaMAForCausalLM, "ArcticForCausalLM": LLaMAForCausalLM, "Grok1ModelForCausalLM": GrokForCausalLM, "InternLMForCausalLM": LLaMAForCausalLM, "InternLM2ForCausalLM": LLaMAForCausalLM, "MedusaForCausalLM": MedusaForCausalLm, "ReDrafterForCausalLM": ReDrafterForCausalLM, "BaichuanForCausalLM": BaichuanForCausalLM, "BaiChuanForCausalLM": BaichuanForCausalLM, "SkyworkForCausalLM": LLaMAForCausalLM, GEMMA_ARCHITECTURE: GemmaForCausalLM, GEMMA2_ARCHITECTURE: GemmaForCausalLM, "QWenLMHeadModel": QWenForCausalLM, "QWenForCausalLM": QWenForCausalLM, "Qwen2ForCausalLM": QWenForCausalLM, "Qwen2MoeForCausalLM": QWenForCausalLM, "Qwen2ForSequenceClassification": QWenForCausalLM, "Qwen2VLForConditionalGeneration": QWenForCausalLM, "WhisperEncoder": WhisperEncoder, "EncoderModel": EncoderModel, "DecoderModel": DecoderModel, "DbrxForCausalLM": DbrxForCausalLM, "RecurrentGemmaForCausalLM": RecurrentGemmaForCausalLM, "CogVLMForCausalLM": CogVLMForCausalLM, "DiT": DiT, "DeepseekForCausalLM": DeepseekForCausalLM, "DeciLMForCausalLM": DeciLMForCausalLM, "DeepseekV2ForCausalLM": DeepseekV2ForCausalLM, "EagleForCausalLM": EagleForCausalLM, "CohereForCausalLM": CohereForCausalLM, "MllamaForConditionalGeneration": MLLaMAModel, "BertForQuestionAnswering": BertForQuestionAnswering, "BertForSequenceClassification": BertForSequenceClassification, "BertModel": BertModel, "RobertaModel": RobertaModel, "RobertaForQuestionAnswering": RobertaForQuestionAnswering, "RobertaForSequenceClassification": RobertaForSequenceClassification, "F5TTS": F5TTS, } ================================================ FILE: src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py ================================================ from __future__ import annotations import os import sys from collections import OrderedDict import numpy as np import tensorrt as trt from tensorrt_llm._common import default_net from ..._utils import str_dtype_to_trt from ...functional import ( Tensor, concat, constant, expand, shape, slice, unsqueeze, ) from ...layers import Linear from ...module import Module, ModuleList from ...plugin import current_all_reduce_helper from ..modeling_utils import PretrainedConfig, PretrainedModel from .modules import AdaLayerNormZero_Final, ConvPositionEmbedding, DiTBlock, TimestepEmbedding current_file_path = os.path.abspath(__file__) parent_dir = os.path.dirname(current_file_path) sys.path.append(parent_dir) class InputEmbedding(Module): def __init__(self, mel_dim, text_dim, out_dim): super().__init__() self.proj = Linear(mel_dim * 2 + text_dim, out_dim) self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim) def forward(self, x, cond, mask=None): x = self.proj(concat([x, cond], dim=-1)) return self.conv_pos_embed(x, mask=mask) + x class F5TTS(PretrainedModel): def __init__(self, config: PretrainedConfig): super().__init__(config) self.dtype = str_dtype_to_trt(config.dtype) self.time_embed = TimestepEmbedding(config.hidden_size) self.input_embed = InputEmbedding(config.mel_dim, config.text_dim, config.hidden_size) self.dim = config.hidden_size self.depth = config.num_hidden_layers self.transformer_blocks = ModuleList( [ DiTBlock( dim=self.dim, heads=config.num_attention_heads, dim_head=config.dim_head, ff_mult=config.ff_mult, dropout=config.dropout, pe_attn_head=config.pe_attn_head, ) for _ in range(self.depth) ] ) self.norm_out = AdaLayerNormZero_Final(config.hidden_size) # final modulation self.proj_out = Linear(config.hidden_size, config.mel_dim) def forward( self, noise, # nosied input audio cond, # masked cond audio time, # time step rope_cos, rope_sin, input_lengths, scale=1.0, ): if default_net().plugin_config.remove_input_padding: mask = None else: N = shape(noise, 1) B = shape(noise, 0) seq_len_2d = concat([1, N]) max_position_embeddings = 4096 # create position ids position_ids_buffer = constant(np.expand_dims(np.arange(max_position_embeddings).astype(np.int32), 0)) tmp_position_ids = slice(position_ids_buffer, starts=[0, 0], sizes=seq_len_2d) tmp_position_ids = expand(tmp_position_ids, concat([B, N])) # [B, N] tmp_input_lengths = unsqueeze(input_lengths, 1) # [B, 1] tmp_input_lengths = expand(tmp_input_lengths, concat([B, N])) # [B, N] mask = tmp_position_ids < tmp_input_lengths # [B, N] mask = mask.cast("int32") t = self.time_embed(time) x = self.input_embed(noise, cond, mask=mask) for block in self.transformer_blocks: x = block(x, t, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale, mask=mask) denoise = self.proj_out(self.norm_out(x, t)) denoise.mark_output("denoised", self.dtype) return denoise def prepare_inputs(self, **kwargs): max_batch_size = kwargs["max_batch_size"] batch_size_range = [2, 2, max_batch_size] mel_size = self.config.mel_dim max_seq_len = 3000 # 4096 num_frames_range = [mel_size * 2, max_seq_len * 2, max_seq_len * max_batch_size] concat_feature_dim = mel_size + self.config.text_dim freq_embed_dim = 256 # Warning: hard coding 256 here head_dim = self.config.dim_head mapping = self.config.mapping if mapping.tp_size > 1: current_all_reduce_helper().set_workspace_tensor(mapping, 1) if default_net().plugin_config.remove_input_padding: noise = Tensor( name="noise", dtype=self.dtype, shape=[-1, mel_size], dim_range=OrderedDict( [ ("num_frames", [num_frames_range]), ("n_mels", [mel_size]), ] ), ) cond = Tensor( name="cond", dtype=self.dtype, shape=[-1, concat_feature_dim], dim_range=OrderedDict( [ ("num_frames", [num_frames_range]), ("embeded_length", [concat_feature_dim]), ] ), ) time = Tensor( name="time", dtype=self.dtype, shape=[-1, freq_embed_dim], dim_range=OrderedDict( [ ("num_frames", [num_frames_range]), ("freq_dim", [freq_embed_dim]), ] ), ) rope_cos = Tensor( name="rope_cos", dtype=self.dtype, shape=[-1, head_dim], dim_range=OrderedDict( [ ("num_frames", [num_frames_range]), ("head_dim", [head_dim]), ] ), ) rope_sin = Tensor( name="rope_sin", dtype=self.dtype, shape=[-1, head_dim], dim_range=OrderedDict( [ ("num_frames", [num_frames_range]), ("head_dim", [head_dim]), ] ), ) else: noise = Tensor( name="noise", dtype=self.dtype, shape=[-1, -1, mel_size], dim_range=OrderedDict( [ ("batch_size", [batch_size_range]), ("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]), ("n_mels", [mel_size]), ] ), ) cond = Tensor( name="cond", dtype=self.dtype, shape=[-1, -1, concat_feature_dim], dim_range=OrderedDict( [ ("batch_size", [batch_size_range]), ("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]), ("embeded_length", [concat_feature_dim]), ] ), ) time = Tensor( name="time", dtype=self.dtype, shape=[-1, freq_embed_dim], dim_range=OrderedDict( [ ("batch_size", [batch_size_range]), ("freq_dim", [freq_embed_dim]), ] ), ) rope_cos = Tensor( name="rope_cos", dtype=self.dtype, shape=[-1, -1, head_dim], dim_range=OrderedDict( [ ("batch_size", [batch_size_range]), ("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]), ("head_dim", [head_dim]), ] ), ) rope_sin = Tensor( name="rope_sin", dtype=self.dtype, shape=[-1, -1, head_dim], dim_range=OrderedDict( [ ("batch_size", [batch_size_range]), ("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]), ("head_dim", [head_dim]), ] ), ) input_lengths = Tensor( name="input_lengths", dtype=trt.int32, shape=[-1], dim_range=OrderedDict([("batch_size", [batch_size_range])]), ) return { "noise": noise, "cond": cond, "time": time, "rope_cos": rope_cos, "rope_sin": rope_sin, "input_lengths": input_lengths, } ================================================ FILE: src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py ================================================ from __future__ import annotations import math from typing import Optional import numpy as np import torch import torch.nn.functional as F from tensorrt_llm._common import default_net from ..._utils import str_dtype_to_trt, trt_dtype_to_np from ...functional import ( Tensor, bert_attention, cast, chunk, concat, constant, expand_dims, expand_dims_like, expand_mask, gelu, matmul, permute, shape, silu, slice, softmax, squeeze, unsqueeze, view, ) from ...layers import ColumnLinear, Conv1d, LayerNorm, Linear, Mish, RowLinear from ...module import Module class FeedForward(Module): def __init__(self, dim, dim_out=None, mult=4, dropout=0.0): super().__init__() inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim self.project_in = Linear(dim, inner_dim) self.ff = Linear(inner_dim, dim_out) def forward(self, x): return self.ff(gelu(self.project_in(x))) class AdaLayerNormZero(Module): def __init__(self, dim): super().__init__() self.linear = Linear(dim, dim * 6) self.norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6) def forward(self, x, emb=None): emb = self.linear(silu(emb)) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = chunk(emb, 6, dim=1) x = self.norm(x) ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype) if default_net().plugin_config.remove_input_padding: x = x * (ones + scale_msa) + shift_msa else: x = x * (ones + unsqueeze(scale_msa, 1)) + unsqueeze(shift_msa, 1) return x, gate_msa, shift_mlp, scale_mlp, gate_mlp class AdaLayerNormZero_Final(Module): def __init__(self, dim): super().__init__() self.linear = Linear(dim, dim * 2) self.norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6) def forward(self, x, emb): emb = self.linear(silu(emb)) scale, shift = chunk(emb, 2, dim=1) ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype) if default_net().plugin_config.remove_input_padding: x = self.norm(x) * (ones + scale) + shift else: x = self.norm(x) * unsqueeze((ones + scale), 1) x = x + unsqueeze(shift, 1) return x class ConvPositionEmbedding(Module): def __init__(self, dim, kernel_size=31, groups=16): super().__init__() assert kernel_size % 2 != 0 self.conv1d1 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2) self.conv1d2 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2) self.mish = Mish() def forward(self, x, mask=None): if default_net().plugin_config.remove_input_padding: x = unsqueeze(x, 0) if mask is not None: mask = mask.view(concat([shape(mask, 0), 1, shape(mask, 1)])) # [B 1 N] mask = expand_dims_like(mask, x) # [B D N] mask = cast(mask, x.dtype) x = permute(x, [0, 2, 1]) # [B D N] if mask is not None: x = self.mish(self.conv1d2(self.mish(self.conv1d1(x * mask) * mask)) * mask) else: x = self.mish(self.conv1d2(self.mish(self.conv1d1(x)))) x = permute(x, [0, 2, 1]) # [B N D] if default_net().plugin_config.remove_input_padding: x = squeeze(x, 0) return x class Attention(Module): def __init__( self, processor: AttnProcessor, dim: int, heads: int = 16, dim_head: int = 64, dropout: float = 0.0, context_dim: Optional[int] = None, # if not None -> joint attention context_pre_only=None, ): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") self.processor = processor self.dim = dim # hidden_size self.heads = heads self.inner_dim = dim_head * heads self.dropout = dropout self.attention_head_size = dim_head self.context_dim = context_dim self.context_pre_only = context_pre_only self.tp_size = 1 self.num_attention_heads = heads // self.tp_size self.num_attention_kv_heads = heads // self.tp_size # 8 self.dtype = str_dtype_to_trt("float32") self.attention_hidden_size = self.attention_head_size * self.num_attention_heads self.to_q = ColumnLinear( dim, self.tp_size * self.num_attention_heads * self.attention_head_size, bias=True, dtype=self.dtype, tp_group=None, tp_size=self.tp_size, ) self.to_k = ColumnLinear( dim, self.tp_size * self.num_attention_heads * self.attention_head_size, bias=True, dtype=self.dtype, tp_group=None, tp_size=self.tp_size, ) self.to_v = ColumnLinear( dim, self.tp_size * self.num_attention_heads * self.attention_head_size, bias=True, dtype=self.dtype, tp_group=None, tp_size=self.tp_size, ) if self.context_dim is not None: self.to_k_c = Linear(context_dim, self.inner_dim) self.to_v_c = Linear(context_dim, self.inner_dim) if self.context_pre_only is not None: self.to_q_c = Linear(context_dim, self.inner_dim) self.to_out = RowLinear( self.tp_size * self.num_attention_heads * self.attention_head_size, dim, bias=True, dtype=self.dtype, tp_group=None, tp_size=self.tp_size, ) if self.context_pre_only is not None and not self.context_pre_only: self.to_out_c = Linear(self.inner_dim, dim) def forward( self, x, # noised input x rope_cos, rope_sin, input_lengths, mask=None, c=None, # context c scale=1.0, rope=None, c_rope=None, # rotary position embedding for c ) -> torch.Tensor: if c is not None: return self.processor(self, x, c=c, input_lengths=input_lengths, scale=scale, rope=rope, c_rope=c_rope) else: return self.processor( self, x, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale ) def rotate_every_two_3dim(tensor: Tensor) -> Tensor: shape_tensor = concat( [shape(tensor, i) / 2 if i == (tensor.ndim() - 1) else shape(tensor, i) for i in range(tensor.ndim())] ) if default_net().plugin_config.remove_input_padding: assert tensor.ndim() == 2 x1 = slice(tensor, [0, 0], shape_tensor, [1, 2]) x2 = slice(tensor, [0, 1], shape_tensor, [1, 2]) x1 = expand_dims(x1, 2) x2 = expand_dims(x2, 2) zero = constant(np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype)))) x2 = zero - x2 x = concat([x2, x1], 2) out = view(x, concat([shape(x, 0), shape(x, 1) * 2])) else: assert tensor.ndim() == 3 x1 = slice(tensor, [0, 0, 0], shape_tensor, [1, 1, 2]) x2 = slice(tensor, [0, 0, 1], shape_tensor, [1, 1, 2]) x1 = expand_dims(x1, 3) x2 = expand_dims(x2, 3) zero = constant(np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype)))) x2 = zero - x2 x = concat([x2, x1], 3) out = view(x, concat([shape(x, 0), shape(x, 1), shape(x, 2) * 2])) return out def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin, pe_attn_head): full_dim = x.size(-1) head_dim = rope_cos.size(-1) # attn head dim, e.g. 64 if pe_attn_head is None: pe_attn_head = full_dim // head_dim rotated_dim = head_dim * pe_attn_head rotated_and_unrotated_list = [] if default_net().plugin_config.remove_input_padding: # for [N, D] input new_t_shape = concat([shape(x, 0), head_dim]) # (2, -1, 64) for i in range(pe_attn_head): x_slice_i = slice(x, [0, i * 64], new_t_shape, [1, 1]) x_rotated_i = x_slice_i * rope_cos + rotate_every_two_3dim(x_slice_i) * rope_sin rotated_and_unrotated_list.append(x_rotated_i) new_t_unrotated_shape = concat([shape(x, 0), full_dim - rotated_dim]) # (2, -1, 1024 - 64 * pe_attn_head) x_unrotated = slice(x, concat([0, rotated_dim]), new_t_unrotated_shape, [1, 1]) rotated_and_unrotated_list.append(x_unrotated) else: # for [B, N, D] input new_t_shape = concat([shape(x, 0), shape(x, 1), head_dim]) # (2, -1, 64) for i in range(pe_attn_head): x_slice_i = slice(x, [0, 0, i * 64], new_t_shape, [1, 1, 1]) x_rotated_i = x_slice_i * rope_cos + rotate_every_two_3dim(x_slice_i) * rope_sin rotated_and_unrotated_list.append(x_rotated_i) new_t_unrotated_shape = concat( [shape(x, 0), shape(x, 1), full_dim - rotated_dim] ) # (2, -1, 1024 - 64 * pe_attn_head) x_unrotated = slice(x, concat([0, 0, rotated_dim]), new_t_unrotated_shape, [1, 1, 1]) rotated_and_unrotated_list.append(x_unrotated) out = concat(rotated_and_unrotated_list, dim=-1) return out class AttnProcessor: def __init__( self, pe_attn_head: Optional[int] = None, # number of attention head to apply rope, None for all ): self.pe_attn_head = pe_attn_head def __call__( self, attn, x, # noised input x rope_cos, rope_sin, input_lengths, scale=1.0, rope=None, mask=None, ) -> torch.FloatTensor: query = attn.to_q(x) key = attn.to_k(x) value = attn.to_v(x) # k,v,q all (2,1226,1024) query = apply_rotary_pos_emb_3dim(query, rope_cos, rope_sin, self.pe_attn_head) key = apply_rotary_pos_emb_3dim(key, rope_cos, rope_sin, self.pe_attn_head) # attention inner_dim = key.shape[-1] norm_factor = math.sqrt(attn.attention_head_size) q_scaling = 1.0 / norm_factor if default_net().plugin_config.remove_input_padding: mask = None if default_net().plugin_config.bert_attention_plugin: qkv = concat([query, key, value], dim=-1) # TRT plugin mode assert input_lengths is not None if default_net().plugin_config.remove_input_padding: qkv = qkv.view(concat([-1, 3 * inner_dim])) max_input_length = constant( np.zeros( [ 2048, ], dtype=np.int32, ) ) else: max_input_length = None context = bert_attention( qkv, input_lengths, attn.num_attention_heads, attn.attention_head_size, q_scaling=q_scaling, max_input_length=max_input_length, ) else: assert not default_net().plugin_config.remove_input_padding def transpose_for_scores(x): new_x_shape = concat([shape(x, 0), shape(x, 1), attn.num_attention_heads, attn.attention_head_size]) y = x.view(new_x_shape) y = y.transpose(1, 2) return y def transpose_for_scores_k(x): new_x_shape = concat([shape(x, 0), shape(x, 1), attn.num_attention_heads, attn.attention_head_size]) y = x.view(new_x_shape) y = y.permute([0, 2, 3, 1]) return y query = transpose_for_scores(query) key = transpose_for_scores_k(key) value = transpose_for_scores(value) attention_scores = matmul(query, key, use_fp32_acc=False) if mask is not None: attention_mask = expand_mask(mask, shape(query, 2)) attention_mask = cast(attention_mask, attention_scores.dtype) attention_scores = attention_scores + attention_mask attention_probs = softmax(attention_scores, dim=-1) context = matmul(attention_probs, value, use_fp32_acc=False).transpose(1, 2) context = context.view(concat([shape(context, 0), shape(context, 1), attn.attention_hidden_size])) context = attn.to_out(context) if mask is not None: mask = mask.view(concat([shape(mask, 0), shape(mask, 1), 1])) mask = expand_dims_like(mask, context) mask = cast(mask, context.dtype) context = context * mask return context # DiT Block class DiTBlock(Module): def __init__(self, dim, heads, dim_head, ff_mult=2, dropout=0.1, pe_attn_head=None): super().__init__() self.attn_norm = AdaLayerNormZero(dim) self.attn = Attention( processor=AttnProcessor(pe_attn_head=pe_attn_head), dim=dim, heads=heads, dim_head=dim_head, dropout=dropout, ) self.ff_norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout) def forward( self, x, t, rope_cos, rope_sin, input_lengths, scale=1.0, rope=ModuleNotFoundError, mask=None ): # x: noised input, t: time embedding # pre-norm & modulation for attention input norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t) # attention # norm ----> (2,1226,1024) attn_output = self.attn( x=norm, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale, mask=mask ) # process attention output for input x if default_net().plugin_config.remove_input_padding: x = x + gate_msa * attn_output else: x = x + unsqueeze(gate_msa, 1) * attn_output ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype) if default_net().plugin_config.remove_input_padding: norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp else: norm = self.ff_norm(x) * (ones + unsqueeze(scale_mlp, 1)) + unsqueeze(shift_mlp, 1) # norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp ff_output = self.ff(norm) if default_net().plugin_config.remove_input_padding: x = x + gate_mlp * ff_output else: x = x + unsqueeze(gate_mlp, 1) * ff_output return x class TimestepEmbedding(Module): def __init__(self, dim, freq_embed_dim=256, dtype=None): super().__init__() # self.time_embed = SinusPositionEmbedding(freq_embed_dim) self.mlp1 = Linear(freq_embed_dim, dim, bias=True, dtype=dtype) self.mlp2 = Linear(dim, dim, bias=True, dtype=dtype) def forward(self, timestep): t_freq = self.mlp1(timestep) t_freq = silu(t_freq) t_emb = self.mlp2(t_freq) return t_emb ================================================ FILE: src/f5_tts/runtime/triton_trtllm/run.sh ================================================ stage=$1 stop_stage=$2 model=$3 # F5TTS_v1_Base | F5TTS_Base | F5TTS_v1_Small | F5TTS_Small if [ -z "$model" ]; then model=F5TTS_v1_Base fi echo "Start stage: $stage, Stop stage: $stop_stage, Model: $model" export CUDA_VISIBLE_DEVICES=0 CKPT_DIR=../../../../ckpts TRTLLM_CKPT_DIR=$CKPT_DIR/$model/trtllm_ckpt TRTLLM_ENGINE_DIR=$CKPT_DIR/$model/trtllm_engine VOCODER_ONNX_PATH=$CKPT_DIR/vocos_vocoder.onnx VOCODER_TRT_ENGINE_PATH=$CKPT_DIR/vocos_vocoder.plan MODEL_REPO=./model_repo if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then echo "Downloading F5-TTS from huggingface" huggingface-cli download SWivid/F5-TTS $model/model_*.* $model/vocab.txt --local-dir $CKPT_DIR fi ckpt_file=$(ls $CKPT_DIR/$model/model_*.* 2>/dev/null | sort -V | tail -1) # default select latest update vocab_file=$CKPT_DIR/$model/vocab.txt if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then echo "Converting checkpoint" python3 scripts/convert_checkpoint.py \ --pytorch_ckpt $ckpt_file \ --output_dir $TRTLLM_CKPT_DIR --model_name $model python_package_path=/usr/local/lib/python3.12/dist-packages cp -r patch/* $python_package_path/tensorrt_llm/models trtllm-build --checkpoint_dir $TRTLLM_CKPT_DIR \ --max_batch_size 8 \ --output_dir $TRTLLM_ENGINE_DIR --remove_input_padding disable fi if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then echo "Exporting vocos vocoder" python3 scripts/export_vocoder_to_onnx.py --vocoder vocos --output-path $VOCODER_ONNX_PATH bash scripts/export_vocos_trt.sh $VOCODER_ONNX_PATH $VOCODER_TRT_ENGINE_PATH fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then echo "Building triton server" rm -r $MODEL_REPO cp -r ./model_repo_f5_tts $MODEL_REPO python3 scripts/fill_template.py -i $MODEL_REPO/f5_tts/config.pbtxt vocab:$vocab_file,model:$ckpt_file,trtllm:$TRTLLM_ENGINE_DIR,vocoder:vocos cp $VOCODER_TRT_ENGINE_PATH $MODEL_REPO/vocoder/1/vocoder.plan fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then echo "Starting triton server" tritonserver --model-repository=$MODEL_REPO fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then echo "Testing triton server" num_task=1 split_name=wenetspeech4tts log_dir=./tests/client_grpc_${model}_concurrent_${num_task}_${split_name} rm -r $log_dir python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name $split_name --log-dir $log_dir fi if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then echo "Testing http client" audio=../../infer/examples/basic/basic_ref_en.wav reference_text="Some call me nature, others call me mother nature." target_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring." python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text" --output-audio "./tests/client_http_$model.wav" fi if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then echo "TRT-LLM: offline decoding benchmark test" batch_size=2 split_name=wenetspeech4tts backend_type=trt log_dir=./tests/benchmark_${model}_batch_size_${batch_size}_${split_name}_${backend_type} rm -r $log_dir torchrun --nproc_per_node=1 \ benchmark.py --output-dir $log_dir \ --batch-size $batch_size \ --enable-warmup \ --split-name $split_name \ --model-path $ckpt_file \ --vocab-file $vocab_file \ --vocoder-trt-engine-path $VOCODER_TRT_ENGINE_PATH \ --backend-type $backend_type \ --tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1 fi if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then echo "Native Pytorch: offline decoding benchmark test" if ! python3 -c "import f5_tts" &> /dev/null; then pip install -e ../../../../ fi batch_size=1 # set attn_mask_enabled=True if batching in actual use case split_name=wenetspeech4tts backend_type=pytorch log_dir=./tests/benchmark_${model}_batch_size_${batch_size}_${split_name}_${backend_type} rm -r $log_dir torchrun --nproc_per_node=1 \ benchmark.py --output-dir $log_dir \ --batch-size $batch_size \ --split-name $split_name \ --enable-warmup \ --model-path $ckpt_file \ --vocab-file $vocab_file \ --backend-type $backend_type \ --tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1 fi ================================================ FILE: src/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py ================================================ # Modified from https://github.com/echocatzh/conv-stft/blob/master/conv_stft/conv_stft.py # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # MIT License # Copyright (c) 2020 Shimin Zhang # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. import torch as th import torch.nn.functional as F from scipy.signal import check_COLA, get_window support_clp_op = None if th.__version__ >= "1.7.0": from torch.fft import rfft as fft support_clp_op = True else: from torch import rfft as fft class STFT(th.nn.Module): def __init__( self, win_len=1024, win_hop=512, fft_len=1024, enframe_mode="continue", win_type="hann", win_sqrt=False, pad_center=True, ): """ Implement of STFT using 1D convolution and 1D transpose convolutions. Implement of framing the signal in 2 ways, `break` and `continue`. `break` method is a kaldi-like framing. `continue` method is a librosa-like framing. More information about `perfect reconstruction`: 1. https://ww2.mathworks.cn/help/signal/ref/stft.html 2. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.get_window.html Args: win_len (int): Number of points in one frame. Defaults to 1024. win_hop (int): Number of framing stride. Defaults to 512. fft_len (int): Number of DFT points. Defaults to 1024. enframe_mode (str, optional): `break` and `continue`. Defaults to 'continue'. win_type (str, optional): The type of window to create. Defaults to 'hann'. win_sqrt (bool, optional): using square root window. Defaults to True. pad_center (bool, optional): `perfect reconstruction` opts. Defaults to True. """ super(STFT, self).__init__() assert enframe_mode in ["break", "continue"] assert fft_len >= win_len self.win_len = win_len self.win_hop = win_hop self.fft_len = fft_len self.mode = enframe_mode self.win_type = win_type self.win_sqrt = win_sqrt self.pad_center = pad_center self.pad_amount = self.fft_len // 2 en_k, fft_k, ifft_k, ola_k = self.__init_kernel__() self.register_buffer("en_k", en_k) self.register_buffer("fft_k", fft_k) self.register_buffer("ifft_k", ifft_k) self.register_buffer("ola_k", ola_k) def __init_kernel__(self): """ Generate enframe_kernel, fft_kernel, ifft_kernel and overlap-add kernel. ** enframe_kernel: Using conv1d layer and identity matrix. ** fft_kernel: Using linear layer for matrix multiplication. In fact, enframe_kernel and fft_kernel can be combined, But for the sake of readability, I took the two apart. ** ifft_kernel, pinv of fft_kernel. ** overlap-add kernel, just like enframe_kernel, but transposed. Returns: tuple: four kernels. """ enframed_kernel = th.eye(self.fft_len)[:, None, :] if support_clp_op: tmp = fft(th.eye(self.fft_len)) fft_kernel = th.stack([tmp.real, tmp.imag], dim=2) else: fft_kernel = fft(th.eye(self.fft_len), 1) if self.mode == "break": enframed_kernel = th.eye(self.win_len)[:, None, :] fft_kernel = fft_kernel[: self.win_len] fft_kernel = th.cat((fft_kernel[:, :, 0], fft_kernel[:, :, 1]), dim=1) ifft_kernel = th.pinverse(fft_kernel)[:, None, :] window = get_window(self.win_type, self.win_len) self.perfect_reconstruct = check_COLA(window, self.win_len, self.win_len - self.win_hop) window = th.FloatTensor(window) if self.mode == "continue": left_pad = (self.fft_len - self.win_len) // 2 right_pad = left_pad + (self.fft_len - self.win_len) % 2 window = F.pad(window, (left_pad, right_pad)) if self.win_sqrt: self.padded_window = window window = th.sqrt(window) else: self.padded_window = window**2 fft_kernel = fft_kernel.T * window ifft_kernel = ifft_kernel * window ola_kernel = th.eye(self.fft_len)[: self.win_len, None, :] if self.mode == "continue": ola_kernel = th.eye(self.fft_len)[:, None, : self.fft_len] return enframed_kernel, fft_kernel, ifft_kernel, ola_kernel def is_perfect(self): """ Whether the parameters win_len, win_hop and win_sqrt obey constants overlap-add(COLA) Returns: bool: Return true if parameters obey COLA. """ return self.perfect_reconstruct and self.pad_center def transform(self, inputs, return_type="complex"): """Take input data (audio) to STFT domain. Args: inputs (tensor): Tensor of floats, with shape (num_batch, num_samples) return_type (str, optional): return (mag, phase) when `magphase`, return (real, imag) when `realimag` and complex(real, imag) when `complex`. Defaults to 'complex'. Returns: tuple: (mag, phase) when `magphase`, return (real, imag) when `realimag`. Defaults to 'complex', each elements with shape [num_batch, num_frequencies, num_frames] """ assert return_type in ["magphase", "realimag", "complex"] if inputs.dim() == 2: inputs = th.unsqueeze(inputs, 1) self.num_samples = inputs.size(-1) if self.pad_center: inputs = F.pad(inputs, (self.pad_amount, self.pad_amount), mode="reflect") enframe_inputs = F.conv1d(inputs, self.en_k, stride=self.win_hop) outputs = th.transpose(enframe_inputs, 1, 2) outputs = F.linear(outputs, self.fft_k) outputs = th.transpose(outputs, 1, 2) dim = self.fft_len // 2 + 1 real = outputs[:, :dim, :] imag = outputs[:, dim:, :] if return_type == "realimag": return real, imag elif return_type == "complex": assert support_clp_op return th.complex(real, imag) else: mags = th.sqrt(real**2 + imag**2) phase = th.atan2(imag, real) return mags, phase def inverse(self, input1, input2=None, input_type="magphase"): """Call the inverse STFT (iSTFT), given tensors produced by the `transform` function. Args: input1 (tensors): Magnitude/Real-part of STFT with shape [num_batch, num_frequencies, num_frames] input2 (tensors): Phase/Imag-part of STFT with shape [num_batch, num_frequencies, num_frames] input_type (str, optional): Mathematical meaning of input tensor's. Defaults to 'magphase'. Returns: tensors: Reconstructed audio given magnitude and phase. Of shape [num_batch, num_samples] """ assert input_type in ["magphase", "realimag"] if input_type == "realimag": real, imag = None, None if support_clp_op and th.is_complex(input1): real, imag = input1.real, input1.imag else: real, imag = input1, input2 else: real = input1 * th.cos(input2) imag = input1 * th.sin(input2) inputs = th.cat([real, imag], dim=1) outputs = F.conv_transpose1d(inputs, self.ifft_k, stride=self.win_hop) t = (self.padded_window[None, :, None]).repeat(1, 1, inputs.size(-1)) t = t.to(inputs.device) coff = F.conv_transpose1d(t, self.ola_k, stride=self.win_hop) num_frames = input1.size(-1) num_samples = num_frames * self.win_hop rm_start, rm_end = self.pad_amount, self.pad_amount + num_samples outputs = outputs[..., rm_start:rm_end] coff = coff[..., rm_start:rm_end] coffidx = th.where(coff > 1e-8) outputs[coffidx] = outputs[coffidx] / (coff[coffidx]) return outputs.squeeze(dim=1) def forward(self, inputs): """Take input data (audio) to STFT domain and then back to audio. Args: inputs (tensor): Tensor of floats, with shape [num_batch, num_samples] Returns: tensor: Reconstructed audio given magnitude and phase. Of shape [num_batch, num_samples] """ mag, phase = self.transform(inputs) rec_wav = self.inverse(mag, phase) return rec_wav ================================================ FILE: src/f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py ================================================ import argparse import json import os import re import time import traceback from concurrent.futures import ThreadPoolExecutor, as_completed import safetensors.torch import torch from tensorrt_llm import str_dtype_to_torch from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.convert_utils import split, split_matrix_tp def split_q_tp(v, n_head, n_hidden, tensor_parallel, rank): split_v = split(v, tensor_parallel, rank, dim=1) return split_v.contiguous() def split_q_bias_tp(v, n_head, n_hidden, tensor_parallel, rank): split_v = split(v, tensor_parallel, rank, dim=0) return split_v.contiguous() def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument("--pytorch_ckpt", type=str, default="./ckpts/model_last.pt") parser.add_argument( "--output_dir", type=str, default="./tllm_checkpoint", help="The path to save the TensorRT-LLM checkpoint" ) parser.add_argument("--tp_size", type=int, default=1, help="N-way tensor parallelism size") parser.add_argument("--cp_size", type=int, default=1, help="Context parallelism size") parser.add_argument("--pp_size", type=int, default=1, help="N-way pipeline parallelism size") parser.add_argument("--dtype", type=str, default="float16", choices=["float32", "bfloat16", "float16"]) parser.add_argument("--fp8_linear", action="store_true", help="Whether use FP8 for linear layers") parser.add_argument( "--workers", type=int, default=1, help="The number of workers for converting checkpoint in parallel" ) parser.add_argument( "--model_name", type=str, default="F5TTS_Custom", choices=[ "F5TTS_v1_Base", "F5TTS_Base", "F5TTS_v1_Small", "F5TTS_Small", ], # if set, overwrite the below hyperparams ) parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT") parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers") parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module") parser.add_argument("--dim_head", type=int, default=64, help="The dimension of attention head") parser.add_argument("--ff_mult", type=int, default=2, help="The FFN intermediate dimension multiplier") parser.add_argument("--text_dim", type=int, default=512, help="The output dimension of text encoder") parser.add_argument( "--text_mask_padding", type=lambda x: x.lower() == "true", choices=[True, False], default=True, help="Whether apply padding mask for conv layers in text encoder", ) parser.add_argument("--conv_layers", type=int, default=4, help="The number of conv layers of text encoder") parser.add_argument("--pe_attn_head", type=int, default=None, help="The number of attn head that apply pos emb") args = parser.parse_args() # overwrite if --model_name ordered if args.model_name == "F5TTS_v1_Base": args.hidden_size = 1024 args.depth = 22 args.num_heads = 16 args.dim_head = 64 args.ff_mult = 2 args.text_dim = 512 args.text_mask_padding = True args.conv_layers = 4 args.pe_attn_head = None elif args.model_name == "F5TTS_Base": args.hidden_size = 1024 args.depth = 22 args.num_heads = 16 args.dim_head = 64 args.ff_mult = 2 args.text_dim = 512 args.text_mask_padding = False args.conv_layers = 4 args.pe_attn_head = 1 elif args.model_name == "F5TTS_v1_Small": args.hidden_size = 768 args.depth = 18 args.num_heads = 12 args.dim_head = 64 args.ff_mult = 2 args.text_dim = 512 args.text_mask_padding = True args.conv_layers = 4 args.pe_attn_head = None elif args.model_name == "F5TTS_Small": args.hidden_size = 768 args.depth = 18 args.num_heads = 12 args.dim_head = 64 args.ff_mult = 2 args.text_dim = 512 args.text_mask_padding = False args.conv_layers = 4 args.pe_attn_head = 1 return args def convert_pytorch_dit_to_trtllm_weight(args, mapping, dtype="float32", use_ema=True): weights = {} tik = time.time() torch_dtype = str_dtype_to_torch(dtype) tensor_parallel = mapping.tp_size ckpt_path = args.pytorch_ckpt ckpt_type = ckpt_path.split(".")[-1] if ckpt_type == "safetensors": from safetensors.torch import load_file model_params = load_file(ckpt_path) else: ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True) model_params = ckpt["ema_model_state_dict"] if use_ema else ckpt["model_state_dict"] prefix = "ema_model.transformer." if use_ema else "transformer." if any(k.startswith(prefix) for k in model_params.keys()): model_params = { key[len(prefix) :] if key.startswith(prefix) else key: value for key, value in model_params.items() if key.startswith(prefix) } pytorch_to_trtllm_name = { r"^time_embed\.time_mlp\.0\.(weight|bias)$": r"time_embed.mlp1.\1", r"^time_embed\.time_mlp\.2\.(weight|bias)$": r"time_embed.mlp2.\1", r"^input_embed\.conv_pos_embed\.conv1d\.0\.(weight|bias)$": r"input_embed.conv_pos_embed.conv1d1.\1", r"^input_embed\.conv_pos_embed\.conv1d\.2\.(weight|bias)$": r"input_embed.conv_pos_embed.conv1d2.\1", r"^transformer_blocks\.(\d+)\.attn\.to_out\.0\.(weight|bias)$": r"transformer_blocks.\1.attn.to_out.\2", r"^transformer_blocks\.(\d+)\.ff\.ff\.0\.0\.(weight|bias)$": r"transformer_blocks.\1.ff.project_in.\2", r"^transformer_blocks\.(\d+)\.ff\.ff\.2\.(weight|bias)$": r"transformer_blocks.\1.ff.ff.\2", } def get_trtllm_name(pytorch_name): for pytorch_name_pattern, trtllm_name_replacement in pytorch_to_trtllm_name.items(): trtllm_name_if_matched = re.sub(pytorch_name_pattern, trtllm_name_replacement, pytorch_name) if trtllm_name_if_matched != pytorch_name: return trtllm_name_if_matched return pytorch_name weights = dict() for name, param in model_params.items(): if name == "input_embed.conv_pos_embed.conv1d.0.weight" or name == "input_embed.conv_pos_embed.conv1d.2.weight": weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype).unsqueeze(-1) else: weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype) assert len(weights) == len(model_params) # new_prefix = "f5_transformer." new_prefix = "" weights = {new_prefix + key: value for key, value in weights.items()} import math scale_factor = math.pow(64, -0.25) for k, v in weights.items(): if re.match("^transformer_blocks.*.attn.to_k.weight$", k): weights[k] *= scale_factor weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank) elif re.match("^transformer_blocks.*.attn.to_k.bias$", k): weights[k] *= scale_factor weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank) elif re.match("^transformer_blocks.*.attn.to_q.weight$", k): weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank) weights[k] *= scale_factor elif re.match("^transformer_blocks.*.attn.to_q.bias$", k): weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank) weights[k] *= scale_factor elif re.match("^transformer_blocks.*.attn.to_v.weight$", k): weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank) elif re.match("^transformer_blocks.*.attn.to_v.bias$", k): weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank) elif re.match("^transformer_blocks.*.attn.to_out.weight$", k): weights[k] = split_matrix_tp(v, tensor_parallel, mapping.tp_rank, dim=1) tok = time.time() t = time.strftime("%H:%M:%S", time.gmtime(tok - tik)) print(f"Weights loaded. Total time: {t}") return weights def save_config(args): if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) config = { "architecture": "F5TTS", # set the same as in ../patch/__init__.py "dtype": args.dtype, "hidden_size": args.hidden_size, "num_hidden_layers": args.depth, "num_attention_heads": args.num_heads, "dim_head": args.dim_head, "dropout": 0.0, # inference-only "ff_mult": args.ff_mult, "mel_dim": 100, "text_dim": args.text_dim, "text_mask_padding": args.text_mask_padding, "conv_layers": args.conv_layers, "pe_attn_head": args.pe_attn_head, "mapping": { "world_size": args.cp_size * args.tp_size * args.pp_size, "cp_size": args.cp_size, "tp_size": args.tp_size, "pp_size": args.pp_size, }, } if args.fp8_linear: config["quantization"] = { "quant_algo": "FP8", # TODO: add support for exclude modules. # "exclude_modules": "*final_layer*", } with open(os.path.join(args.output_dir, "config.json"), "w") as f: json.dump(config, f, indent=4) def covert_and_save(args, rank): if rank == 0: save_config(args) mapping = Mapping( world_size=args.cp_size * args.tp_size * args.pp_size, rank=rank, cp_size=args.cp_size, tp_size=args.tp_size, pp_size=args.pp_size, ) weights = convert_pytorch_dit_to_trtllm_weight(args, mapping, dtype=args.dtype) safetensors.torch.save_file(weights, os.path.join(args.output_dir, f"rank{rank}.safetensors")) def execute(workers, func, args): if workers == 1: for rank, f in enumerate(func): f(args, rank) else: with ThreadPoolExecutor(max_workers=workers) as p: futures = [p.submit(f, args, rank) for rank, f in enumerate(func)] exceptions = [] for future in as_completed(futures): try: future.result() except Exception as e: traceback.print_exc() exceptions.append(e) assert len(exceptions) == 0, "Checkpoint conversion failed, please check error log." def main(): args = parse_arguments() world_size = args.cp_size * args.tp_size * args.pp_size assert args.pp_size == 1, "PP is not supported yet." tik = time.time() if args.pytorch_ckpt is None: return print("Start execute") execute(args.workers, [covert_and_save] * world_size, args) tok = time.time() t = time.strftime("%H:%M:%S", time.gmtime(tok - tik)) print(f"Total time of converting checkpoints: {t}") if __name__ == "__main__": main() ================================================ FILE: src/f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import torch import torch.nn as nn from conv_stft import STFT from huggingface_hub import hf_hub_download from vocos import Vocos opset_version = 17 def get_args(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( "--vocoder", type=str, default="vocos", choices=["vocos", "bigvgan"], help="Vocoder to export", ) parser.add_argument( "--output-path", type=str, default="./vocos_vocoder.onnx", help="Output path", ) return parser.parse_args() class ISTFTHead(nn.Module): def __init__(self, n_fft: int, hop_length: int): super().__init__() self.out = None self.stft = STFT(fft_len=n_fft, win_hop=hop_length, win_len=n_fft) def forward(self, x: torch.Tensor): x = self.out(x).transpose(1, 2) mag, p = x.chunk(2, dim=1) mag = torch.exp(mag) mag = torch.clip(mag, max=1e2) real = mag * torch.cos(p) imag = mag * torch.sin(p) audio = self.stft.inverse(input1=real, input2=imag, input_type="realimag") return audio class VocosVocoder(nn.Module): def __init__(self, vocos_vocoder): super(VocosVocoder, self).__init__() self.vocos_vocoder = vocos_vocoder istft_head_out = self.vocos_vocoder.head.out n_fft = self.vocos_vocoder.head.istft.n_fft hop_length = self.vocos_vocoder.head.istft.hop_length istft_head_for_export = ISTFTHead(n_fft, hop_length) istft_head_for_export.out = istft_head_out self.vocos_vocoder.head = istft_head_for_export def forward(self, mel): waveform = self.vocos_vocoder.decode(mel) return waveform def export_VocosVocoder(vocos_vocoder, output_path, verbose): vocos_vocoder = VocosVocoder(vocos_vocoder).cuda() vocos_vocoder.eval() dummy_batch_size = 8 dummy_input_length = 500 dummy_mel = torch.randn(dummy_batch_size, 100, dummy_input_length).cuda() with torch.no_grad(): dummy_waveform = vocos_vocoder(mel=dummy_mel) print(dummy_waveform.shape) dummy_input = dummy_mel torch.onnx.export( vocos_vocoder, dummy_input, output_path, opset_version=opset_version, do_constant_folding=True, input_names=["mel"], output_names=["waveform"], dynamic_axes={ "mel": {0: "batch_size", 2: "input_length"}, "waveform": {0: "batch_size", 1: "output_length"}, }, verbose=verbose, ) print("Exported to {}".format(output_path)) def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device="cpu", hf_cache_dir=None): if vocoder_name == "vocos": # vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device) if is_local: print(f"Load vocos from local path {local_path}") config_path = f"{local_path}/config.yaml" model_path = f"{local_path}/pytorch_model.bin" else: print("Download Vocos from huggingface charactr/vocos-mel-24khz") repo_id = "charactr/vocos-mel-24khz" config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml") model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin") vocoder = Vocos.from_hparams(config_path) state_dict = torch.load(model_path, map_location="cpu", weights_only=True) vocoder.load_state_dict(state_dict) vocoder = vocoder.eval().to(device) elif vocoder_name == "bigvgan": raise NotImplementedError("BigVGAN is not supported yet") vocoder.remove_weight_norm() vocoder = vocoder.eval().to(device) return vocoder if __name__ == "__main__": args = get_args() vocoder = load_vocoder(vocoder_name=args.vocoder, device="cpu", hf_cache_dir=None) if args.vocoder == "vocos": export_VocosVocoder(vocoder, args.output_path, verbose=False) ================================================ FILE: src/f5_tts/runtime/triton_trtllm/scripts/export_vocos_trt.sh ================================================ #!/bin/bash # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Manual installation of TensorRT, in case not using NVIDIA NGC: # https://docs.nvidia.com/deeplearning/tensorrt/latest/installing-tensorrt/installing.html#downloading-tensorrt TRTEXEC="/usr/src/tensorrt/bin/trtexec" ONNX_PATH=$1 ENGINE_PATH=$2 echo "ONNX_PATH: $ONNX_PATH" echo "ENGINE_PATH: $ENGINE_PATH" PRECISION="fp32" MIN_BATCH_SIZE=1 OPT_BATCH_SIZE=1 MAX_BATCH_SIZE=8 MIN_INPUT_LENGTH=1 OPT_INPUT_LENGTH=1000 MAX_INPUT_LENGTH=3000 # 4096 MEL_MIN_SHAPE="${MIN_BATCH_SIZE}x100x${MIN_INPUT_LENGTH}" MEL_OPT_SHAPE="${OPT_BATCH_SIZE}x100x${OPT_INPUT_LENGTH}" MEL_MAX_SHAPE="${MAX_BATCH_SIZE}x100x${MAX_INPUT_LENGTH}" ${TRTEXEC} \ --minShapes="mel:${MEL_MIN_SHAPE}" \ --optShapes="mel:${MEL_OPT_SHAPE}" \ --maxShapes="mel:${MEL_MAX_SHAPE}" \ --onnx=${ONNX_PATH} \ --saveEngine=${ENGINE_PATH} ================================================ FILE: src/f5_tts/runtime/triton_trtllm/scripts/fill_template.py ================================================ #! /usr/bin/env python3 from argparse import ArgumentParser from string import Template def main(file_path, substitutions, in_place, participant_ids): with open(file_path) as f: pbtxt = Template(f.read()) sub_dict = {"max_queue_size": 0} sub_dict["participant_ids"] = participant_ids for sub in substitutions.split(","): key, value = sub.split(":") sub_dict[key] = value pbtxt = pbtxt.safe_substitute(sub_dict) if in_place: with open(file_path, "w") as f: f.write(pbtxt) else: print(pbtxt) if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("file_path", help="path of the .pbtxt to modify") parser.add_argument( "substitutions", help="substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2...", ) parser.add_argument("--in_place", "-i", action="store_true", help="do the operation in-place") parser.add_argument("--participant_ids", help="Participant IDs for the model", default="") args = parser.parse_args() main(**vars(args)) ================================================ FILE: src/f5_tts/scripts/count_max_epoch.py ================================================ """ADAPTIVE BATCH SIZE""" print("Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in") print(" -> least padding, gather wavs with accumulated frames in a batch\n") # data total_hours = 95282 mel_hop_length = 256 mel_sampling_rate = 24000 # target wanted_max_updates = 1200000 # train params gpus = 8 frames_per_gpu = 38400 # 8 * 38400 = 307200 grad_accum = 1 # intermediate mini_batch_frames = frames_per_gpu * grad_accum * gpus mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600 updates_per_epoch = total_hours / mini_batch_hours # steps_per_epoch = updates_per_epoch * grad_accum # result epochs = wanted_max_updates / updates_per_epoch print(f"epochs should be set to: {epochs:.0f} ({epochs / grad_accum:.1f} x gd_acum {grad_accum})") print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates") # print(f" or approx. 0/{steps_per_epoch:.0f} steps") # others print(f"total {total_hours:.0f} hours") print(f"mini-batch of {mini_batch_frames:.0f} frames, {mini_batch_hours:.2f} hours per mini-batch") ================================================ FILE: src/f5_tts/scripts/count_max_epoch_precise.py ================================================ import math from torch.utils.data import SequentialSampler from f5_tts.model.dataset import DynamicBatchSampler, load_dataset train_dataset = load_dataset("Emilia_ZH_EN", "pinyin") sampler = SequentialSampler(train_dataset) gpus = 8 batch_size_per_gpu = 38400 max_samples_per_gpu = 64 max_updates = 1250000 batch_sampler = DynamicBatchSampler( sampler, batch_size_per_gpu, max_samples=max_samples_per_gpu, random_seed=666, drop_residual=False, ) updates_per_epoch = int(len(batch_sampler) / gpus) print( f"One epoch has {updates_per_epoch} updates if gpus={gpus}, with " f"batch_size_per_gpu={batch_size_per_gpu} (frames) & " f"max_samples_per_gpu={max_samples_per_gpu}." ) print(f"If gpus={gpus}, for max_updates={max_updates} should set epoch={math.ceil(max_updates / updates_per_epoch)}.") ================================================ FILE: src/f5_tts/scripts/count_params_gflops.py ================================================ import os import sys sys.path.append(os.getcwd()) import thop import torch from f5_tts.model import CFM, DiT """ ~155M """ # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4) # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4) # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2) # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4) # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True) # transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2) """ ~335M """ # FLOPs: 622.1 G, Params: 333.2 M # transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4) # FLOPs: 363.4 G, Params: 335.8 M transformer = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) model = CFM(transformer=transformer) target_sample_rate = 24000 n_mel_channels = 100 hop_length = 256 duration = 20 frame_length = int(duration * target_sample_rate / hop_length) text_length = 150 flops, params = thop.profile( model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long)) ) print(f"FLOPs: {flops / 1e9} G") print(f"Params: {params / 1e6} M") ================================================ FILE: src/f5_tts/socket_client.py ================================================ import asyncio import logging import socket import time import numpy as np import pyaudio logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998): client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) await asyncio.get_event_loop().run_in_executor(None, client_socket.connect, (server_ip, int(server_port))) start_time = time.time() first_chunk_time = None async def play_audio_stream(): nonlocal first_chunk_time p = pyaudio.PyAudio() stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True, frames_per_buffer=2048) try: while True: data = await asyncio.get_event_loop().run_in_executor(None, client_socket.recv, 8192) if not data: break if data == b"END": logger.info("End of audio received.") break audio_array = np.frombuffer(data, dtype=np.float32) stream.write(audio_array.tobytes()) if first_chunk_time is None: first_chunk_time = time.time() finally: stream.stop_stream() stream.close() p.terminate() logger.info(f"Total time taken: {time.time() - start_time:.4f} seconds") try: data_to_send = f"{text}".encode("utf-8") await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, data_to_send) await play_audio_stream() except Exception as e: logger.error(f"Error in listen_to_F5TTS: {e}") finally: client_socket.close() if __name__ == "__main__": text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency. Let's break down the components" asyncio.run(listen_to_F5TTS(text_to_send)) ================================================ FILE: src/f5_tts/socket_server.py ================================================ import argparse import gc import logging import queue import socket import struct import threading import traceback import wave from importlib.resources import files import numpy as np import torch import torchaudio from huggingface_hub import hf_hub_download from hydra.utils import get_class from omegaconf import OmegaConf from f5_tts.infer.utils_infer import ( chunk_text, infer_batch_process, load_model, load_vocoder, preprocess_ref_audio_text, ) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class AudioFileWriterThread(threading.Thread): """Threaded file writer to avoid blocking the TTS streaming process.""" def __init__(self, output_file, sampling_rate): super().__init__() self.output_file = output_file self.sampling_rate = sampling_rate self.queue = queue.Queue() self.stop_event = threading.Event() self.audio_data = [] def run(self): """Process queued audio data and write it to a file.""" logger.info("AudioFileWriterThread started.") with wave.open(self.output_file, "wb") as wf: wf.setnchannels(1) wf.setsampwidth(2) wf.setframerate(self.sampling_rate) while not self.stop_event.is_set() or not self.queue.empty(): try: chunk = self.queue.get(timeout=0.1) if chunk is not None: chunk = np.int16(chunk * 32767) self.audio_data.append(chunk) wf.writeframes(chunk.tobytes()) except queue.Empty: continue def add_chunk(self, chunk): """Add a new chunk to the queue.""" self.queue.put(chunk) def stop(self): """Stop writing and ensure all queued data is written.""" self.stop_event.set() self.join() logger.info("Audio writing completed.") class TTSStreamingProcessor: def __init__(self, model, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32): self.device = device or ( "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml"))) self.model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}") self.model_arc = model_cfg.model.arch self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type self.sampling_rate = model_cfg.model.mel_spec.target_sample_rate self.model = self.load_ema_model(ckpt_file, vocab_file, dtype) self.vocoder = self.load_vocoder_model() self.update_reference(ref_audio, ref_text) self._warm_up() self.file_writer_thread = None self.first_package = True def load_ema_model(self, ckpt_file, vocab_file, dtype): return load_model( self.model_cls, self.model_arc, ckpt_path=ckpt_file, mel_spec_type=self.mel_spec_type, vocab_file=vocab_file, ode_method="euler", use_ema=True, device=self.device, ).to(self.device, dtype=dtype) def load_vocoder_model(self): return load_vocoder(vocoder_name=self.mel_spec_type, is_local=False, local_path=None, device=self.device) def update_reference(self, ref_audio, ref_text): self.ref_audio, self.ref_text = preprocess_ref_audio_text(ref_audio, ref_text) self.audio, self.sr = torchaudio.load(self.ref_audio) ref_audio_duration = self.audio.shape[-1] / self.sr ref_text_byte_len = len(self.ref_text.encode("utf-8")) self.max_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration)) self.few_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration) / 2) self.min_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration) / 4) def _warm_up(self): logger.info("Warming up the model...") gen_text = "Warm-up text for the model." for _ in infer_batch_process( (self.audio, self.sr), self.ref_text, [gen_text], self.model, self.vocoder, progress=None, device=self.device, streaming=True, ): pass logger.info("Warm-up completed.") def generate_stream(self, text, conn): text_batches = chunk_text(text, max_chars=self.max_chars) if self.first_package: text_batches = chunk_text(text_batches[0], max_chars=self.few_chars) + text_batches[1:] text_batches = chunk_text(text_batches[0], max_chars=self.min_chars) + text_batches[1:] self.first_package = False audio_stream = infer_batch_process( (self.audio, self.sr), self.ref_text, text_batches, self.model, self.vocoder, progress=None, device=self.device, streaming=True, chunk_size=2048, ) # Reset the file writer thread if self.file_writer_thread is not None: self.file_writer_thread.stop() self.file_writer_thread = AudioFileWriterThread("output.wav", self.sampling_rate) self.file_writer_thread.start() for audio_chunk, _ in audio_stream: if len(audio_chunk) > 0: logger.info(f"Generated audio chunk of size: {len(audio_chunk)}") # Send audio chunk via socket conn.sendall(struct.pack(f"{len(audio_chunk)}f", *audio_chunk)) # Write to file asynchronously self.file_writer_thread.add_chunk(audio_chunk) logger.info("Finished sending audio stream.") conn.sendall(b"END") # Send end signal # Ensure all audio data is written before exiting self.file_writer_thread.stop() def handle_client(conn, processor): try: with conn: conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) while True: data = conn.recv(1024) if not data: processor.first_package = True break data_str = data.decode("utf-8").strip() logger.info(f"Received text: {data_str}") try: processor.generate_stream(data_str, conn) except Exception as inner_e: logger.error(f"Error during processing: {inner_e}") traceback.print_exc() break except Exception as e: logger.error(f"Error handling client: {e}") traceback.print_exc() def start_server(host, port, processor): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind((host, port)) s.listen() logger.info(f"Server started on {host}:{port}") while True: conn, addr = s.accept() logger.info(f"Connected by {addr}") handle_client(conn, processor) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", default="0.0.0.0") parser.add_argument("--port", default=9998) parser.add_argument( "--model", default="F5TTS_v1_Base", help="The model name, e.g. F5TTS_v1_Base", ) parser.add_argument( "--ckpt_file", default=str(hf_hub_download(repo_id="SWivid/F5-TTS", filename="F5TTS_v1_Base/model_1250000.safetensors")), help="Path to the model checkpoint file", ) parser.add_argument( "--vocab_file", default="", help="Path to the vocab file if customized", ) parser.add_argument( "--ref_audio", default=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")), help="Reference audio to provide model with speaker characteristics", ) parser.add_argument( "--ref_text", default="", help="Reference audio subtitle, leave empty to auto-transcribe", ) parser.add_argument("--device", default=None, help="Device to run the model on") parser.add_argument("--dtype", default=torch.float32, help="Data type to use for model inference") args = parser.parse_args() try: # Initialize the processor with the model and vocoder processor = TTSStreamingProcessor( model=args.model, ckpt_file=args.ckpt_file, vocab_file=args.vocab_file, ref_audio=args.ref_audio, ref_text=args.ref_text, device=args.device, dtype=args.dtype, ) # Start the server start_server(args.host, args.port, processor) except KeyboardInterrupt: gc.collect() ================================================ FILE: src/f5_tts/train/README.md ================================================ # Training Check your FFmpeg installation: ```bash ffmpeg -version ``` If not found, install it first (or skip assuming you know of other backends available). ## Prepare Dataset Example data processing scripts, and you may tailor your own one along with a Dataset class in `src/f5_tts/model/dataset.py`. ### 1. Some specific Datasets preparing scripts Download corresponding dataset first, and fill in the path in scripts. ```bash # Prepare the Emilia dataset python src/f5_tts/train/datasets/prepare_emilia.py # Prepare the Wenetspeech4TTS dataset python src/f5_tts/train/datasets/prepare_wenetspeech4tts.py # Prepare the LibriTTS dataset python src/f5_tts/train/datasets/prepare_libritts.py # Prepare the LJSpeech dataset python src/f5_tts/train/datasets/prepare_ljspeech.py ``` ### 2. Create custom dataset with CSV Prepare a CSV with two columns using a required header: `audio_file|text`. Audio paths must be absolute. Use guidance see [#57 here](https://github.com/SWivid/F5-TTS/discussions/57#discussioncomment-10959029). ```bash python src/f5_tts/train/datasets/prepare_csv_wavs.py /path/to/metadata.csv /path/to/output ``` ## Training & Finetuning Once your datasets are prepared, you can start the training process. ### 1. Training script used for pretrained model ```bash # setup accelerate config, e.g. use multi-gpu ddp, fp16 # will be to: ~/.cache/huggingface/accelerate/default_config.yaml accelerate config # .yaml files are under src/f5_tts/configs directory accelerate launch src/f5_tts/train/train.py --config-name F5TTS_v1_Base.yaml # possible to overwrite accelerate and hydra config accelerate launch --mixed_precision=fp16 src/f5_tts/train/train.py --config-name F5TTS_v1_Base.yaml ++datasets.batch_size_per_gpu=19200 ``` ### 2. Finetuning practice Discussion board for Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57). Gradio UI training/finetuning with `src/f5_tts/train/finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143). If want to finetune with a variant version e.g. *F5TTS_v1_Base_no_zero_init*, manually download pretrained checkpoint from model weight repository and fill in the path correspondingly on web interface. If use tensorboard as logger, install it first with `pip install tensorboard`. The `use_ema = True` might be harmful for early-stage finetuned checkpoints (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off with finetune gradio option or `load_model(..., use_ema=False)`, see if offer better results. ### 3. W&B Logging The `wandb/` dir will be created under path you run training/finetuning scripts. By default, the training script does NOT use logging (assuming you didn't manually log in using `wandb login`). To turn on wandb logging, you can either: 1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login) 2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/authorize and set the environment variable as follows: On Mac & Linux: ``` export WANDB_API_KEY= ``` On Windows: ``` set WANDB_API_KEY= ``` Moreover, if you couldn't access W&B and want to log metrics offline, you can set the environment variable as follows: ``` export WANDB_MODE=offline ``` ================================================ FILE: src/f5_tts/train/datasets/prepare_csv_wavs.py ================================================ """ Usage: python prepare_csv_wavs.py /path/to/metadata.csv /output/dataset/path [--pretrain] [--workers N] CSV format (header required, "|" delimiter): audio_file|text /path/to/wavs/audio_0001.wav|Yo! Hello? Hello? /path/to/wavs/audio_0002.wav|Hi, how are you doing today? I want to go shopping and buy me some lemons. Notes: - audio_file must be an absolute path. """ import concurrent.futures import multiprocessing import os import shutil import signal import subprocess import sys from contextlib import contextmanager sys.path.append(os.getcwd()) import argparse import csv import json from importlib.resources import files from pathlib import Path import soundfile as sf import torchaudio from datasets.arrow_writer import ArrowWriter from tqdm import tqdm from f5_tts.model.utils import convert_char_to_pinyin PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt") # Configuration constants BATCH_SIZE = 100 # Batch size for text conversion MAX_WORKERS = max(1, multiprocessing.cpu_count() - 1) # Leave one CPU free THREAD_NAME_PREFIX = "AudioProcessor" CHUNK_SIZE = 100 # Number of files to process per worker batch executor = None # Global executor for cleanup def is_csv_wavs_format(input_path): fpath = Path(input_path).expanduser() return fpath.is_file() and fpath.suffix.lower() == ".csv" @contextmanager def graceful_exit(): """Context manager for graceful shutdown on signals""" def signal_handler(signum, frame): print("\nReceived signal to terminate. Cleaning up...") if executor is not None: print("Shutting down executor...") executor.shutdown(wait=False, cancel_futures=True) sys.exit(1) # Set up signal handlers signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) try: yield finally: if executor is not None: executor.shutdown(wait=False) def process_audio_file(audio_path, text, polyphone): """Process a single audio file by checking its existence and extracting duration.""" if not Path(audio_path).exists(): print(f"audio {audio_path} not found, skipping") return None try: audio_duration = get_audio_duration(audio_path) if audio_duration <= 0: raise ValueError(f"Duration {audio_duration} is non-positive.") return (audio_path, text, audio_duration) except Exception as e: print(f"Warning: Failed to process {audio_path} due to error: {e}. Skipping corrupt file.") return None def batch_convert_texts(texts, polyphone, batch_size=BATCH_SIZE): """Convert a list of texts to pinyin in batches.""" converted_texts = [] for i in tqdm( range(0, len(texts), batch_size), total=(len(texts) + batch_size - 1) // batch_size, desc="Converting texts to pinyin", ): batch = texts[i : i + batch_size] converted_batch = convert_char_to_pinyin(batch, polyphone=polyphone) converted_texts.extend(converted_batch) return converted_texts def prepare_csv_wavs_dir(input_path, num_workers=None): global executor if not is_csv_wavs_format(input_path): raise ValueError(f"input must be a .csv file: {input_path}") audio_path_text_pairs = read_audio_text_pairs(Path(input_path).expanduser().as_posix()) polyphone = True total_files = len(audio_path_text_pairs) if total_files == 0: raise RuntimeError("No valid rows found in CSV.") # Use provided worker count or calculate optimal number worker_count = num_workers if num_workers is not None else min(MAX_WORKERS, total_files) print(f"\nProcessing {total_files} audio files using {worker_count} workers...") with graceful_exit(): # Initialize thread pool with optimized settings with concurrent.futures.ThreadPoolExecutor( max_workers=worker_count, thread_name_prefix=THREAD_NAME_PREFIX ) as exec: executor = exec results = [] # Process files in chunks for better efficiency for i in range(0, len(audio_path_text_pairs), CHUNK_SIZE): chunk = audio_path_text_pairs[i : i + CHUNK_SIZE] # Submit futures in order chunk_futures = [executor.submit(process_audio_file, pair[0], pair[1], polyphone) for pair in chunk] # Iterate over futures in the original submission order to preserve ordering for future in tqdm( chunk_futures, total=len(chunk), desc=f"Processing chunk {i // CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1) // CHUNK_SIZE}", ): try: result = future.result() if result is not None: results.append(result) except Exception as e: print(f"Error processing file: {e}") executor = None # Filter out failed results processed = [res for res in results if res is not None] if not processed: raise RuntimeError("No valid audio files were processed!") # Batch process text conversion raw_texts = [item[1] for item in processed] converted_texts = batch_convert_texts(raw_texts, polyphone, batch_size=BATCH_SIZE) # Prepare final results sub_result = [] durations = [] vocab_set = set() for (audio_path, _, duration), conv_text in zip(processed, converted_texts): sub_result.append({"audio_path": audio_path, "text": conv_text, "duration": duration}) durations.append(duration) vocab_set.update(list(conv_text)) return sub_result, durations, vocab_set def get_audio_duration(audio_path, timeout=5): """Get the duration of an audio file in seconds with fallbacks.""" try: return sf.info(audio_path).duration except Exception as e: print(f"Warning: soundfile failed for {audio_path} with error: {e}. Falling back to ffprobe.") try: cmd = [ "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", audio_path, ] result = subprocess.run( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True, timeout=timeout ) duration_str = result.stdout.strip() if duration_str: return float(duration_str) raise ValueError("Empty duration string from ffprobe.") except (subprocess.TimeoutExpired, subprocess.SubprocessError, ValueError) as e: print(f"Warning: ffprobe failed for {audio_path} with error: {e}. Falling back to torchaudio.info.") try: info = torchaudio.info(audio_path) if info.sample_rate > 0: return info.num_frames / info.sample_rate raise ValueError("Invalid sample_rate from torchaudio.info.") except Exception as e: raise RuntimeError(f"failed to get duration for {audio_path}: {e}") def read_audio_text_pairs(csv_file_path): audio_text_pairs = [] csv_path = Path(csv_file_path).expanduser().absolute() with open(csv_path.as_posix(), mode="r", newline="", encoding="utf-8-sig") as csvfile: reader = csv.reader(csvfile, delimiter="|") header = next(reader, None) if header is None: return audio_text_pairs if len(header) < 2 or header[0].strip() != "audio_file" or header[1].strip() != "text": raise ValueError("CSV header must be: audio_file|text") for row_idx, row in enumerate(reader, start=2): if len(row) < 2: continue audio_file = row[0].strip() text = row[1].strip() if not audio_file: continue audio_path = Path(audio_file).expanduser() if not audio_path.is_absolute(): raise ValueError(f"audio_file must be an absolute path (row {row_idx}): {audio_file}") audio_text_pairs.append((audio_path.as_posix(), text)) return audio_text_pairs def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune): out_dir = Path(out_dir) out_dir.mkdir(exist_ok=True, parents=True) print(f"\nSaving to {out_dir} ...") raw_arrow_path = out_dir / "raw.arrow" with ArrowWriter(path=raw_arrow_path.as_posix()) as writer: for line in tqdm(result, desc="Writing to raw.arrow ..."): writer.write(line) writer.finalize() # Save durations to JSON dur_json_path = out_dir / "duration.json" with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f: json.dump({"duration": duration_list}, f, ensure_ascii=False) # Handle vocab file - write only once based on finetune flag voca_out_path = out_dir / "vocab.txt" if is_finetune: file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix() shutil.copy2(file_vocab_finetune, voca_out_path) else: with open(voca_out_path.as_posix(), "w") as f: for vocab in sorted(text_vocab_set): f.write(vocab + "\n") dataset_name = out_dir.stem print(f"\nFor {dataset_name}, sample count: {len(result)}") print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}") print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours") def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True, num_workers: int = None): if is_finetune: assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}" sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir, num_workers=num_workers) save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune) def get_args(): parser = argparse.ArgumentParser(description="Prepare and save dataset.") parser.add_argument( "inp_dir", type=str, help="Input CSV with header 'audio_file|text' and absolute wav paths.", ) parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.") parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune") parser.add_argument("--workers", type=int, help=f"Number of worker threads (default: {MAX_WORKERS})") return parser.parse_args() def cli(): try: args = get_args() prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain, num_workers=args.workers) except KeyboardInterrupt: print("\nOperation cancelled by user. Cleaning up...") if executor is not None: executor.shutdown(wait=False, cancel_futures=True) sys.exit(1) if __name__ == "__main__": cli() ================================================ FILE: src/f5_tts/train/datasets/prepare_emilia.py ================================================ # Emilia Dataset: https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07 # if use updated new version, i.e. WebDataset, feel free to modify / draft your own script # generate audio text map for Emilia ZH & EN # evaluate for vocab size import os import sys sys.path.append(os.getcwd()) import json from concurrent.futures import ProcessPoolExecutor from importlib.resources import files from pathlib import Path from datasets.arrow_writer import ArrowWriter from tqdm import tqdm from f5_tts.model.utils import convert_char_to_pinyin, repetition_found out_zh = { "ZH_B00041_S06226", "ZH_B00042_S09204", "ZH_B00065_S09430", "ZH_B00065_S09431", "ZH_B00066_S09327", "ZH_B00066_S09328", } zh_filters = ["い", "て"] # seems synthesized audios, or heavily code-switched out_en = { "EN_B00013_S00913", "EN_B00042_S00120", "EN_B00055_S04111", "EN_B00061_S00693", "EN_B00061_S01494", "EN_B00061_S03375", "EN_B00059_S00092", "EN_B00111_S04300", "EN_B00100_S03759", "EN_B00087_S03811", "EN_B00059_S00950", "EN_B00089_S00946", "EN_B00078_S05127", "EN_B00070_S04089", "EN_B00074_S09659", "EN_B00061_S06983", "EN_B00061_S07060", "EN_B00059_S08397", "EN_B00082_S06192", "EN_B00091_S01238", "EN_B00089_S07349", "EN_B00070_S04343", "EN_B00061_S02400", "EN_B00076_S01262", "EN_B00068_S06467", "EN_B00076_S02943", "EN_B00064_S05954", "EN_B00061_S05386", "EN_B00066_S06544", "EN_B00076_S06944", "EN_B00072_S08620", "EN_B00076_S07135", "EN_B00076_S09127", "EN_B00065_S00497", "EN_B00059_S06227", "EN_B00063_S02859", "EN_B00075_S01547", "EN_B00061_S08286", "EN_B00079_S02901", "EN_B00092_S03643", "EN_B00096_S08653", "EN_B00063_S04297", "EN_B00063_S04614", "EN_B00079_S04698", "EN_B00104_S01666", "EN_B00061_S09504", "EN_B00061_S09694", "EN_B00065_S05444", "EN_B00063_S06860", "EN_B00065_S05725", "EN_B00069_S07628", "EN_B00083_S03875", "EN_B00071_S07665", "EN_B00071_S07665", "EN_B00062_S04187", "EN_B00065_S09873", "EN_B00065_S09922", "EN_B00084_S02463", "EN_B00067_S05066", "EN_B00106_S08060", "EN_B00073_S06399", "EN_B00073_S09236", "EN_B00087_S00432", "EN_B00085_S05618", "EN_B00064_S01262", "EN_B00072_S01739", "EN_B00059_S03913", "EN_B00069_S04036", "EN_B00067_S05623", "EN_B00060_S05389", "EN_B00060_S07290", "EN_B00062_S08995", } en_filters = ["ا", "い", "て"] def deal_with_audio_dir(audio_dir): audio_jsonl = audio_dir.with_suffix(".jsonl") sub_result, durations = [], [] vocab_set = set() bad_case_zh = 0 bad_case_en = 0 with open(audio_jsonl, "r") as f: lines = f.readlines() for line in tqdm(lines, desc=f"{audio_jsonl.stem}"): obj = json.loads(line) text = obj["text"] if obj["language"] == "zh": if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text): bad_case_zh += 1 continue else: text = text.translate( str.maketrans({",": ",", "!": "!", "?": "?"}) ) # not "。" cuz much code-switched if obj["language"] == "en": if ( obj["wav"].split("/")[1] in out_en or any(f in text for f in en_filters) or repetition_found(text, length=4) ): bad_case_en += 1 continue if tokenizer == "pinyin": text = convert_char_to_pinyin([text], polyphone=polyphone)[0] duration = obj["duration"] sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration}) durations.append(duration) vocab_set.update(list(text)) return sub_result, durations, vocab_set, bad_case_zh, bad_case_en def main(): assert tokenizer in ["pinyin", "char"] result = [] duration_list = [] text_vocab_set = set() total_bad_case_zh = 0 total_bad_case_en = 0 # process raw data executor = ProcessPoolExecutor(max_workers=max_workers) futures = [] for lang in langs: dataset_path = Path(os.path.join(dataset_dir, lang)) [ futures.append(executor.submit(deal_with_audio_dir, audio_dir)) for audio_dir in dataset_path.iterdir() if audio_dir.is_dir() ] for futures in tqdm(futures, total=len(futures)): sub_result, durations, vocab_set, bad_case_zh, bad_case_en = futures.result() result.extend(sub_result) duration_list.extend(durations) text_vocab_set.update(vocab_set) total_bad_case_zh += bad_case_zh total_bad_case_en += bad_case_en executor.shutdown() # save preprocessed dataset to disk if not os.path.exists(f"{save_dir}"): os.makedirs(f"{save_dir}") print(f"\nSaving to {save_dir} ...") # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom # dataset.save_to_disk(f"{save_dir}/raw", max_shard_size="2GB") with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer: for line in tqdm(result, desc="Writing to raw.arrow ..."): writer.write(line) writer.finalize() # dup a json separately saving duration in case for DynamicBatchSampler ease with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f: json.dump({"duration": duration_list}, f, ensure_ascii=False) # vocab map, i.e. tokenizer # add alphabets and symbols (optional, if plan to ft on de/fr etc.) # if tokenizer == "pinyin": # text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)]) with open(f"{save_dir}/vocab.txt", "w") as f: for vocab in sorted(text_vocab_set): f.write(vocab + "\n") print(f"\nFor {dataset_name}, sample count: {len(result)}") print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}") print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours") if "ZH" in langs: print(f"Bad zh transcription case: {total_bad_case_zh}") if "EN" in langs: print(f"Bad en transcription case: {total_bad_case_en}\n") if __name__ == "__main__": max_workers = 32 tokenizer = "pinyin" # "pinyin" | "char" polyphone = True langs = ["ZH", "EN"] dataset_dir = "/Emilia_Dataset/raw" dataset_name = f"Emilia_{'_'.join(langs)}_{tokenizer}" save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}" print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n") main() # Emilia ZH & EN # samples count 37837916 (after removal) # pinyin vocab size 2543 (polyphone) # total duration 95281.87 (hours) # bad zh asr cnt 230435 (samples) # bad eh asr cnt 37217 (samples) # vocab size may be slightly different due to rjieba tokenizer and pypinyin (e.g. way of polyphoneme) # please be careful if using pretrained model, make sure the vocab.txt is same ================================================ FILE: src/f5_tts/train/datasets/prepare_emilia_v2.py ================================================ # put in src/f5_tts/train/datasets/prepare_emilia_v2.py # prepares Emilia dataset with the new format w/ Emilia-YODAS import json import os from concurrent.futures import ProcessPoolExecutor from importlib.resources import files from pathlib import Path from datasets.arrow_writer import ArrowWriter from tqdm import tqdm from f5_tts.model.utils import repetition_found # Define filters for exclusion out_en = set() en_filters = ["ا", "い", "て"] def process_audio_directory(audio_dir): sub_result, durations, vocab_set = [], [], set() bad_case_en = 0 for file in audio_dir.iterdir(): if file.suffix == ".json": with open(file, "r") as f: obj = json.load(f) text = obj["text"] if any(f in text for f in en_filters) or repetition_found(text, length=4): bad_case_en += 1 continue duration = obj["duration"] audio_file = file.with_suffix(".mp3") if audio_file.exists(): sub_result.append({"audio_path": str(audio_file), "text": text, "duration": duration}) durations.append(duration) vocab_set.update(list(text)) return sub_result, durations, vocab_set, bad_case_en def main(): assert tokenizer in ["pinyin", "char"] result, duration_list, text_vocab_set = [], [], set() total_bad_case_en = 0 executor = ProcessPoolExecutor(max_workers=max_workers) futures = [] dataset_path = Path(dataset_dir) for sub_dir in dataset_path.iterdir(): if sub_dir.is_dir(): futures.append(executor.submit(process_audio_directory, sub_dir)) for future in tqdm(futures, total=len(futures)): sub_result, durations, vocab_set, bad_case_en = future.result() result.extend(sub_result) duration_list.extend(durations) text_vocab_set.update(vocab_set) total_bad_case_en += bad_case_en executor.shutdown() if not os.path.exists(f"{save_dir}"): os.makedirs(f"{save_dir}") with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer: for line in tqdm(result, desc="Writing to raw.arrow ..."): writer.write(line) writer.finalize() with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f: json.dump({"duration": duration_list}, f, ensure_ascii=False) with open(f"{save_dir}/vocab.txt", "w") as f: for vocab in sorted(text_vocab_set): f.write(vocab + "\n") print(f"For {dataset_name}, sample count: {len(result)}") print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}") print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours") print(f"Bad en transcription case: {total_bad_case_en}\n") if __name__ == "__main__": max_workers = 32 tokenizer = "char" dataset_dir = "/home/ubuntu/emilia-dataset/Emilia-YODAS/EN" dataset_name = f"Emilia_EN_{tokenizer}" # save_dir = os.path.expanduser(f"~/F5-TTS/data/{dataset_name}") save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}" print(f"Prepare for {dataset_name}, will save to {save_dir}\n") main() ================================================ FILE: src/f5_tts/train/datasets/prepare_libritts.py ================================================ import os import sys sys.path.append(os.getcwd()) import json from concurrent.futures import ProcessPoolExecutor from importlib.resources import files from pathlib import Path import soundfile as sf from datasets.arrow_writer import ArrowWriter from tqdm import tqdm def deal_with_audio_dir(audio_dir): sub_result, durations = [], [] vocab_set = set() audio_lists = list(audio_dir.rglob("*.wav")) for line in audio_lists: text_path = line.with_suffix(".normalized.txt") text = open(text_path, "r").read().strip() duration = sf.info(line).duration if duration < 0.4 or duration > 30: continue sub_result.append({"audio_path": str(line), "text": text, "duration": duration}) durations.append(duration) vocab_set.update(list(text)) return sub_result, durations, vocab_set def main(): result = [] duration_list = [] text_vocab_set = set() # process raw data executor = ProcessPoolExecutor(max_workers=max_workers) futures = [] for subset in tqdm(SUB_SET): dataset_path = Path(os.path.join(dataset_dir, subset)) [ futures.append(executor.submit(deal_with_audio_dir, audio_dir)) for audio_dir in dataset_path.iterdir() if audio_dir.is_dir() ] for future in tqdm(futures, total=len(futures)): sub_result, durations, vocab_set = future.result() result.extend(sub_result) duration_list.extend(durations) text_vocab_set.update(vocab_set) executor.shutdown() # save preprocessed dataset to disk if not os.path.exists(f"{save_dir}"): os.makedirs(f"{save_dir}") print(f"\nSaving to {save_dir} ...") with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer: for line in tqdm(result, desc="Writing to raw.arrow ..."): writer.write(line) writer.finalize() # dup a json separately saving duration in case for DynamicBatchSampler ease with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f: json.dump({"duration": duration_list}, f, ensure_ascii=False) # vocab map, i.e. tokenizer with open(f"{save_dir}/vocab.txt", "w") as f: for vocab in sorted(text_vocab_set): f.write(vocab + "\n") print(f"\nFor {dataset_name}, sample count: {len(result)}") print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}") print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours") if __name__ == "__main__": max_workers = 36 tokenizer = "char" # "pinyin" | "char" SUB_SET = ["train-clean-100", "train-clean-360", "train-other-500"] dataset_dir = "/LibriTTS" dataset_name = f"LibriTTS_{'_'.join(SUB_SET)}_{tokenizer}".replace("train-clean-", "").replace("train-other-", "") save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}" print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n") main() # For LibriTTS_100_360_500_char, sample count: 354218 # For LibriTTS_100_360_500_char, vocab size is: 78 # For LibriTTS_100_360_500_char, total 554.09 hours ================================================ FILE: src/f5_tts/train/datasets/prepare_ljspeech.py ================================================ import os import sys sys.path.append(os.getcwd()) import json from importlib.resources import files from pathlib import Path import soundfile as sf from datasets.arrow_writer import ArrowWriter from tqdm import tqdm def main(): result = [] duration_list = [] text_vocab_set = set() with open(meta_info, "r") as f: lines = f.readlines() for line in tqdm(lines): uttr, text, norm_text = line.split("|") norm_text = norm_text.strip() wav_path = Path(dataset_dir) / "wavs" / f"{uttr}.wav" duration = sf.info(wav_path).duration if duration < 0.4 or duration > 30: continue result.append({"audio_path": str(wav_path), "text": norm_text, "duration": duration}) duration_list.append(duration) text_vocab_set.update(list(norm_text)) # save preprocessed dataset to disk if not os.path.exists(f"{save_dir}"): os.makedirs(f"{save_dir}") print(f"\nSaving to {save_dir} ...") with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer: for line in tqdm(result, desc="Writing to raw.arrow ..."): writer.write(line) writer.finalize() # dup a json separately saving duration in case for DynamicBatchSampler ease with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f: json.dump({"duration": duration_list}, f, ensure_ascii=False) # vocab map, i.e. tokenizer # add alphabets and symbols (optional, if plan to ft on de/fr etc.) with open(f"{save_dir}/vocab.txt", "w") as f: for vocab in sorted(text_vocab_set): f.write(vocab + "\n") print(f"\nFor {dataset_name}, sample count: {len(result)}") print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}") print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours") if __name__ == "__main__": tokenizer = "char" # "pinyin" | "char" dataset_dir = "/LJSpeech-1.1" dataset_name = f"LJSpeech_{tokenizer}" meta_info = os.path.join(dataset_dir, "metadata.csv") save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}" print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n") main() ================================================ FILE: src/f5_tts/train/datasets/prepare_wenetspeech4tts.py ================================================ # generate audio text map for WenetSpeech4TTS # evaluate for vocab size import os import sys sys.path.append(os.getcwd()) import json from concurrent.futures import ProcessPoolExecutor from importlib.resources import files import torchaudio from datasets import Dataset from tqdm import tqdm from f5_tts.model.utils import convert_char_to_pinyin def deal_with_sub_path_files(dataset_path, sub_path): print(f"Dealing with: {sub_path}") text_dir = os.path.join(dataset_path, sub_path, "txts") audio_dir = os.path.join(dataset_path, sub_path, "wavs") text_files = os.listdir(text_dir) audio_paths, texts, durations = [], [], [] for text_file in tqdm(text_files): with open(os.path.join(text_dir, text_file), "r", encoding="utf-8") as file: first_line = file.readline().split("\t") audio_nm = first_line[0] audio_path = os.path.join(audio_dir, audio_nm + ".wav") text = first_line[1].strip() audio_paths.append(audio_path) if tokenizer == "pinyin": texts.extend(convert_char_to_pinyin([text], polyphone=polyphone)) elif tokenizer == "char": texts.append(text) audio, sample_rate = torchaudio.load(audio_path) durations.append(audio.shape[-1] / sample_rate) return audio_paths, texts, durations def main(): assert tokenizer in ["pinyin", "char"] audio_path_list, text_list, duration_list = [], [], [] executor = ProcessPoolExecutor(max_workers=max_workers) futures = [] for dataset_path in dataset_paths: sub_items = os.listdir(dataset_path) sub_paths = [item for item in sub_items if os.path.isdir(os.path.join(dataset_path, item))] for sub_path in sub_paths: futures.append(executor.submit(deal_with_sub_path_files, dataset_path, sub_path)) for future in tqdm(futures, total=len(futures)): audio_paths, texts, durations = future.result() audio_path_list.extend(audio_paths) text_list.extend(texts) duration_list.extend(durations) executor.shutdown() if not os.path.exists("data"): os.makedirs("data") print(f"\nSaving to {save_dir} ...") dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) dataset.save_to_disk(f"{save_dir}/raw", max_shard_size="2GB") # arrow format with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f: json.dump( {"duration": duration_list}, f, ensure_ascii=False ) # dup a json separately saving duration in case for DynamicBatchSampler ease print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...") text_vocab_set = set() for text in tqdm(text_list): text_vocab_set.update(list(text)) # add alphabets and symbols (optional, if plan to ft on de/fr etc.) if tokenizer == "pinyin": text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)]) with open(f"{save_dir}/vocab.txt", "w") as f: for vocab in sorted(text_vocab_set): f.write(vocab + "\n") print(f"\nFor {dataset_name}, sample count: {len(text_list)}") print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n") if __name__ == "__main__": max_workers = 32 tokenizer = "pinyin" # "pinyin" | "char" polyphone = True dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic dataset_name = ( ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice - 1] + "_" + tokenizer ) dataset_paths = [ "/WenetSpeech4TTS/Basic", "/WenetSpeech4TTS/Standard", "/WenetSpeech4TTS/Premium", ][-dataset_choice:] save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}" print(f"\nChoose Dataset: {dataset_name}, will save to {save_dir}\n") main() # Results (if adding alphabets with accents and symbols): # WenetSpeech4TTS Basic Standard Premium # samples count 3932473 1941220 407494 # pinyin vocab size 1349 1348 1344 (no polyphone) # - - 1459 (polyphone) # char vocab size 5264 5219 5042 # vocab size may be slightly different due to rjieba tokenizer and pypinyin (e.g. way of polyphoneme) # please be careful if using pretrained model, make sure the vocab.txt is same ================================================ FILE: src/f5_tts/train/finetune_cli.py ================================================ import argparse import os import shutil from importlib.resources import files from cached_path import cached_path from f5_tts.model import CFM, DiT, Trainer, UNetT from f5_tts.model.dataset import load_dataset from f5_tts.model.utils import get_tokenizer # -------------------------- Dataset Settings --------------------------- # target_sample_rate = 24000 n_mel_channels = 100 hop_length = 256 win_length = 1024 n_fft = 1024 mel_spec_type = "vocos" # 'vocos' or 'bigvgan' # -------------------------- Argument Parsing --------------------------- # def parse_args(): parser = argparse.ArgumentParser(description="Train CFM Model") parser.add_argument( "--exp_name", type=str, default="F5TTS_v1_Base", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], help="Experiment name", ) parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use") parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training") parser.add_argument("--batch_size_per_gpu", type=int, default=3200, help="Batch size per GPU") parser.add_argument( "--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type" ) parser.add_argument("--max_samples", type=int, default=64, help="Max sequences per batch") parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps") parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping") parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs") parser.add_argument("--num_warmup_updates", type=int, default=20000, help="Warmup updates") parser.add_argument("--save_per_updates", type=int, default=50000, help="Save checkpoint every N updates") parser.add_argument( "--keep_last_n_checkpoints", type=int, default=-1, help="-1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints", ) parser.add_argument("--last_per_updates", type=int, default=5000, help="Save last checkpoint every N updates") parser.add_argument("--finetune", action="store_true", help="Use Finetune") parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint") parser.add_argument( "--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type" ) parser.add_argument( "--tokenizer_path", type=str, default=None, help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')", ) parser.add_argument( "--log_samples", action="store_true", help="Log inferenced samples per ckpt save updates", ) parser.add_argument("--logger", type=str, default=None, choices=[None, "wandb", "tensorboard"], help="logger") parser.add_argument( "--bnb_optimizer", action="store_true", help="Use 8-bit Adam optimizer from bitsandbytes", ) return parser.parse_args() # -------------------------- Training Settings -------------------------- # def main(): args = parse_args() checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}")) # Model parameters based on experiment name if args.exp_name == "F5TTS_v1_Base": wandb_resume_id = None model_cls = DiT model_cfg = dict( dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4, ) if args.finetune: if args.pretrain is None: ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors")) else: ckpt_path = args.pretrain elif args.exp_name == "F5TTS_Base": wandb_resume_id = None model_cls = DiT model_cfg = dict( dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, text_mask_padding=False, conv_layers=4, pe_attn_head=1, ) if args.finetune: if args.pretrain is None: ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt")) else: ckpt_path = args.pretrain elif args.exp_name == "E2TTS_Base": wandb_resume_id = None model_cls = UNetT model_cfg = dict( dim=1024, depth=24, heads=16, ff_mult=4, text_mask_padding=False, pe_attn_head=1, ) if args.finetune: if args.pretrain is None: ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt")) else: ckpt_path = args.pretrain if args.finetune: if not os.path.isdir(checkpoint_path): os.makedirs(checkpoint_path, exist_ok=True) file_checkpoint = os.path.basename(ckpt_path) if not file_checkpoint.startswith("pretrained_"): # Change: Add 'pretrained_' prefix to copied model file_checkpoint = "pretrained_" + file_checkpoint file_checkpoint = os.path.join(checkpoint_path, file_checkpoint) if not os.path.isfile(file_checkpoint): shutil.copy2(ckpt_path, file_checkpoint) print("copy checkpoint for finetune") # Use the tokenizer and tokenizer_path provided in the command line arguments tokenizer = args.tokenizer if tokenizer == "custom": if not args.tokenizer_path: raise ValueError("Custom tokenizer selected, but no tokenizer_path provided.") tokenizer_path = args.tokenizer_path else: tokenizer_path = args.dataset_name vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) print("\nvocab : ", vocab_size) print("\nvocoder : ", mel_spec_type) mel_spec_kwargs = dict( n_fft=n_fft, hop_length=hop_length, win_length=win_length, n_mel_channels=n_mel_channels, target_sample_rate=target_sample_rate, mel_spec_type=mel_spec_type, ) model = CFM( transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels), mel_spec_kwargs=mel_spec_kwargs, vocab_char_map=vocab_char_map, ) trainer = Trainer( model, args.epochs, args.learning_rate, num_warmup_updates=args.num_warmup_updates, save_per_updates=args.save_per_updates, keep_last_n_checkpoints=args.keep_last_n_checkpoints, checkpoint_path=checkpoint_path, batch_size_per_gpu=args.batch_size_per_gpu, batch_size_type=args.batch_size_type, max_samples=args.max_samples, grad_accumulation_steps=args.grad_accumulation_steps, max_grad_norm=args.max_grad_norm, logger=args.logger, wandb_project=args.dataset_name, wandb_run_name=args.exp_name, wandb_resume_id=wandb_resume_id, log_samples=args.log_samples, last_per_updates=args.last_per_updates, bnb_optimizer=args.bnb_optimizer, ) train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs) trainer.train( train_dataset, resumable_with_seed=666, # seed for shuffling dataset ) if __name__ == "__main__": main() ================================================ FILE: src/f5_tts/train/finetune_gradio.py ================================================ import gc import json import os import platform import queue import random import re import shutil import signal import subprocess import sys import tempfile import threading import time from glob import glob from importlib.resources import files import click import gradio as gr import librosa import numpy as np import psutil import torch import torchaudio from cached_path import cached_path from datasets import Dataset as Dataset_ from datasets.arrow_writer import ArrowWriter from safetensors.torch import load_file, save_file from scipy.io import wavfile from f5_tts.api import F5TTS from f5_tts.infer.utils_infer import transcribe from f5_tts.model.utils import convert_char_to_pinyin training_process = None system = platform.system() python_executable = sys.executable or "python" tts_api = None last_checkpoint = "" last_device = "" last_ema = None path_data = str(files("f5_tts").joinpath("../../data")) path_project_ckpts = str(files("f5_tts").joinpath("../../ckpts")) file_train = str(files("f5_tts").joinpath("train/finetune_cli.py")) device = ( "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) # Save settings from a JSON file def save_settings( project_name, exp_name, learning_rate, batch_size_per_gpu, batch_size_type, max_samples, grad_accumulation_steps, max_grad_norm, epochs, num_warmup_updates, save_per_updates, keep_last_n_checkpoints, last_per_updates, finetune, file_checkpoint_train, tokenizer_type, tokenizer_file, mixed_precision, logger, ch_8bit_adam, ): path_project = os.path.join(path_project_ckpts, project_name) os.makedirs(path_project, exist_ok=True) file_setting = os.path.join(path_project, "setting.json") settings = { "exp_name": exp_name, "learning_rate": learning_rate, "batch_size_per_gpu": batch_size_per_gpu, "batch_size_type": batch_size_type, "max_samples": max_samples, "grad_accumulation_steps": grad_accumulation_steps, "max_grad_norm": max_grad_norm, "epochs": epochs, "num_warmup_updates": num_warmup_updates, "save_per_updates": save_per_updates, "keep_last_n_checkpoints": keep_last_n_checkpoints, "last_per_updates": last_per_updates, "finetune": finetune, "file_checkpoint_train": file_checkpoint_train, "tokenizer_type": tokenizer_type, "tokenizer_file": tokenizer_file, "mixed_precision": mixed_precision, "logger": logger, "bnb_optimizer": ch_8bit_adam, } with open(file_setting, "w") as f: json.dump(settings, f, indent=4) return "Settings saved!" # Load settings from a JSON file def load_settings(project_name): project_name = project_name.replace("_pinyin", "").replace("_char", "") path_project = os.path.join(path_project_ckpts, project_name) file_setting = os.path.join(path_project, "setting.json") # Default settings default_settings = { "exp_name": "F5TTS_v1_Base", "learning_rate": 1e-5, "batch_size_per_gpu": 3200, "batch_size_type": "frame", "max_samples": 64, "grad_accumulation_steps": 1, "max_grad_norm": 1.0, "epochs": 100, "num_warmup_updates": 100, "save_per_updates": 500, "keep_last_n_checkpoints": -1, "last_per_updates": 100, "finetune": True, "file_checkpoint_train": "", "tokenizer_type": "pinyin", "tokenizer_file": "", "mixed_precision": "fp16", "logger": "none", "bnb_optimizer": False, } if device == "mps": default_settings["mixed_precision"] = "none" # Load settings from file if it exists if os.path.isfile(file_setting): with open(file_setting, "r") as f: file_settings = json.load(f) default_settings.update(file_settings) # Return as a tuple in the correct order return ( default_settings["exp_name"], default_settings["learning_rate"], default_settings["batch_size_per_gpu"], default_settings["batch_size_type"], default_settings["max_samples"], default_settings["grad_accumulation_steps"], default_settings["max_grad_norm"], default_settings["epochs"], default_settings["num_warmup_updates"], default_settings["save_per_updates"], default_settings["keep_last_n_checkpoints"], default_settings["last_per_updates"], default_settings["finetune"], default_settings["file_checkpoint_train"], default_settings["tokenizer_type"], default_settings["tokenizer_file"], default_settings["mixed_precision"], default_settings["logger"], default_settings["bnb_optimizer"], ) # Load metadata def get_audio_duration(audio_path): """Calculate the duration mono of an audio file.""" audio, sample_rate = torchaudio.load(audio_path) return audio.shape[1] / sample_rate class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py def __init__( self, sr: int, threshold: float = -40.0, min_length: int = 20000, # 20 seconds min_interval: int = 300, hop_size: int = 20, max_sil_kept: int = 2000, ): if not min_length >= min_interval >= hop_size: raise ValueError("The following condition must be satisfied: min_length >= min_interval >= hop_size") if not max_sil_kept >= hop_size: raise ValueError("The following condition must be satisfied: max_sil_kept >= hop_size") min_interval = sr * min_interval / 1000 self.threshold = 10 ** (threshold / 20.0) self.hop_size = round(sr * hop_size / 1000) self.win_size = min(round(min_interval), 4 * self.hop_size) self.min_length = round(sr * min_length / 1000 / self.hop_size) self.min_interval = round(min_interval / self.hop_size) self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size) def _apply_slice(self, waveform, begin, end): if len(waveform.shape) > 1: return waveform[:, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)] else: return waveform[begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)] # @timeit def slice(self, waveform): if len(waveform.shape) > 1: samples = waveform.mean(axis=0) else: samples = waveform if samples.shape[0] <= self.min_length: return [waveform] rms_list = librosa.feature.rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0) sil_tags = [] silence_start = None clip_start = 0 for i, rms in enumerate(rms_list): # Keep looping while frame is silent. if rms < self.threshold: # Record start of silent frames. if silence_start is None: silence_start = i continue # Keep looping while frame is not silent and silence start has not been recorded. if silence_start is None: continue # Clear recorded silence start if interval is not enough or clip is too short is_leading_silence = silence_start == 0 and i > self.max_sil_kept need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length if not is_leading_silence and not need_slice_middle: silence_start = None continue # Need slicing. Record the range of silent frames to be removed. if i - silence_start <= self.max_sil_kept: pos = rms_list[silence_start : i + 1].argmin() + silence_start if silence_start == 0: sil_tags.append((0, pos)) else: sil_tags.append((pos, pos)) clip_start = pos elif i - silence_start <= self.max_sil_kept * 2: pos = rms_list[i - self.max_sil_kept : silence_start + self.max_sil_kept + 1].argmin() pos += i - self.max_sil_kept pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept if silence_start == 0: sil_tags.append((0, pos_r)) clip_start = pos_r else: sil_tags.append((min(pos_l, pos), max(pos_r, pos))) clip_start = max(pos_r, pos) else: pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept if silence_start == 0: sil_tags.append((0, pos_r)) else: sil_tags.append((pos_l, pos_r)) clip_start = pos_r silence_start = None # Deal with trailing silence. total_frames = rms_list.shape[0] if silence_start is not None and total_frames - silence_start >= self.min_interval: silence_end = min(total_frames, silence_start + self.max_sil_kept) pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start sil_tags.append((pos, total_frames + 1)) # Apply and return slices: [chunk, start, end] if len(sil_tags) == 0: return [[waveform, 0, int(total_frames * self.hop_size)]] else: chunks = [] if sil_tags[0][0] > 0: chunks.append([self._apply_slice(waveform, 0, sil_tags[0][0]), 0, int(sil_tags[0][0] * self.hop_size)]) for i in range(len(sil_tags) - 1): chunks.append( [ self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]), int(sil_tags[i][1] * self.hop_size), int(sil_tags[i + 1][0] * self.hop_size), ] ) if sil_tags[-1][1] < total_frames: chunks.append( [ self._apply_slice(waveform, sil_tags[-1][1], total_frames), int(sil_tags[-1][1] * self.hop_size), int(total_frames * self.hop_size), ] ) return chunks # terminal def terminate_process_tree(pid, including_parent=True): try: parent = psutil.Process(pid) except psutil.NoSuchProcess: # Process already terminated return children = parent.children(recursive=True) for child in children: try: os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL except OSError: pass if including_parent: try: os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL except OSError: pass def terminate_process(pid): if system == "Windows": cmd = f"taskkill /t /f /pid {pid}" os.system(cmd) else: terminate_process_tree(pid) def start_training( dataset_name, exp_name, learning_rate, batch_size_per_gpu, batch_size_type, max_samples, grad_accumulation_steps, max_grad_norm, epochs, num_warmup_updates, save_per_updates, keep_last_n_checkpoints, last_per_updates, finetune, file_checkpoint_train, tokenizer_type, tokenizer_file, mixed_precision, stream, logger, ch_8bit_adam, ): global training_process, tts_api, stop_signal if tts_api is not None: if tts_api is not None: del tts_api gc.collect() torch.cuda.empty_cache() tts_api = None path_project = os.path.join(path_data, dataset_name) if not os.path.isdir(path_project): yield ( f"There is not project with name {dataset_name}", gr.update(interactive=True), gr.update(interactive=False), ) return file_raw = os.path.join(path_project, "raw.arrow") if not os.path.isfile(file_raw): yield f"There is no file {file_raw}", gr.update(interactive=True), gr.update(interactive=False) return # Check if a training process is already running if training_process is not None: return "Train run already!", gr.update(interactive=False), gr.update(interactive=True) yield "start train", gr.update(interactive=False), gr.update(interactive=False) # Command to run the training script with the specified arguments if tokenizer_file == "": if dataset_name.endswith("_pinyin"): tokenizer_type = "pinyin" elif dataset_name.endswith("_char"): tokenizer_type = "char" else: tokenizer_type = "custom" dataset_name = dataset_name.replace("_pinyin", "").replace("_char", "") if mixed_precision != "none": fp16 = f"--mixed_precision={mixed_precision}" else: fp16 = "" cmd = ( f'accelerate launch {fp16} "{file_train}" --exp_name {exp_name}' f" --learning_rate {learning_rate}" f" --batch_size_per_gpu {batch_size_per_gpu}" f" --batch_size_type {batch_size_type}" f" --max_samples {max_samples}" f" --grad_accumulation_steps {grad_accumulation_steps}" f" --max_grad_norm {max_grad_norm}" f" --epochs {epochs}" f" --num_warmup_updates {num_warmup_updates}" f" --save_per_updates {save_per_updates}" f" --keep_last_n_checkpoints {keep_last_n_checkpoints}" f" --last_per_updates {last_per_updates}" f" --dataset_name {dataset_name}" ) if finetune: cmd += " --finetune" if file_checkpoint_train != "": cmd += f' --pretrain "{file_checkpoint_train}"' if tokenizer_file != "": cmd += f" --tokenizer_path {tokenizer_file}" cmd += f" --tokenizer {tokenizer_type}" if logger != "none": cmd += f" --logger {logger}" cmd += " --log_samples" if ch_8bit_adam: cmd += " --bnb_optimizer" print("run command : \n" + cmd + "\n") save_settings( dataset_name, exp_name, learning_rate, batch_size_per_gpu, batch_size_type, max_samples, grad_accumulation_steps, max_grad_norm, epochs, num_warmup_updates, save_per_updates, keep_last_n_checkpoints, last_per_updates, finetune, file_checkpoint_train, tokenizer_type, tokenizer_file, mixed_precision, logger, ch_8bit_adam, ) try: if not stream: # Start the training process training_process = subprocess.Popen(cmd, shell=True) time.sleep(5) yield "train start", gr.update(interactive=False), gr.update(interactive=True) # Wait for the training process to finish training_process.wait() else: def stream_output(pipe, output_queue): try: for line in iter(pipe.readline, ""): output_queue.put(line) except Exception as e: output_queue.put(f"Error reading pipe: {str(e)}") finally: pipe.close() env = os.environ.copy() env["PYTHONUNBUFFERED"] = "1" training_process = subprocess.Popen( cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1, env=env ) yield "Training started ...", gr.update(interactive=False), gr.update(interactive=True) stdout_queue = queue.Queue() stderr_queue = queue.Queue() stdout_thread = threading.Thread(target=stream_output, args=(training_process.stdout, stdout_queue)) stderr_thread = threading.Thread(target=stream_output, args=(training_process.stderr, stderr_queue)) stdout_thread.daemon = True stderr_thread.daemon = True stdout_thread.start() stderr_thread.start() stop_signal = False while True: if stop_signal: training_process.terminate() time.sleep(0.5) if training_process.poll() is None: training_process.kill() yield "Training stopped by user.", gr.update(interactive=True), gr.update(interactive=False) break process_status = training_process.poll() # Handle stdout try: while True: output = stdout_queue.get_nowait() print(output, end="") match = re.search( r"Epoch (\d+)/(\d+):\s+(\d+)%\|.*\[(\d+:\d+)<.*?loss=(\d+\.\d+), update=(\d+)", output ) if match: current_epoch = match.group(1) total_epochs = match.group(2) percent_complete = match.group(3) elapsed_time = match.group(4) loss = match.group(5) current_update = match.group(6) message = ( f"Epoch: {current_epoch}/{total_epochs}, " f"Progress: {percent_complete}%, " f"Elapsed Time: {elapsed_time}, " f"Loss: {loss}, " f"Update: {current_update}" ) yield message, gr.update(interactive=False), gr.update(interactive=True) elif output.strip(): yield output, gr.update(interactive=False), gr.update(interactive=True) except queue.Empty: pass # Handle stderr try: while True: error_output = stderr_queue.get_nowait() print(error_output, end="") if error_output.strip(): yield f"{error_output.strip()}", gr.update(interactive=False), gr.update(interactive=True) except queue.Empty: pass if process_status is not None and stdout_queue.empty() and stderr_queue.empty(): if process_status != 0: yield ( f"Process crashed with exit code {process_status}!", gr.update(interactive=False), gr.update(interactive=True), ) else: yield ( "Training complete or paused ...", gr.update(interactive=False), gr.update(interactive=True), ) break # Small sleep to prevent CPU thrashing time.sleep(0.1) # Clean up training_process.stdout.close() training_process.stderr.close() training_process.wait() time.sleep(1) if training_process is None: text_info = "Train stopped !" else: text_info = "Train complete at end !" except Exception as e: # Catch all exceptions # Ensure that we reset the training process variable in case of an error text_info = f"An error occurred: {str(e)}" training_process = None yield text_info, gr.update(interactive=True), gr.update(interactive=False) def stop_training(): global training_process, stop_signal if training_process is None: return "Train not running !", gr.update(interactive=True), gr.update(interactive=False) terminate_process_tree(training_process.pid) # training_process = None stop_signal = True return "Train stopped !", gr.update(interactive=True), gr.update(interactive=False) def get_list_projects(): project_list = [] for folder in os.listdir(path_data): path_folder = os.path.join(path_data, folder) if not os.path.isdir(path_folder): continue folder = folder.lower() if folder == "emilia_zh_en_pinyin": continue project_list.append(folder) projects_selelect = None if not project_list else project_list[-1] return project_list, projects_selelect def create_data_project(name, tokenizer_type): name += "_" + tokenizer_type os.makedirs(os.path.join(path_data, name), exist_ok=True) os.makedirs(os.path.join(path_data, name, "dataset"), exist_ok=True) project_list, projects_selelect = get_list_projects() return gr.update(choices=project_list, value=name) def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()): path_project = os.path.join(path_data, name_project) path_dataset = os.path.join(path_project, "dataset") path_project_wavs = os.path.join(path_project, "wavs") file_metadata = os.path.join(path_project, "metadata.csv") if not user: if audio_files is None: return "You need to load an audio file." if os.path.isdir(path_project_wavs): shutil.rmtree(path_project_wavs) if os.path.isfile(file_metadata): os.remove(file_metadata) os.makedirs(path_project_wavs, exist_ok=True) if user: file_audios = [ file for format in ("*.wav", "*.ogg", "*.opus", "*.mp3", "*.flac") for file in glob(os.path.join(path_dataset, format)) ] if file_audios == []: return "No audio file was found in the dataset." else: file_audios = audio_files alpha = 0.5 _max = 1.0 slicer = Slicer(24000) num = 0 error_num = 0 data = "" for file_audio in progress.tqdm(file_audios, desc="transcribe files", total=len((file_audios))): audio, _ = librosa.load(file_audio, sr=24000, mono=True) list_slicer = slicer.slice(audio) for chunk, start, end in progress.tqdm(list_slicer, total=len(list_slicer), desc="slicer files"): name_segment = os.path.join(f"segment_{num}") file_segment = os.path.join(path_project_wavs, f"{name_segment}.wav") tmp_max = np.abs(chunk).max() if tmp_max > 1: chunk /= tmp_max chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk wavfile.write(file_segment, 24000, (chunk * 32767).astype(np.int16)) try: text = transcribe(file_segment, language) text = text.strip() data += f"{name_segment}|{text}\n" num += 1 except: # noqa: E722 error_num += 1 with open(file_metadata, "w", encoding="utf-8-sig") as f: f.write(data) if error_num != []: error_text = f"\nerror files : {error_num}" else: error_text = "" return f"transcribe complete samples : {num}\npath : {path_project_wavs}{error_text}" def format_seconds_to_hms(seconds): hours = int(seconds / 3600) minutes = int((seconds % 3600) / 60) seconds = seconds % 60 return "{:02d}:{:02d}:{:02d}".format(hours, minutes, int(seconds)) def get_correct_audio_path( audio_input, base_path="wavs", supported_formats=("wav", "mp3", "aac", "flac", "m4a", "alac", "ogg", "aiff", "wma", "amr"), ): file_audio = None # Helper function to check if file has a supported extension def has_supported_extension(file_name): return any(file_name.endswith(f".{ext}") for ext in supported_formats) # Case 1: If it's a full path with a valid extension, use it directly if os.path.isabs(audio_input) and has_supported_extension(audio_input): file_audio = audio_input # Case 2: If it has a supported extension but is not a full path elif has_supported_extension(audio_input) and not os.path.isabs(audio_input): file_audio = os.path.join(base_path, audio_input) # Case 3: If only the name is given (no extension and not a full path) elif not has_supported_extension(audio_input) and not os.path.isabs(audio_input): for ext in supported_formats: potential_file = os.path.join(base_path, f"{audio_input}.{ext}") if os.path.exists(potential_file): file_audio = potential_file break else: file_audio = os.path.join(base_path, f"{audio_input}.{supported_formats[0]}") return file_audio def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()): path_project = os.path.join(path_data, name_project) path_project_wavs = os.path.join(path_project, "wavs") file_metadata = os.path.join(path_project, "metadata.csv") file_raw = os.path.join(path_project, "raw.arrow") file_duration = os.path.join(path_project, "duration.json") file_vocab = os.path.join(path_project, "vocab.txt") if not os.path.isfile(file_metadata): return "The file was not found in " + file_metadata, "" with open(file_metadata, "r", encoding="utf-8-sig") as f: data = f.read() audio_path_list = [] text_list = [] duration_list = [] count = data.split("\n") lenght = 0 result = [] error_files = [] text_vocab_set = set() for line in progress.tqdm(data.split("\n"), total=count): sp_line = line.split("|") if len(sp_line) != 2: continue name_audio, text = sp_line[:2] file_audio = get_correct_audio_path(name_audio, path_project_wavs) if not os.path.isfile(file_audio): error_files.append([file_audio, "error path"]) continue try: duration = get_audio_duration(file_audio) except Exception as e: error_files.append([file_audio, "duration"]) print(f"Error processing {file_audio}: {e}") continue if duration < 1 or duration > 30: if duration > 30: error_files.append([file_audio, "duration > 30 sec"]) if duration < 1: error_files.append([file_audio, "duration < 1 sec "]) continue if len(text) < 3: error_files.append([file_audio, "very short text length 3"]) continue text = text.strip() text = convert_char_to_pinyin([text], polyphone=True)[0] audio_path_list.append(file_audio) duration_list.append(duration) text_list.append(text) result.append({"audio_path": file_audio, "text": text, "duration": duration}) if ch_tokenizer: text_vocab_set.update(list(text)) lenght += duration if duration_list == []: return f"Error: No audio files found in the specified path : {path_project_wavs}", "" min_second = round(min(duration_list), 2) max_second = round(max(duration_list), 2) with ArrowWriter(path=file_raw) as writer: for line in progress.tqdm(result, total=len(result), desc="prepare data"): writer.write(line) writer.finalize() with open(file_duration, "w") as f: json.dump({"duration": duration_list}, f, ensure_ascii=False) new_vocal = "" if not ch_tokenizer: if not os.path.isfile(file_vocab): file_vocab_finetune = os.path.join(path_data, "Emilia_ZH_EN_pinyin/vocab.txt") if not os.path.isfile(file_vocab_finetune): return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!", "" shutil.copy2(file_vocab_finetune, file_vocab) with open(file_vocab, "r", encoding="utf-8-sig") as f: vocab_char_map = {} for i, char in enumerate(f): vocab_char_map[char[:-1]] = i vocab_size = len(vocab_char_map) else: with open(file_vocab, "w", encoding="utf-8-sig") as f: for vocab in sorted(text_vocab_set): f.write(vocab + "\n") new_vocal += vocab + "\n" vocab_size = len(text_vocab_set) if error_files != []: error_text = "\n".join([" = ".join(item) for item in error_files]) else: error_text = "" return ( f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\nvocab : {vocab_size}\n{error_text}", new_vocal, ) def check_user(value): return gr.update(visible=not value), gr.update(visible=value) def calculate_train( name_project, epochs, learning_rate, batch_size_per_gpu, batch_size_type, max_samples, num_warmup_updates, finetune, ): path_project = os.path.join(path_data, name_project) file_duration = os.path.join(path_project, "duration.json") hop_length = 256 sampling_rate = 24000 if not os.path.isfile(file_duration): return ( epochs, learning_rate, batch_size_per_gpu, max_samples, num_warmup_updates, "project not found !", ) with open(file_duration, "r") as file: data = json.load(file) duration_list = data["duration"] max_sample_length = max(duration_list) * sampling_rate / hop_length total_samples = len(duration_list) total_duration = sum(duration_list) if torch.cuda.is_available(): gpu_count = torch.cuda.device_count() total_memory = 0 for i in range(gpu_count): gpu_properties = torch.cuda.get_device_properties(i) total_memory += gpu_properties.total_memory / (1024**3) # in GB elif torch.xpu.is_available(): gpu_count = torch.xpu.device_count() total_memory = 0 for i in range(gpu_count): gpu_properties = torch.xpu.get_device_properties(i) total_memory += gpu_properties.total_memory / (1024**3) elif torch.backends.mps.is_available(): gpu_count = 1 total_memory = psutil.virtual_memory().available / (1024**3) avg_gpu_memory = total_memory / gpu_count # rough estimate of batch size if batch_size_type == "frame": batch_size_per_gpu = max(int(38400 * (avg_gpu_memory - 5) / 75), int(max_sample_length)) elif batch_size_type == "sample": batch_size_per_gpu = int(200 / (total_duration / total_samples)) if total_samples < 64: max_samples = int(total_samples * 0.25) num_warmup_updates = max(num_warmup_updates, int(total_samples * 0.05)) # take 1.2M updates as the maximum max_updates = 1200000 if batch_size_type == "frame": mini_batch_duration = batch_size_per_gpu * gpu_count * hop_length / sampling_rate updates_per_epoch = total_duration / mini_batch_duration elif batch_size_type == "sample": updates_per_epoch = total_samples / batch_size_per_gpu / gpu_count epochs = int(max_updates / updates_per_epoch) if finetune: learning_rate = 1e-5 else: learning_rate = 7.5e-5 return ( epochs, learning_rate, batch_size_per_gpu, max_samples, num_warmup_updates, total_samples, ) def prune_checkpoint(checkpoint_path: str, new_checkpoint_path: str, save_ema: bool, safetensors: bool) -> str: try: checkpoint = torch.load(checkpoint_path, weights_only=True) print("Original Checkpoint Keys:", checkpoint.keys()) to_retain = "ema_model_state_dict" if save_ema else "model_state_dict" try: model_state_dict_to_retain = checkpoint[to_retain] except KeyError: return f"{to_retain} not found in the checkpoint." if safetensors: new_checkpoint_path = new_checkpoint_path.replace(".pt", ".safetensors") save_file(model_state_dict_to_retain, new_checkpoint_path) else: new_checkpoint_path = new_checkpoint_path.replace(".safetensors", ".pt") new_checkpoint = {"ema_model_state_dict": model_state_dict_to_retain} torch.save(new_checkpoint, new_checkpoint_path) return f"New checkpoint saved at: {new_checkpoint_path}" except Exception as e: return f"An error occurred: {e}" def expand_model_embeddings(ckpt_path, new_ckpt_path, num_new_tokens=42): seed = 666 random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False if ckpt_path.endswith(".safetensors"): ckpt = load_file(ckpt_path, device="cpu") ckpt = {"ema_model_state_dict": ckpt} elif ckpt_path.endswith(".pt"): ckpt = torch.load(ckpt_path, map_location="cpu") ema_sd = ckpt.get("ema_model_state_dict", {}) embed_key_ema = "ema_model.transformer.text_embed.text_embed.weight" old_embed_ema = ema_sd[embed_key_ema] vocab_old = old_embed_ema.size(0) embed_dim = old_embed_ema.size(1) vocab_new = vocab_old + num_new_tokens def expand_embeddings(old_embeddings): new_embeddings = torch.zeros((vocab_new, embed_dim)) new_embeddings[:vocab_old] = old_embeddings new_embeddings[vocab_old:] = torch.randn((num_new_tokens, embed_dim)) return new_embeddings ema_sd[embed_key_ema] = expand_embeddings(ema_sd[embed_key_ema]) if new_ckpt_path.endswith(".safetensors"): save_file(ema_sd, new_ckpt_path) elif new_ckpt_path.endswith(".pt"): torch.save(ckpt, new_ckpt_path) return vocab_new def vocab_count(text): return str(len(text.split(","))) def vocab_extend(project_name, symbols, model_type): if symbols == "": return "Symbols empty!" name_project = project_name path_project = os.path.join(path_data, name_project) file_vocab_project = os.path.join(path_project, "vocab.txt") file_vocab = os.path.join(path_data, "Emilia_ZH_EN_pinyin/vocab.txt") if not os.path.isfile(file_vocab): return f"the file {file_vocab} not found !" symbols = symbols.split(",") if symbols == []: return "Symbols to extend not found." with open(file_vocab, "r", encoding="utf-8-sig") as f: data = f.read() vocab = data.split("\n") vocab_check = set(vocab) miss_symbols = [] for item in symbols: item = item.replace(" ", "") if item in vocab_check: continue miss_symbols.append(item) if miss_symbols == []: return "Symbols are okay no need to extend." size_vocab = len(vocab) vocab.pop() for item in miss_symbols: vocab.append(item) vocab.append("") with open(file_vocab_project, "w", encoding="utf-8") as f: f.write("\n".join(vocab)) if model_type == "F5TTS_v1_Base": ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors")) elif model_type == "F5TTS_Base": ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt")) elif model_type == "E2TTS_Base": ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt")) vocab_size_new = len(miss_symbols) dataset_name = name_project.replace("_pinyin", "").replace("_char", "") new_ckpt_path = os.path.join(path_project_ckpts, dataset_name) os.makedirs(new_ckpt_path, exist_ok=True) # Add pretrained_ prefix to model when copying for consistency with finetune_cli.py new_ckpt_file = os.path.join(new_ckpt_path, "pretrained_" + os.path.basename(ckpt_path)) size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=vocab_size_new) vocab_new = "\n".join(miss_symbols) return f"vocab old size : {size_vocab}\nvocab new size : {size}\nvocab add : {vocab_size_new}\nnew symbols :\n{vocab_new}" def vocab_check(project_name, tokenizer_type): name_project = project_name path_project = os.path.join(path_data, name_project) file_metadata = os.path.join(path_project, "metadata.csv") file_vocab = os.path.join(path_data, "Emilia_ZH_EN_pinyin/vocab.txt") if not os.path.isfile(file_vocab): return f"the file {file_vocab} not found !", "" with open(file_vocab, "r", encoding="utf-8-sig") as f: data = f.read() vocab = data.split("\n") vocab = set(vocab) if not os.path.isfile(file_metadata): return f"the file {file_metadata} not found !", "" with open(file_metadata, "r", encoding="utf-8-sig") as f: data = f.read() miss_symbols = [] miss_symbols_keep = {} for item in data.split("\n"): sp = item.split("|") if len(sp) != 2: continue text = sp[1].strip() if tokenizer_type == "pinyin": text = convert_char_to_pinyin([text], polyphone=True)[0] for t in text: if t not in vocab and t not in miss_symbols_keep: miss_symbols.append(t) miss_symbols_keep[t] = t if miss_symbols == []: vocab_miss = "" info = "You can train using your language !" else: vocab_miss = ",".join(miss_symbols) info = f"The following {len(miss_symbols)} symbols are missing in your language\n\n" return info, vocab_miss def get_random_sample_prepare(project_name): name_project = project_name path_project = os.path.join(path_data, name_project) file_arrow = os.path.join(path_project, "raw.arrow") if not os.path.isfile(file_arrow): return "", None dataset = Dataset_.from_file(file_arrow) random_sample = dataset.shuffle(seed=random.randint(0, 1000)).select([0]) text = "[" + " , ".join(["' " + t + " '" for t in random_sample["text"][0]]) + "]" audio_path = random_sample["audio_path"][0] return text, audio_path def get_random_sample_transcribe(project_name): name_project = project_name path_project = os.path.join(path_data, name_project) file_metadata = os.path.join(path_project, "metadata.csv") if not os.path.isfile(file_metadata): return "", None data = "" with open(file_metadata, "r", encoding="utf-8-sig") as f: data = f.read() list_data = [] for item in data.split("\n"): sp = item.split("|") if len(sp) != 2: continue # fixed audio when it is absolute file_audio = get_correct_audio_path(sp[0], os.path.join(path_project, "wavs")) list_data.append([file_audio, sp[1]]) if list_data == []: return "", None random_item = random.choice(list_data) return random_item[1], random_item[0] def get_random_sample_infer(project_name): text, audio = get_random_sample_transcribe(project_name) return ( text, text, audio, ) def infer( project, file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, use_ema, speed, seed, remove_silence ): global last_checkpoint, last_device, tts_api, last_ema if not os.path.isfile(file_checkpoint): return None, "checkpoint not found!" if training_process is not None: device_test = "cpu" else: device_test = None if last_checkpoint != file_checkpoint or last_device != device_test or last_ema != use_ema or tts_api is None: if last_checkpoint != file_checkpoint: last_checkpoint = file_checkpoint if last_device != device_test: last_device = device_test if last_ema != use_ema: last_ema = use_ema vocab_file = os.path.join(path_data, project, "vocab.txt") tts_api = F5TTS( model=exp_name, ckpt_file=file_checkpoint, vocab_file=vocab_file, device=device_test, use_ema=use_ema ) print("update >> ", device_test, file_checkpoint, use_ema) if seed == -1: # -1 used for random seed = None with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: tts_api.infer( ref_file=ref_audio, ref_text=ref_text.strip(), gen_text=gen_text.strip(), nfe_step=nfe_step, speed=speed, remove_silence=remove_silence, file_wave=f.name, seed=seed, ) return f.name, tts_api.device, str(tts_api.seed) def check_finetune(finetune): return gr.update(interactive=finetune), gr.update(interactive=finetune), gr.update(interactive=finetune) def get_checkpoints_project(project_name, is_gradio=True): if project_name is None: return [], "" project_name = project_name.replace("_pinyin", "").replace("_char", "") if os.path.isdir(path_project_ckpts): files_checkpoints = glob(os.path.join(path_project_ckpts, project_name, "*.pt")) # Separate pretrained and regular checkpoints pretrained_checkpoints = [f for f in files_checkpoints if "pretrained_" in os.path.basename(f)] regular_checkpoints = [ f for f in files_checkpoints if "pretrained_" not in os.path.basename(f) and "model_last.pt" not in os.path.basename(f) ] last_checkpoint = [f for f in files_checkpoints if "model_last.pt" in os.path.basename(f)] # Sort regular checkpoints by number regular_checkpoints = sorted( regular_checkpoints, key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0]) ) # Combine in order: pretrained, regular, last files_checkpoints = pretrained_checkpoints + regular_checkpoints + last_checkpoint else: files_checkpoints = [] selelect_checkpoint = None if not files_checkpoints else files_checkpoints[0] if is_gradio: return gr.update(choices=files_checkpoints, value=selelect_checkpoint) return files_checkpoints, selelect_checkpoint def get_audio_project(project_name, is_gradio=True): if project_name is None: return [], "" project_name = project_name.replace("_pinyin", "").replace("_char", "") if os.path.isdir(path_project_ckpts): files_audios = glob(os.path.join(path_project_ckpts, project_name, "samples", "*.wav")) files_audios = sorted(files_audios, key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0])) files_audios = [item.replace("_gen.wav", "") for item in files_audios if item.endswith("_gen.wav")] else: files_audios = [] selelect_checkpoint = None if not files_audios else files_audios[0] if is_gradio: return gr.update(choices=files_audios, value=selelect_checkpoint) return files_audios, selelect_checkpoint def get_gpu_stats(): gpu_stats = "" if torch.cuda.is_available(): gpu_count = torch.cuda.device_count() for i in range(gpu_count): gpu_name = torch.cuda.get_device_name(i) gpu_properties = torch.cuda.get_device_properties(i) total_memory = gpu_properties.total_memory / (1024**3) # in GB allocated_memory = torch.cuda.memory_allocated(i) / (1024**2) # in MB reserved_memory = torch.cuda.memory_reserved(i) / (1024**2) # in MB gpu_stats += ( f"GPU {i} Name: {gpu_name}\n" f"Total GPU memory (GPU {i}): {total_memory:.2f} GB\n" f"Allocated GPU memory (GPU {i}): {allocated_memory:.2f} MB\n" f"Reserved GPU memory (GPU {i}): {reserved_memory:.2f} MB\n\n" ) elif torch.xpu.is_available(): gpu_count = torch.xpu.device_count() for i in range(gpu_count): gpu_name = torch.xpu.get_device_name(i) gpu_properties = torch.xpu.get_device_properties(i) total_memory = gpu_properties.total_memory / (1024**3) # in GB allocated_memory = torch.xpu.memory_allocated(i) / (1024**2) # in MB reserved_memory = torch.xpu.memory_reserved(i) / (1024**2) # in MB gpu_stats += ( f"GPU {i} Name: {gpu_name}\n" f"Total GPU memory (GPU {i}): {total_memory:.2f} GB\n" f"Allocated GPU memory (GPU {i}): {allocated_memory:.2f} MB\n" f"Reserved GPU memory (GPU {i}): {reserved_memory:.2f} MB\n\n" ) elif torch.backends.mps.is_available(): gpu_count = 1 gpu_stats += "MPS GPU\n" total_memory = psutil.virtual_memory().total / ( 1024**3 ) # Total system memory (MPS doesn't have its own memory) allocated_memory = 0 reserved_memory = 0 gpu_stats += ( f"Total system memory: {total_memory:.2f} GB\n" f"Allocated GPU memory (MPS): {allocated_memory:.2f} MB\n" f"Reserved GPU memory (MPS): {reserved_memory:.2f} MB\n" ) else: gpu_stats = "No GPU available" return gpu_stats def get_cpu_stats(): cpu_usage = psutil.cpu_percent(interval=1) memory_info = psutil.virtual_memory() memory_used = memory_info.used / (1024**2) memory_total = memory_info.total / (1024**2) memory_percent = memory_info.percent pid = os.getpid() process = psutil.Process(pid) nice_value = process.nice() cpu_stats = ( f"CPU Usage: {cpu_usage:.2f}%\n" f"System Memory: {memory_used:.2f} MB used / {memory_total:.2f} MB total ({memory_percent}% used)\n" f"Process Priority (Nice value): {nice_value}" ) return cpu_stats def get_combined_stats(): gpu_stats = get_gpu_stats() cpu_stats = get_cpu_stats() combined_stats = f"### GPU Stats\n{gpu_stats}\n\n### CPU Stats\n{cpu_stats}" return combined_stats def get_audio_select(file_sample): select_audio_ref = file_sample select_audio_gen = file_sample if file_sample is not None: select_audio_ref += "_ref.wav" select_audio_gen += "_gen.wav" return select_audio_ref, select_audio_gen with gr.Blocks() as app: gr.Markdown( """ # F5 TTS Automatic Finetune This is a local web UI for F5 TTS finetuning support. This app supports the following TTS models: * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching) * [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS) The pretrained checkpoints support English and Chinese. For tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussions/143) """ ) with gr.Row(): projects, projects_selelect = get_list_projects() tokenizer_type = gr.Radio(label="Tokenizer Type", choices=["pinyin", "char", "custom"], value="pinyin") project_name = gr.Textbox(label="Project Name", value="my_speak") bt_create = gr.Button("Create a New Project") with gr.Row(): cm_project = gr.Dropdown( choices=projects, value=projects_selelect, label="Project", allow_custom_value=True, scale=6 ) ch_refresh_project = gr.Button("Refresh", scale=1) bt_create.click(fn=create_data_project, inputs=[project_name, tokenizer_type], outputs=[cm_project]) with gr.Tabs(): with gr.TabItem("Transcribe Data"): gr.Markdown("""```plaintext Skip this step if you have your dataset, metadata.csv, and a folder wavs with all the audio files. ```""") ch_manual = gr.Checkbox(label="Audio from Path", value=False) mark_info_transcribe = gr.Markdown( """```plaintext Place your 'wavs' folder and 'metadata.csv' file in the '{your_project_name}' directory. my_speak/ │ └── dataset/ ├── audio1.wav └── audio2.wav ... ```""", visible=False, ) audio_speaker = gr.File(label="Voice", type="filepath", file_count="multiple") txt_lang = gr.Textbox(label="Language", value="English") bt_transcribe = bt_create = gr.Button("Transcribe") txt_info_transcribe = gr.Textbox(label="Info", value="") bt_transcribe.click( fn=transcribe_all, inputs=[cm_project, audio_speaker, txt_lang, ch_manual], outputs=[txt_info_transcribe], ) ch_manual.change(fn=check_user, inputs=[ch_manual], outputs=[audio_speaker, mark_info_transcribe]) random_sample_transcribe = gr.Button("Random Sample") with gr.Row(): random_text_transcribe = gr.Textbox(label="Text") random_audio_transcribe = gr.Audio(label="Audio", type="filepath") random_sample_transcribe.click( fn=get_random_sample_transcribe, inputs=[cm_project], outputs=[random_text_transcribe, random_audio_transcribe], ) with gr.TabItem("Vocab Check"): gr.Markdown("""```plaintext Check the vocabulary for fine-tuning Emilia_ZH_EN to ensure all symbols are included. For fine-tuning a new language. ```""") check_button = gr.Button("Check Vocab") txt_info_check = gr.Textbox(label="Info", value="") gr.Markdown("""```plaintext Using the extended model, you can finetune to a new language that is missing symbols in the vocab. This creates a new model with a new vocabulary size and saves it in your ckpts/project folder. ```""") exp_name_extend = gr.Radio( label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base" ) with gr.Row(): txt_extend = gr.Textbox( label="Symbols", value="", placeholder="To add new symbols, make sure to use ',' for each symbol", scale=6, ) txt_count_symbol = gr.Textbox(label="New Vocab Size", value="", scale=1) extend_button = gr.Button("Extend") txt_info_extend = gr.Textbox(label="Info", value="") txt_extend.change(vocab_count, inputs=[txt_extend], outputs=[txt_count_symbol]) check_button.click( fn=vocab_check, inputs=[cm_project, tokenizer_type], outputs=[txt_info_check, txt_extend] ) extend_button.click( fn=vocab_extend, inputs=[cm_project, txt_extend, exp_name_extend], outputs=[txt_info_extend] ) with gr.TabItem("Prepare Data"): gr.Markdown("""```plaintext Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt ```""") gr.Markdown( """```plaintext Place all your "wavs" folder and your "metadata.csv" file in your project name directory. Supported audio formats: "wav", "mp3", "aac", "flac", "m4a", "alac", "ogg", "aiff", "wma", "amr" Example wav format: my_speak/ │ ├── wavs/ │ ├── audio1.wav │ └── audio2.wav | ... │ └── metadata.csv File format metadata.csv: audio1|text1 or audio1.wav|text1 or your_path/audio1.wav|text1 audio2|text1 or audio2.wav|text1 or your_path/audio2.wav|text1 ... ```""" ) ch_tokenizern = gr.Checkbox(label="Create Vocabulary", value=False, visible=False) bt_prepare = bt_create = gr.Button("Prepare") txt_info_prepare = gr.Textbox(label="Info", value="") txt_vocab_prepare = gr.Textbox(label="Vocab", value="") bt_prepare.click( fn=create_metadata, inputs=[cm_project, ch_tokenizern], outputs=[txt_info_prepare, txt_vocab_prepare] ) random_sample_prepare = gr.Button("Random Sample") with gr.Row(): random_text_prepare = gr.Textbox(label="Tokenizer") random_audio_prepare = gr.Audio(label="Audio", type="filepath") random_sample_prepare.click( fn=get_random_sample_prepare, inputs=[cm_project], outputs=[random_text_prepare, random_audio_prepare] ) with gr.TabItem("Train Model"): gr.Markdown("""```plaintext The auto-setting is still experimental. Set a large value of epoch if not sure; and keep last N checkpoints if limited disk space. If you encounter a memory error, try reducing the batch size per GPU to a smaller number. ```""") with gr.Row(): exp_name = gr.Radio(label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"]) tokenizer_file = gr.Textbox(label="Tokenizer File") file_checkpoint_train = gr.Textbox(label="Path to the Pretrained Checkpoint") with gr.Row(): ch_finetune = bt_create = gr.Checkbox(label="Finetune") lb_samples = gr.Label(label="Samples") bt_calculate = bt_create = gr.Button("Auto Settings") with gr.Row(): epochs = gr.Number(label="Epochs") learning_rate = gr.Number(label="Learning Rate", step=0.5e-5) max_grad_norm = gr.Number(label="Max Gradient Norm") num_warmup_updates = gr.Number(label="Warmup Updates") with gr.Row(): batch_size_type = gr.Radio( label="Batch Size Type", choices=["frame", "sample"], info="frame is calculated as seconds * sampling_rate / hop_length", ) batch_size_per_gpu = gr.Number(label="Batch Size per GPU", info="N frames or N samples") grad_accumulation_steps = gr.Number( label="Gradient Accumulation Steps", info="Effective batch size is multiplied by this value" ) max_samples = gr.Number(label="Max Samples", info="Maximum number of samples per single GPU batch") with gr.Row(): save_per_updates = gr.Number( label="Save per Updates", info="Save intermediate checkpoints every N updates", minimum=10, ) keep_last_n_checkpoints = gr.Number( label="Keep Last N Checkpoints", step=1, precision=0, info="-1 to keep all, 0 to not save intermediate, > 0 to keep last N", minimum=-1, ) last_per_updates = gr.Number( label="Last per Updates", info="Save latest checkpoint with suffix _last.pt every N updates", minimum=10, ) gr.Radio(label="") # placeholder with gr.Row(): ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer") mixed_precision = gr.Radio(label="Mixed Precision", choices=["none", "fp16", "bf16"]) cd_logger = gr.Radio(label="Logger", choices=["none", "wandb", "tensorboard"]) with gr.Column(): start_button = gr.Button("Start Training") stop_button = gr.Button("Stop Training", interactive=False) if projects_selelect is not None: ( exp_name_value, learning_rate_value, batch_size_per_gpu_value, batch_size_type_value, max_samples_value, grad_accumulation_steps_value, max_grad_norm_value, epochs_value, num_warmup_updates_value, save_per_updates_value, keep_last_n_checkpoints_value, last_per_updates_value, finetune_value, file_checkpoint_train_value, tokenizer_type_value, tokenizer_file_value, mixed_precision_value, logger_value, bnb_optimizer_value, ) = load_settings(projects_selelect) # Assigning values to the respective components exp_name.value = exp_name_value learning_rate.value = learning_rate_value batch_size_per_gpu.value = batch_size_per_gpu_value batch_size_type.value = batch_size_type_value max_samples.value = max_samples_value grad_accumulation_steps.value = grad_accumulation_steps_value max_grad_norm.value = max_grad_norm_value epochs.value = epochs_value num_warmup_updates.value = num_warmup_updates_value save_per_updates.value = save_per_updates_value keep_last_n_checkpoints.value = keep_last_n_checkpoints_value last_per_updates.value = last_per_updates_value ch_finetune.value = finetune_value file_checkpoint_train.value = file_checkpoint_train_value tokenizer_type.value = tokenizer_type_value tokenizer_file.value = tokenizer_file_value mixed_precision.value = mixed_precision_value cd_logger.value = logger_value ch_8bit_adam.value = bnb_optimizer_value ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True) txt_info_train = gr.Textbox(label="Info", value="") list_audios, select_audio = get_audio_project(projects_selelect, False) select_audio_ref = select_audio select_audio_gen = select_audio if select_audio is not None: select_audio_ref += "_ref.wav" select_audio_gen += "_gen.wav" with gr.Row(): ch_list_audio = gr.Dropdown( choices=list_audios, value=select_audio, label="Audios", allow_custom_value=True, scale=6, interactive=True, ) bt_stream_audio = gr.Button("Refresh", scale=1) bt_stream_audio.click(fn=get_audio_project, inputs=[cm_project], outputs=[ch_list_audio]) cm_project.change(fn=get_audio_project, inputs=[cm_project], outputs=[ch_list_audio]) with gr.Row(): audio_ref_stream = gr.Audio(label="Original", type="filepath", value=select_audio_ref) audio_gen_stream = gr.Audio(label="Generate", type="filepath", value=select_audio_gen) ch_list_audio.change( fn=get_audio_select, inputs=[ch_list_audio], outputs=[audio_ref_stream, audio_gen_stream], ) start_button.click( fn=start_training, inputs=[ cm_project, exp_name, learning_rate, batch_size_per_gpu, batch_size_type, max_samples, grad_accumulation_steps, max_grad_norm, epochs, num_warmup_updates, save_per_updates, keep_last_n_checkpoints, last_per_updates, ch_finetune, file_checkpoint_train, tokenizer_type, tokenizer_file, mixed_precision, ch_stream, cd_logger, ch_8bit_adam, ], outputs=[txt_info_train, start_button, stop_button], ) stop_button.click(fn=stop_training, outputs=[txt_info_train, start_button, stop_button]) bt_calculate.click( fn=calculate_train, inputs=[ cm_project, epochs, learning_rate, batch_size_per_gpu, batch_size_type, max_samples, num_warmup_updates, ch_finetune, ], outputs=[ epochs, learning_rate, batch_size_per_gpu, max_samples, num_warmup_updates, lb_samples, ], ) ch_finetune.change( check_finetune, inputs=[ch_finetune], outputs=[file_checkpoint_train, tokenizer_file, tokenizer_type] ) def setup_load_settings(): output_components = [ exp_name, learning_rate, batch_size_per_gpu, batch_size_type, max_samples, grad_accumulation_steps, max_grad_norm, epochs, num_warmup_updates, save_per_updates, keep_last_n_checkpoints, last_per_updates, ch_finetune, file_checkpoint_train, tokenizer_type, tokenizer_file, mixed_precision, cd_logger, ch_8bit_adam, ] return output_components outputs = setup_load_settings() cm_project.change( fn=load_settings, inputs=[cm_project], outputs=outputs, ) ch_refresh_project.click( fn=load_settings, inputs=[cm_project], outputs=outputs, ) with gr.TabItem("Test Model"): gr.Markdown("""```plaintext Check the use_ema setting (True or False) for your model to see what works best for you. Set seed to -1 for random. ```""") exp_name = gr.Radio( label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base" ) list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False) with gr.Row(): nfe_step = gr.Number(label="NFE Step", value=32) speed = gr.Slider(label="Speed", value=1.0, minimum=0.3, maximum=2.0, step=0.1) seed = gr.Number(label="Random Seed", value=-1, minimum=-1) remove_silence = gr.Checkbox(label="Remove Silence") with gr.Row(): ch_use_ema = gr.Checkbox( label="Use EMA", value=True, info="Turn off at early stage might offer better results" ) cm_checkpoint = gr.Dropdown( choices=list_checkpoints, value=checkpoint_select, label="Checkpoints", allow_custom_value=True ) bt_checkpoint_refresh = gr.Button("Refresh") random_sample_infer = gr.Button("Random Sample") ref_text = gr.Textbox(label="Reference Text") ref_audio = gr.Audio(label="Reference Audio", type="filepath") gen_text = gr.Textbox(label="Text to Generate") random_sample_infer.click( fn=get_random_sample_infer, inputs=[cm_project], outputs=[ref_text, gen_text, ref_audio] ) with gr.Row(): txt_info_gpu = gr.Textbox("", label="Inference on Device :") seed_info = gr.Textbox(label="Used Random Seed :") check_button_infer = gr.Button("Inference") gen_audio = gr.Audio(label="Generated Audio", type="filepath") check_button_infer.click( fn=infer, inputs=[ cm_project, cm_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, ch_use_ema, speed, seed, remove_silence, ], outputs=[gen_audio, txt_info_gpu, seed_info], ) bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint]) cm_project.change(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint]) with gr.TabItem("Prune Checkpoint"): gr.Markdown("""```plaintext Reduce the Base model size from 5GB to 1.3GB. The new checkpoint file prunes out optimizer and etc., can be used for inference or finetuning afterward, but not able to resume pretraining. ```""") txt_path_checkpoint = gr.Textbox(label="Path to Checkpoint:") txt_path_checkpoint_small = gr.Textbox(label="Path to Output:") with gr.Row(): ch_save_ema = gr.Checkbox(label="Save EMA checkpoint", value=True) ch_safetensors = gr.Checkbox(label="Save with safetensors format", value=True) txt_info_reduse = gr.Textbox(label="Info", value="") reduse_button = gr.Button("Prune") reduse_button.click( fn=prune_checkpoint, inputs=[txt_path_checkpoint, txt_path_checkpoint_small, ch_save_ema, ch_safetensors], outputs=[txt_info_reduse], ) with gr.TabItem("System Info"): output_box = gr.Textbox(label="GPU and CPU Information", lines=20) def update_stats(): return get_combined_stats() update_button = gr.Button("Update Stats") update_button.click(fn=update_stats, outputs=output_box) def auto_update(): yield gr.update(value=update_stats()) gr.update(fn=auto_update, inputs=[], outputs=output_box) @click.command() @click.option("--port", "-p", default=None, type=int, help="Port to run the app on") @click.option("--host", "-H", default=None, help="Host to run the app on") @click.option( "--share", "-s", default=False, is_flag=True, help="Share the app via Gradio share link", ) @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access") def main(port, host, share, api): global app print("Starting app...") app.queue(api_open=api).launch(server_name=host, server_port=port, share=share, show_api=api) if __name__ == "__main__": main() ================================================ FILE: src/f5_tts/train/train.py ================================================ # training script. import os from importlib.resources import files import hydra from omegaconf import OmegaConf from f5_tts.model import CFM, Trainer from f5_tts.model.dataset import load_dataset from f5_tts.model.utils import get_tokenizer os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to root of project (local editable) @hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None) def main(model_cfg): model_cls = hydra.utils.get_class(f"f5_tts.model.{model_cfg.model.backbone}") model_arc = model_cfg.model.arch tokenizer = model_cfg.model.tokenizer mel_spec_type = model_cfg.model.mel_spec.mel_spec_type wandb_project = model_cfg.ckpts.get("wandb_project", "CFM-TTS") wandb_run_name = model_cfg.ckpts.get( "wandb_run_name", f"{model_cfg.model.name}_{mel_spec_type}_{model_cfg.model.tokenizer}_{model_cfg.datasets.name}", ) wandb_resume_id = model_cfg.ckpts.get("wandb_resume_id", None) # set text tokenizer if tokenizer != "custom": tokenizer_path = model_cfg.datasets.name else: tokenizer_path = model_cfg.model.tokenizer_path vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) # set model model = CFM( transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=model_cfg.model.mel_spec.n_mel_channels), mel_spec_kwargs=model_cfg.model.mel_spec, vocab_char_map=vocab_char_map, ) # init trainer trainer = Trainer( model, epochs=model_cfg.optim.epochs, learning_rate=model_cfg.optim.learning_rate, num_warmup_updates=model_cfg.optim.num_warmup_updates, save_per_updates=model_cfg.ckpts.save_per_updates, keep_last_n_checkpoints=model_cfg.ckpts.keep_last_n_checkpoints, checkpoint_path=str(files("f5_tts").joinpath(f"../../{model_cfg.ckpts.save_dir}")), batch_size_per_gpu=model_cfg.datasets.batch_size_per_gpu, batch_size_type=model_cfg.datasets.batch_size_type, max_samples=model_cfg.datasets.max_samples, grad_accumulation_steps=model_cfg.optim.grad_accumulation_steps, max_grad_norm=model_cfg.optim.max_grad_norm, logger=model_cfg.ckpts.logger, wandb_project=wandb_project, wandb_run_name=wandb_run_name, wandb_resume_id=wandb_resume_id, last_per_updates=model_cfg.ckpts.last_per_updates, log_samples=model_cfg.ckpts.log_samples, bnb_optimizer=model_cfg.optim.bnb_optimizer, mel_spec_type=mel_spec_type, is_local_vocoder=model_cfg.model.vocoder.is_local, local_vocoder_path=model_cfg.model.vocoder.local_path, model_cfg_dict=OmegaConf.to_container(model_cfg, resolve=True), ) train_dataset = load_dataset(model_cfg.datasets.name, tokenizer, mel_spec_kwargs=model_cfg.model.mel_spec) trainer.train( train_dataset, num_workers=model_cfg.datasets.num_workers, resumable_with_seed=666, # seed for shuffling dataset ) if __name__ == "__main__": main()