Repository: fishaudio/fish-speech
Branch: main
Commit: 49985a34a704
Files: 153
Total size: 676.6 KB
Directory structure:
gitextract_ft5t8lt3/
├── .dockerignore
├── .github/
│ ├── ISSUE_TEMPLATE/
│ │ ├── bug_report.yml
│ │ ├── config.yml
│ │ └── feature_request.yml
│ ├── pull_request_template.md
│ └── workflows/
│ ├── build-docker-image.yml
│ ├── docs.yml
│ └── stale.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .project-root
├── .readthedocs.yaml
├── API_FLAGS.txt
├── LICENSE
├── README.md
├── awesome_webui/
│ ├── .gitignore
│ ├── README.md
│ ├── eslint.config.js
│ ├── index.html
│ ├── package.json
│ ├── src/
│ │ ├── App.tsx
│ │ ├── components/
│ │ │ └── ui/
│ │ │ ├── alert.tsx
│ │ │ ├── badge.tsx
│ │ │ ├── button.tsx
│ │ │ ├── card.tsx
│ │ │ ├── collapsible.tsx
│ │ │ ├── dialog.tsx
│ │ │ ├── label.tsx
│ │ │ ├── scroll-area.tsx
│ │ │ ├── separator.tsx
│ │ │ ├── slider.tsx
│ │ │ ├── switch.tsx
│ │ │ ├── textarea.tsx
│ │ │ └── toggle-group.tsx
│ │ ├── index.css
│ │ └── main.tsx
│ ├── tsconfig.app.json
│ ├── tsconfig.json
│ ├── tsconfig.node.json
│ └── vite.config.ts
├── compose.base.yml
├── compose.yml
├── docker/
│ └── Dockerfile
├── dockerfile.dev
├── docs/
│ ├── CNAME
│ ├── README.ar.md
│ ├── README.ja.md
│ ├── README.ko.md
│ ├── README.pt-BR.md
│ ├── README.zh.md
│ ├── ar/
│ │ ├── finetune.md
│ │ ├── index.md
│ │ ├── inference.md
│ │ └── install.md
│ ├── en/
│ │ ├── finetune.md
│ │ ├── index.md
│ │ ├── inference.md
│ │ ├── install.md
│ │ └── server.md
│ ├── ja/
│ │ ├── finetune.md
│ │ ├── index.md
│ │ ├── inference.md
│ │ └── install.md
│ ├── ko/
│ │ ├── finetune.md
│ │ ├── index.md
│ │ ├── inference.md
│ │ └── install.md
│ ├── pt/
│ │ ├── finetune.md
│ │ ├── index.md
│ │ ├── inference.md
│ │ └── install.md
│ ├── requirements.txt
│ ├── stylesheets/
│ │ └── extra.css
│ └── zh/
│ ├── finetune.md
│ ├── index.md
│ ├── inference.md
│ └── install.md
├── entrypoint.sh
├── fish_speech/
│ ├── callbacks/
│ │ ├── __init__.py
│ │ └── grad_norm.py
│ ├── configs/
│ │ ├── base.yaml
│ │ ├── lora/
│ │ │ └── r_8_alpha_16.yaml
│ │ ├── modded_dac_vq.yaml
│ │ └── text2semantic_finetune.yaml
│ ├── content_sequence.py
│ ├── conversation.py
│ ├── datasets/
│ │ ├── concat_repeat.py
│ │ ├── protos/
│ │ │ ├── text-data.proto
│ │ │ ├── text_data_pb2.py
│ │ │ └── text_data_stream.py
│ │ ├── semantic.py
│ │ └── vqgan.py
│ ├── i18n/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── core.py
│ │ ├── locale/
│ │ │ ├── en_US.json
│ │ │ ├── es_ES.json
│ │ │ ├── ja_JP.json
│ │ │ ├── ko_KR.json
│ │ │ ├── pt_BR.json
│ │ │ └── zh_CN.json
│ │ └── scan.py
│ ├── inference_engine/
│ │ ├── __init__.py
│ │ ├── reference_loader.py
│ │ ├── utils.py
│ │ └── vq_manager.py
│ ├── models/
│ │ ├── dac/
│ │ │ ├── __init__.py
│ │ │ ├── inference.py
│ │ │ ├── modded_dac.py
│ │ │ └── rvq.py
│ │ └── text2semantic/
│ │ ├── __init__.py
│ │ ├── inference.py
│ │ ├── lit_module.py
│ │ ├── llama.py
│ │ └── lora.py
│ ├── scheduler.py
│ ├── text/
│ │ ├── __init__.py
│ │ └── clean.py
│ ├── tokenizer.py
│ ├── train.py
│ └── utils/
│ ├── __init__.py
│ ├── braceexpand.py
│ ├── context.py
│ ├── file.py
│ ├── instantiators.py
│ ├── logger.py
│ ├── logging_utils.py
│ ├── rich_utils.py
│ ├── schema.py
│ ├── spectrogram.py
│ └── utils.py
├── inference.ipynb
├── mkdocs.yml
├── pyproject.toml
├── pyrightconfig.json
└── tools/
├── api_client.py
├── api_server.py
├── llama/
│ ├── build_dataset.py
│ ├── eval_in_context.py
│ ├── merge_lora.py
│ └── quantize.py
├── run_webui.py
├── server/
│ ├── api_utils.py
│ ├── exception_handler.py
│ ├── inference.py
│ ├── model_manager.py
│ ├── model_utils.py
│ └── views.py
├── vqgan/
│ ├── create_train_split.py
│ └── extract_vq.py
└── webui/
├── __init__.py
├── inference.py
└── variables.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .dockerignore
================================================
# .dockerignore
# Git and version control
.git
.gitignore
.gitattributes
.gitmodules
# IDE and editor files
.vscode/
.idea/
*.swp
*.swo
*~
.DS_Store
Thumbs.db
# Python cache and build artifacts
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# Virtual environments
venv/
env/
ENV/
.venv/
.env/
# Testing
.pytest_cache/
.coverage
htmlcov/
.tox/
.nox/
coverage.xml
*.cover
.hypothesis/
# Jupyter Notebook
.ipynb_checkpoints
*.ipynb
# Logs
*.log
logs/
# Temporary files
tmp/
temp/
*.tmp
*.temp
# OS generated files
.DS_Store
.DS_Store?
._*
.Spotlight-V100
.Trashes
ehthumbs.db
Thumbs.db
# Docker files (except the one being used)
docker/
Dockerfile*
docker-compose*.yml
.dockerignore
# Checkpoints and models (should be mounted)
checkpoints/
models/
*.pth
*.ckpt
*.safetensors
*.bin
# Reference voices (should be mounted)
references/
# Generated audio files
*.wav
*.mp3
*.flac
*.ogg
generated_audio.wav
fake.wav
fake.npy
# Cache directories
.cache/
cache/
.uv_cache/
# Development files
.env
.env.local
.env.development
.env.test
.env.production
# Test files
test_*.py
*_test.py
tests/
# CI/CD
.github/
.gitlab-ci.yml
.travis.yml
.circleci/
azure-pipelines.yml
# Monitoring and profiling
.prof
*.prof
# Backup files
*.bak
*.backup
*.old
# Large data files
*.csv
*.jsonl
*.parquet
*.h5
*.hdf5
# Audio processing temporary files
*.tmp.wav
*.temp.wav
# OLD:
# .github
# results
# data
# *.filelist
# /data_server/target
# checkpoints
# .venv
================================================
FILE: .github/ISSUE_TEMPLATE/bug_report.yml
================================================
name: "🕷️ Bug report"
description: |
Please follow this template carefully to ensure we can address your issue quickly.
Make sure to provide as much detail as possible, including logs and screenshots.
labels:
- bug
body:
- type: checkboxes
attributes:
label: Self Checks
description: "To ensure timely help, please confirm the following:"
options:
- label: This template is only for bug reports. For questions, please visit [Discussions](https://github.com/fishaudio/fish-speech/discussions).
required: true
- label: I have thoroughly reviewed the project documentation (installation, training, inference) but couldn't find information to solve my problem. [English](https://speech.fish.audio/) [中文](https://speech.fish.audio/zh/) [日本語](https://speech.fish.audio/ja/) [Portuguese (Brazil)](https://speech.fish.audio/pt/)
required: true
- label: I have searched for existing issues, including closed ones. [Search issues](https://github.com/fishaudio/fish-speech/issues)
required: true
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/fishaudio/fish-speech/issues/515)).
required: true
- label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)"
required: true
- label: "Please do not modify this template and fill in all required fields."
required: true
- type: dropdown
attributes:
label: Cloud or Self Hosted
multiple: true
options:
- Cloud
- Self Hosted (Docker)
- Self Hosted (Source)
validations:
required: true
- type: textarea
attributes:
label: Environment Details
description: "Provide details such as OS, Python version, and any relevant software or dependencies."
placeholder: e.g., macOS 13.5, Python 3.10, torch==2.4.1, Gradio 4.44.0
validations:
required: true
- type: textarea
attributes:
label: Steps to Reproduce
description: |
Include detailed steps, screenshots, and logs. Use the correct markdown syntax for code blocks.
placeholder: |
1. Run the command `python -m tools.api_client -t "xxxxx"`
2. Observe the console output error: `ModuleNotFoundError: No module named 'pyaudio'` (with screenshots or logs will be better)
validations:
required: true
- type: textarea
attributes:
label: ✔️ Expected Behavior
placeholder: Describe what you expected to happen.
validations:
required: false
- type: textarea
attributes:
label: ❌ Actual Behavior
placeholder: Describe what actually happened.
validations:
required: false
================================================
FILE: .github/ISSUE_TEMPLATE/config.yml
================================================
blank_issues_enabled: false
contact_links:
- name: "\U0001F4E7 Discussions"
url: https://github.com/fishaudio/fish-speech/discussions
about: General discussions and request help from the community
================================================
FILE: .github/ISSUE_TEMPLATE/feature_request.yml
================================================
name: "⭐ Feature or enhancement request"
description: Propose something new.
labels:
- enhancement
body:
- type: checkboxes
attributes:
label: Self Checks
description: "To make sure we get to you in time, please check the following :)"
options:
- label: I have thoroughly reviewed the project documentation (installation, training, inference) but couldn't find any relevant information that meets my needs. [English](https://speech.fish.audio/) [中文](https://speech.fish.audio/zh/) [日本語](https://speech.fish.audio/ja/) [Portuguese (Brazil)](https://speech.fish.audio/pt/)
required: true
- label: I have searched for existing issues [search for existing issues]([https://github.com/langgenius/dify/issues](https://github.com/fishaudio/fish-speech/issues)), including closed ones.
required: true
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/fishaudio/fish-speech/issues/515)).
required: true
- label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)"
required: true
- label: "Please do not modify this template :) and fill in all the required fields."
required: true
- type: textarea
attributes:
label: 1. Is this request related to a challenge you're experiencing? Tell us your story.
description: |
Describe the specific problem or scenario you’re facing in detail. For example:
*"I was trying to use [feature] for [specific task], but encountered [issue]. This was frustrating because...."*
placeholder: Please describe the situation in as much detail as possible.
validations:
required: true
- type: textarea
attributes:
label: 2. What is your suggested solution?
description: |
Provide a clear description of the feature or enhancement you'd like to propose.
How would this feature solve your issue or improve the project?
placeholder: Describe your idea or proposed solution here.
validations:
required: true
- type: textarea
attributes:
label: 3. Additional context or comments
description: |
Any other relevant information, links, documents, or screenshots that provide clarity.
Use this section for anything not covered above.
placeholder: Add any extra details here.
validations:
required: false
- type: checkboxes
attributes:
label: 4. Can you help us with this feature?
description: |
Let us know if you're interested in contributing. This is not a commitment but a way to express interest in collaboration.
options:
- label: I am interested in contributing to this feature.
required: false
- type: markdown
attributes:
value: |
**Note:** Please submit only one request per issue to keep discussions focused and manageable.
================================================
FILE: .github/pull_request_template.md
================================================
**Is this PR adding new feature or fix a BUG?**
Add feature / Fix BUG.
**Is this pull request related to any issue? If yes, please link the issue.**
#xxx
================================================
FILE: .github/workflows/build-docker-image.yml
================================================
name: Build Docker Images
on:
push:
branches:
- main
tags:
- "v*"
jobs:
build:
runs-on: ubuntu-latest-16c64g
strategy:
matrix:
target: [webui, server]
backend: [cuda, cpu]
steps:
- uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Get Version
run: |
if [[ $GITHUB_REF == refs/tags/v* ]]; then
version=$(basename ${GITHUB_REF})
else
version=nightly
fi
echo "version=${version}" >> $GITHUB_ENV
echo "Current version: ${version}"
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_PAT }}
- name: Set platform for CPU builds
id: platform
run: |
if [ "${{ matrix.backend }}" = "cpu" ]; then
echo "platforms=linux/amd64,linux/arm64" >> $GITHUB_OUTPUT
else
echo "platforms=linux/amd64" >> $GITHUB_OUTPUT
fi
- name: Build and Push ${{ matrix.target }}-${{ matrix.backend }} Image
uses: docker/build-push-action@v6
with:
context: .
file: docker/Dockerfile
platforms: ${{ steps.platform.outputs.platforms }}
push: true
target: ${{ matrix.target }}
build-args: |
BACKEND=${{ matrix.backend }}
UV_EXTRA=${{ matrix.backend == 'cuda' && 'cu126' || 'cpu' }}
tags: |
fishaudio/fish-speech:${{ matrix.target }}-${{ matrix.backend }}-${{ env.version }}
fishaudio/fish-speech:${{ matrix.target }}-${{ matrix.backend }}
${{ (matrix.target == 'webui' && matrix.backend == 'cuda') && format('fishaudio/fish-speech:{0}', env.version) || '' }}
${{ (matrix.target == 'webui' && matrix.backend == 'cuda') && 'fishaudio/fish-speech:latest' || '' }}
outputs: type=image,oci-mediatypes=true,compression=zstd,compression-level=3,force-compression=true
cache-from: type=registry,ref=fishaudio/fish-speech:${{ matrix.target }}-${{ matrix.backend }}
cache-to: type=inline
update-readme:
runs-on: ubuntu-latest
needs: build
if: github.ref == 'refs/heads/main'
steps:
- name: Push README to Dockerhub
uses: peter-evans/dockerhub-description@v4
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_PAT }}
repository: fishaudio/fish-speech
================================================
FILE: .github/workflows/docs.yml
================================================
name: docs
on:
push:
branches:
- main
paths:
- 'docs/**'
- 'mkdocs.yml'
permissions:
contents: write
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Configure Git Credentials
run: |
git config user.name github-actions[bot]
git config user.email 41898282+github-actions[bot]@users.noreply.github.com
- uses: actions/setup-python@v5
with:
python-version: 3.x
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
- uses: actions/cache@v4
with:
key: mkdocs-material-${{ env.cache_id }}
path: .cache
restore-keys: |
mkdocs-material-
- run: pip install -r docs/requirements.txt
- run: mkdocs gh-deploy --force
================================================
FILE: .github/workflows/stale.yml
================================================
name: Close inactive issues
on:
schedule:
- cron: "0 0 * * *"
jobs:
close-issues:
runs-on: ubuntu-latest
permissions:
issues: write
pull-requests: write
steps:
- uses: actions/stale@v9
with:
days-before-issue-stale: 30
days-before-issue-close: 14
stale-issue-label: "stale"
stale-issue-message: "This issue is stale because it has been open for 30 days with no activity."
close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale."
days-before-pr-stale: 30
days-before-pr-close: 30
stale-pr-label: "stale"
stale-pr-message: "This PR is stale because it has been open for 30 days with no activity."
close-pr-message: "This PR was closed because it has been inactive for 30 days since being marked as stale."
repo-token: ${{ secrets.GITHUB_TOKEN }}
================================================
FILE: .gitignore
================================================
# =============================================================================
# Fish Speech - .gitignore
# =============================================================================
# Operating System Files
# -----------------------
.DS_Store
.DS_Store?
._*
.Spotlight-V100
.Trashes
ehthumbs.db
Thumbs.db
# IDEs and Editors
# ----------------
.vscode/
.idea/
*.swp
*.swo
*~
# Python
# ------
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# Virtual Environments
# --------------------
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
/fishenv/
# Project Dependencies
# --------------------
.pdm-python
/fish_speech.egg-info
# Data and Model Files
# --------------------
data/
results/
checkpoints/
references/
demo-audios/
example/
filelists/
*.filelist
# Audio Files
# -----------
*.wav
*.mp3
*.flac
*.ogg
*.m4a
# Data Files
# ----------
*.npy
*.npz
*.pkl
*.pickle
*.lab
/fish_speech/text/cmudict_cache.pickle
# Cache and Temporary Files
# --------------------------
/.cache/
/.gradio/
/.locale/
.pgx.*
*log
*.log
site/
# External Tools
# --------------
ffmpeg.exe
ffprobe.exe
/faster_whisper/
# Server Related
# --------------
/data_server/target/
# Test Files
# ----------
/*.test.sh
asr-label*
================================================
FILE: .pre-commit-config.yaml
================================================
ci:
autoupdate_schedule: monthly
repos:
- repo: https://github.com/pycqa/isort
rev: 8.0.1
hooks:
- id: isort
args: [--profile=black]
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 26.1.0
hooks:
- id: black
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: end-of-file-fixer
- id: check-yaml
- id: check-json
- id: mixed-line-ending
args: ["--fix=lf"]
- id: check-added-large-files
args: ["--maxkb=5000"]
================================================
FILE: .project-root
================================================
================================================
FILE: .readthedocs.yaml
================================================
# Read the Docs configuration file for MkDocs projects
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
# Required
version: 2
# Set the version of Python and other tools you might need
build:
os: ubuntu-22.04
tools:
python: "3.12"
mkdocs:
configuration: mkdocs.yml
# Optionally declare the Python requirements required to build your docs
python:
install:
- requirements: docs/requirements.txt
================================================
FILE: API_FLAGS.txt
================================================
# --infer
--api
--listen 0.0.0.0:8080 \
--llama-checkpoint-path "checkpoints/openaudio-s1-mini" \
--decoder-checkpoint-path "checkpoints/openaudio-s1-mini/codec.pth" \
--decoder-config-name modded_dac_vq
================================================
FILE: LICENSE
================================================
# FISH AUDIO RESEARCH LICENSE AGREEMENT
**Last Updated: March 7, 2026**
## I. INTRODUCTION
This Agreement applies to any individual person or entity ("You", "Your" or "Licensee") that uses or distributes any portion or element of the Fish Audio Materials or Derivative Works thereof for any Research, Non-Commercial, or Commercial purpose. Capitalized terms not otherwise defined herein are defined in Section V below.
This Agreement is intended to allow research and non-commercial uses of the Materials free of charge. Any Commercial use of the Materials requires a separate license from Fish Audio.
By clicking "I Accept" or by using, distributing, or accessing any portion or element of the Fish Audio Materials or Derivative Works, You agree that You have read, understood and are bound by the terms of this Agreement. If You are acting on behalf of a company, organization or other entity, then "You" includes you and that entity, and You agree that You: (i) are an authorized representative of such entity with the authority to bind such entity to this Agreement, and (ii) You agree to the terms of this Agreement on that entity's behalf.
## II. RESEARCH & NON-COMMERCIAL USE LICENSE
Subject to the terms of this Agreement, Fish Audio grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Fish Audio's intellectual property or other rights owned by Fish Audio embodied in the Fish Audio Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Fish Audio Materials for any Research or Non-Commercial Purpose.
"Research Purpose" means academic or scientific advancement, and in each case, is not primarily intended for commercial advantage or monetary compensation to You or others.
"Non-Commercial Purpose" means any purpose other than a Research Purpose that is not primarily intended for commercial advantage or monetary compensation to You or others, such as personal use (i.e., hobbyist) or evaluation and testing.
## III. COMMERCIAL USE
**Any use of the Fish Audio Materials or Derivative Works for a Commercial Purpose requires a separate written license agreement from Fish Audio.** No commercial rights are granted under this Agreement.
"Commercial Purpose" means any purpose other than a Research Purpose or Non-Commercial Purpose that is primarily intended for or directed toward commercial advantage or monetary compensation to You or others, including but not limited to: (i) creating, modifying, or distributing Your product or service, including via a hosted service or application programming interface, (ii) Your business's or organization's internal operations, and (iii) any use in connection with a product or service for which You charge a fee or generate revenue, whether directly or indirectly.
To obtain a commercial license, please contact Fish Audio at:
- **Website:** [https://fish.audio](https://fish.audio)
- **Email:** business@fish.audio
## IV. GENERAL TERMS
Your Research and Non-Commercial License under this Agreement is subject to the following terms.
### a. Distribution & Attribution
If You distribute or make available the Fish Audio Materials or a Derivative Work to a third party, or a product or service that uses any portion of them, You shall: (i) provide a copy of this Agreement to that third party, (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This model is licensed under the Fish Audio Research License, Copyright © 39 AI, INC. All Rights Reserved.", and (iii) prominently display "Built with Fish Audio" on a related website, user interface, blogpost, about page, or product documentation.
If You create a Derivative Work, You may add your own attribution notice(s) to the "Notice" text file included with that Derivative Work, provided that You clearly indicate which attributions apply to the Fish Audio Materials and state in the "Notice" text file that You changed the Fish Audio Materials and how it was modified.
### b. Use Restrictions
Your use of the Fish Audio Materials and Derivative Works, including any output or results of the Fish Audio Materials or Derivative Works, must comply with applicable laws and regulations (including Trade Control Laws and equivalent regulations) and adhere to Fish Audio's Acceptable Use Policy, which is hereby incorporated by reference.
Furthermore, You will not use the Fish Audio Materials or Derivative Works, or any output or results of the Fish Audio Materials or Derivative Works, to create or improve any foundational generative AI model (excluding the Models or Derivative Works).
### c. Intellectual Property
**(i) Trademark License.** No trademark licenses are granted under this Agreement, and in connection with the Fish Audio Materials or Derivative Works, You may not use any name or mark owned by or associated with Fish Audio or any of its Affiliates, except as required under Section IV(a) herein.
**(ii) Ownership of Derivative Works.** As between You and Fish Audio, You are the owner of Derivative Works You create, subject to Fish Audio's ownership of the Fish Audio Materials and any Derivative Works made by or for Fish Audio.
**(iii) Ownership of Outputs.** As between You and Fish Audio, You own any outputs generated from the Models or Derivative Works to the extent permitted by applicable law.
**(iv) Disputes.** If You or Your Affiliate(s) institute litigation or other proceedings against Fish Audio (including a cross-claim or counterclaim in a lawsuit) alleging that the Fish Audio Materials, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Fish Audio from and against any claim by any third party arising out of or related to Your use or distribution of the Fish Audio Materials or Derivative Works in violation of this Agreement.
**(v) Feedback.** From time to time, You may provide Fish Audio with verbal and/or written suggestions, comments or other feedback related to Fish Audio's existing or prospective technology, products or services (collectively, "Feedback"). You are not obligated to provide Fish Audio with Feedback, but to the extent that You do, You hereby grant Fish Audio a perpetual, irrevocable, royalty-free, fully-paid, sub-licensable, transferable, non-exclusive, worldwide right and license to exploit the Feedback in any manner without restriction. Your Feedback is provided "AS IS" and You make no warranties whatsoever about any Feedback.
### d. Disclaimer of Warranty
UNLESS REQUIRED BY APPLICABLE LAW, THE FISH AUDIO MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OR LAWFULNESS OF USING OR REDISTRIBUTING THE FISH AUDIO MATERIALS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE FISH AUDIO MATERIALS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS.
### e. Limitation of Liability
IN NO EVENT WILL FISH AUDIO OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF FISH AUDIO OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
### f. Term and Termination
The term of this Agreement will commence upon Your acceptance of this Agreement or access to the Fish Audio Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Fish Audio may terminate this Agreement if You are in breach of any term or condition of this Agreement. Upon termination of this Agreement, You shall delete and cease use of any Fish Audio Materials or Derivative Works. Sections IV(d), (e), and (g) shall survive the termination of this Agreement.
### g. Governing Law
This Agreement will be governed by and construed in accordance with the laws of the United States and the State of California without regard to choice of law principles, and the UN Convention on Contracts for International Sale of Goods does not apply to this Agreement.
## V. DEFINITIONS
**"Affiliate(s)"** means any entity that directly or indirectly controls, is controlled by, or is under common control with the subject entity; for purposes of this definition, "control" means direct or indirect ownership or control of more than 50% of the voting interests of the subject entity.
**"Agreement"** means this Fish Audio Research License Agreement.
**"Derivative Work(s)"** means (a) any derivative work of the Fish Audio Materials as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model's output, including "fine tune" and "low-rank adaptation" models derived from a Model or a Model's output, but do not include the output of any Model.
**"Documentation"** means any specifications, manuals, documentation, and other written information provided by Fish Audio related to the Software or Models.
**"Fish Audio"** or **"we"** means 39 AI, INC. and its Affiliates.
**"Model(s)"** means, collectively, Fish Audio's proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing.
**"Software"** means Fish Audio's proprietary software made available under this Agreement now or in the future.
**"Fish Audio Materials"** means, collectively, Fish Audio's proprietary Models, Software and Documentation (and any portion or combination thereof) made available under this Agreement.
**"Trade Control Laws"** means any applicable U.S. and non-U.S. export control and trade sanctions laws and regulations.
================================================
FILE: README.md
================================================
Fish Speech
**English** | [简体中文](docs/README.zh.md) | [Portuguese](docs/README.pt-BR.md) | [日本語](docs/README.ja.md) | [한국어](docs/README.ko.md) | [العربية](docs/README.ar.md)
> [!IMPORTANT]
> **License Notice**
> This codebase and its associated model weights are released under **[FISH AUDIO RESEARCH LICENSE](LICENSE)**. Please refer to [LICENSE](LICENSE) for more details. We will take action against any violation of the license.
> [!WARNING]
> **Legal Disclaimer**
> We do not hold any responsibility for any illegal usage of the codebase. Please refer to your local laws about DMCA and other related laws.
## Quick Start
### For Human
Here are the official documents for Fish Audio S2, follow the instructions to get started easily.
- [Installation](https://speech.fish.audio/install/)
- [Command Line Inference](https://speech.fish.audio/inference/#command-line-inference)
- [WebUI Inference](https://speech.fish.audio/inference/#webui-inference)
- [Server Inference](https://speech.fish.audio/server/)
- [Docker Setup](https://speech.fish.audio/install/#docker-setup)
> [!IMPORTANT]
> **For SGLang server, please read [SGLang-Omni README](https://github.com/sgl-project/sglang-omni/blob/main/sglang_omni/models/fishaudio_s2_pro/README.md).**
### For LLM Agent
```
Install and configure Fish-Audio S2 by following the instructions here: https://speech.fish.audio/install/
```
## Fish Audio S2 Pro
**State-of-the-art multilingual text-to-speech (TTS) system, redefining the boundaries of voice generation.**
Fish Audio S2 Pro is the most advanced multimodal model developed by [Fish Audio](https://fish.audio/). Trained on over **10 million hours** of audio data covering more than **80 languages**, S2 Pro combines a **Dual-Autoregressive (Dual-AR)** architecture with reinforcement learning (RL) alignment to generate speech that is exceptionally natural, realistic, and emotionally rich, leading the competition among both open-source and closed-source systems.
The core strength of S2 Pro lies in its support for **sub-word level** fine-grained control of prosody and emotion using natural language tags (e.g., `[whisper]`, `[excited]`, `[angry]`), while natively supporting multi-speaker and multi-turn conversation generation.
Visit the [Fish Audio website](https://fish.audio/) for a live playground, or read our [technical report](https://arxiv.org/abs/2603.08823) and [blog post](https://fish.audio/blog/fish-audio-open-sources-s2/) for more details.
### Model Variants
| Model | Size | Availability | Description |
|------|------|-------------|-------------|
| S2-Pro | 4B parameters | [HuggingFace](https://huggingface.co/fishaudio/s2-pro) | Full-featured flagship model with maximum quality and stability |
More details of the model can be found in the [technical report](https://arxiv.org/abs/2411.01156).
## Benchmark Results
| Benchmark | Fish Audio S2 |
|------|------|
| Seed-TTS Eval — WER (Chinese) | **0.54%** (best overall) |
| Seed-TTS Eval — WER (English) | **0.99%** (best overall) |
| Audio Turing Test (with instruction) | **0.515** posterior mean |
| EmergentTTS-Eval — Win Rate | **81.88%** (highest overall) |
| Fish Instruction Benchmark — TAR | **93.3%** |
| Fish Instruction Benchmark — Quality | **4.51 / 5.0** |
| Multilingual (MiniMax Testset) — Best WER | **11 of 24** languages |
| Multilingual (MiniMax Testset) — Best SIM | **17 of 24** languages |
On Seed-TTS Eval, S2 achieves the lowest WER among all evaluated models including closed-source systems: Qwen3-TTS (0.77/1.24), MiniMax Speech-02 (0.99/1.90), Seed-TTS (1.12/2.25). On the Audio Turing Test, 0.515 surpasses Seed-TTS (0.417) by 24% and MiniMax-Speech (0.387) by 33%. On EmergentTTS-Eval, S2 achieves particularly strong results in paralinguistics (91.61% win rate), questions (84.41%), and syntactic complexity (83.39%).
## Highlights
### Fine-Grained Inline Control via Natural Language
S2 Pro brings unprecedented "soul" to speech. Using simple `[tag]` syntax, you can precisely embed emotional instructions at any position in the text.
- **15,000+ Unique Tags Supported**: Not limited to fixed presets; S2 supports **free-form text descriptions**. Try `[whisper in small voice]`, `[professional broadcast tone]`, or `[pitch up]`.
- **Rich Emotion Library**:
`[pause]` `[emphasis]` `[laughing]` `[inhale]` `[chuckle]` `[tsk]` `[singing]` `[excited]` `[laughing tone]` `[interrupting]` `[chuckling]` `[excited tone]` `[volume up]` `[echo]` `[angry]` `[low volume]` `[sigh]` `[low voice]` `[whisper]` `[screaming]` `[shouting]` `[loud]` `[surprised]` `[short pause]` `[exhale]` `[delight]` `[panting]` `[audience laughter]` `[with strong accent]` `[volume down]` `[clearing throat]` `[sad]` `[moaning]` `[shocked]`
### Innovative Dual-Autoregressive (Dual-AR) Architecture
S2 Pro adopts a master-slave Dual-AR architecture consisting of a decoder-only transformer and an RVQ audio codec (10 codebooks, ~21 Hz):
- **Slow AR (4B parameters)**: Operates along the time axis, predicting the primary semantic codebook.
- **Fast AR (400M parameters)**: Generates the remaining 9 residual codebooks at each time step, reconstructing exquisite acoustic details.
This asymmetric design achieves peak audio fidelity while significantly boosting inference speed.
### Reinforcement Learning (RL) Alignment
S2 Pro utilizes **Group Relative Policy Optimization (GRPO)** for post-training alignment. We use the same model suite for data cleaning and annotation directly as Reward Models, perfectly resolving the distribution mismatch between pre-training data and post-training objectives.
- **Multi-Dimensional Reward Signals**: Comprehensively evaluates semantic accuracy, instruction adherence, acoustic preference scoring, and timbre similarity to ensure every second of generated speech feels intuitive to humans.
### Extreme Streaming Performance (Powered by SGLang)
As the Dual-AR architecture is structurally isomorphic to standard LLMs, S2 Pro natively supports all SGLang inference acceleration features, including Continuous Batching, Paged KV Cache, CUDA Graph, and RadixAttention-based Prefix Caching.
**Performance on a single NVIDIA H200 GPU:**
- **Real-Time Factor (RTF)**: 0.195
- **Time-to-First-Audio (TTFA)**: ~100 ms
- **Extreme Throughput**: 3,000+ acoustic tokens/s while maintaining RTF < 0.5
### Robust Multilingual Support
S2 Pro supports over 80 languages without requiring phonemes or language-specific preprocessing:
- **Tier 1**: Japanese (ja), English (en), Chinese (zh)
- **Tier 2**: Korean (ko), Spanish (es), Portuguese (pt), Arabic (ar), Russian (ru), French (fr), German (de)
- **Global Coverage**: sv, it, tr, no, nl, cy, eu, ca, da, gl, ta, hu, fi, pl, et, hi, la, ur, th, vi, jw, bn, yo, xsl, cs, sw, nn, he, ms, uk, id, kk, bg, lv, my, tl, sk, ne, fa, af, el, bo, hr, ro, sn, mi, yi, am, be, km, is, az, sd, br, sq, ps, mn, ht, ml, sr, sa, te, ka, bs, pa, lt, kn, si, hy, mr, as, gu, fo, etc.
### Native Multi-Speaker Generation
Fish Audio S2 allows users to upload reference audio containing multiple speakers, and the model processes each speaker's features via the `<|speaker:i|>` token. You can then control the model's performance via speaker ID tokens, enabling a single generation to include multiple speakers. There is no longer a need to upload separate reference audio for each individual speaker.
### Multi-Turn Generation
Thanks to the expansion of the model context, our model can now leverage previous information to improve the expressiveness of subsequent generated content, thereby increasing the naturalness of the dialogue.
### Rapid Voice Cloning
Fish Audio S2 supports accurate voice cloning using short reference samples (typically 10-30 seconds). The model captures timbre, speaking style, and emotional tendencies, producing realistic and consistent cloned voices without additional fine-tuning.
For SGLang Server usage, please refer to the [SGLang-Omni README](https://github.com/sgl-project/sglang-omni/blob/main/sglang_omni/models/fishaudio_s2_pro/README.md).
---
## Credits
- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
- [GPT VITS](https://github.com/innnky/gpt-vits)
- [MQTTS](https://github.com/b04901014/MQTTS)
- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
- [Qwen3](https://github.com/QwenLM/Qwen3)
## Tech Report
```bibtex
@misc{fish-speech-v1.4,
title={Fish-Speech: Leveraging Large Language Models for Advanced Multilingual Text-to-Speech Synthesis},
author={Shijia Liao and Yuxuan Wang and Tianyu Li and Yifan Cheng and Ruoyi Zhang and Rongzhi Zhou and Yijin Xing},
year={2024},
eprint={2411.01156},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2411.01156},
}
@misc{liao2026fishaudios2technical,
title={Fish Audio S2 Technical Report},
author={Shijia Liao and Yuxuan Wang and Songting Liu and Yifan Cheng and Ruoyi Zhang and Tianyu Li and Shidong Li and Yisheng Zheng and Xingwei Liu and Qingzheng Wang and Zhizhuo Zhou and Jiahua Liu and Xin Chen and Dawei Han},
year={2026},
eprint={2603.08823},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2603.08823},
}
```
================================================
FILE: awesome_webui/.gitignore
================================================
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
pnpm-debug.log*
lerna-debug.log*
node_modules
dist
dist-ssr
*.local
# Editor directories and files
.vscode/*
!.vscode/extensions.json
.idea
.DS_Store
*.suo
*.ntvs*
*.njsproj
*.sln
*.sw?
================================================
FILE: awesome_webui/README.md
================================================
# React + TypeScript + Vite
This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules.
Currently, two official plugins are available:
- [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react) uses [Babel](https://babeljs.io/) (or [oxc](https://oxc.rs) when used in [rolldown-vite](https://vite.dev/guide/rolldown)) for Fast Refresh
- [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react-swc) uses [SWC](https://swc.rs/) for Fast Refresh
## React Compiler
The React Compiler is currently not compatible with SWC. See [this issue](https://github.com/vitejs/vite-plugin-react/issues/428) for tracking the progress.
## Expanding the ESLint configuration
If you are developing a production application, we recommend updating the configuration to enable type-aware lint rules:
```js
export default defineConfig([
globalIgnores(['dist']),
{
files: ['**/*.{ts,tsx}'],
extends: [
// Other configs...
// Remove tseslint.configs.recommended and replace with this
tseslint.configs.recommendedTypeChecked,
// Alternatively, use this for stricter rules
tseslint.configs.strictTypeChecked,
// Optionally, add this for stylistic rules
tseslint.configs.stylisticTypeChecked,
// Other configs...
],
languageOptions: {
parserOptions: {
project: ['./tsconfig.node.json', './tsconfig.app.json'],
tsconfigRootDir: import.meta.dirname,
},
// other options...
},
},
])
```
You can also install [eslint-plugin-react-x](https://github.com/Rel1cx/eslint-react/tree/main/packages/plugins/eslint-plugin-react-x) and [eslint-plugin-react-dom](https://github.com/Rel1cx/eslint-react/tree/main/packages/plugins/eslint-plugin-react-dom) for React-specific lint rules:
```js
// eslint.config.js
import reactX from 'eslint-plugin-react-x'
import reactDom from 'eslint-plugin-react-dom'
export default defineConfig([
globalIgnores(['dist']),
{
files: ['**/*.{ts,tsx}'],
extends: [
// Other configs...
// Enable lint rules for React
reactX.configs['recommended-typescript'],
// Enable lint rules for React DOM
reactDom.configs.recommended,
],
languageOptions: {
parserOptions: {
project: ['./tsconfig.node.json', './tsconfig.app.json'],
tsconfigRootDir: import.meta.dirname,
},
// other options...
},
},
])
```
================================================
FILE: awesome_webui/eslint.config.js
================================================
import js from '@eslint/js'
import globals from 'globals'
import reactHooks from 'eslint-plugin-react-hooks'
import reactRefresh from 'eslint-plugin-react-refresh'
import tseslint from 'typescript-eslint'
import { defineConfig, globalIgnores } from 'eslint/config'
export default defineConfig([
globalIgnores(['dist']),
{
files: ['**/*.{ts,tsx}'],
extends: [
js.configs.recommended,
tseslint.configs.recommended,
reactHooks.configs.flat.recommended,
reactRefresh.configs.vite,
],
languageOptions: {
ecmaVersion: 2020,
globals: globals.browser,
},
},
])
================================================
FILE: awesome_webui/index.html
================================================
Awesome WebUI
================================================
FILE: awesome_webui/package.json
================================================
{
"name": "awesome_webui",
"private": true,
"version": "0.0.0",
"type": "module",
"scripts": {
"dev": "vite",
"build": "tsc -b && vite build",
"lint": "eslint .",
"preview": "vite preview"
},
"dependencies": {
"@radix-ui/react-collapsible": "^1.1.12",
"@radix-ui/react-dialog": "^1.1.15",
"@radix-ui/react-label": "^2.1.8",
"@radix-ui/react-scroll-area": "^1.2.10",
"@radix-ui/react-separator": "^1.1.8",
"@radix-ui/react-slider": "^1.3.6",
"@radix-ui/react-slot": "^1.2.4",
"@radix-ui/react-switch": "^1.2.6",
"@radix-ui/react-toggle-group": "^1.1.11",
"@tailwindcss/vite": "^4.2.1",
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
"lucide-react": "^0.577.0",
"react": "^19.2.0",
"react-dom": "^19.2.0",
"tailwind-merge": "^3.5.0",
"tailwindcss": "^4.2.1"
},
"devDependencies": {
"@eslint/js": "^9.39.1",
"@types/node": "^24.10.1",
"@types/react": "^19.2.7",
"@types/react-dom": "^19.2.3",
"@vitejs/plugin-react-swc": "^4.2.2",
"eslint": "^9.39.1",
"eslint-plugin-react-hooks": "^7.0.1",
"eslint-plugin-react-refresh": "^0.4.24",
"globals": "^16.5.0",
"typescript": "~5.9.3",
"typescript-eslint": "^8.48.0",
"vite": "^7.3.1"
}
}
================================================
FILE: awesome_webui/src/App.tsx
================================================
import { useEffect, useRef, useState } from 'react'
import {
AudioLines,
ChevronDown,
CircleAlert,
Copy,
Download,
FileText,
Info,
LoaderCircle,
Plus,
Settings2,
Upload,
} from 'lucide-react'
import { Alert, AlertDescription, AlertTitle } from '@/components/ui/alert'
import { Badge } from '@/components/ui/badge'
import { Button } from '@/components/ui/button'
import {
Card,
CardContent,
CardDescription,
CardHeader,
CardTitle,
} from '@/components/ui/card'
import {
Collapsible,
CollapsibleContent,
CollapsibleTrigger,
} from '@/components/ui/collapsible'
import {
Dialog,
DialogContent,
DialogDescription,
DialogFooter,
DialogHeader,
DialogTitle,
} from '@/components/ui/dialog'
import { Label } from '@/components/ui/label'
import { ScrollArea } from '@/components/ui/scroll-area'
import { Separator } from '@/components/ui/separator'
import { Slider } from '@/components/ui/slider'
import { Switch } from '@/components/ui/switch'
import { Textarea } from '@/components/ui/textarea'
import { ToggleGroup, ToggleGroupItem } from '@/components/ui/toggle-group'
type AudioFormat = 'mp3' | 'wav' | 'pcm' | 'opus'
type LatencyMode = 'normal' | 'balanced'
const defaultInputText = `[excited, joyful tone] We're going to DISNEY WORLD! [squeal of delight] I've been saving for [emphasis] three years [breathless] and finally, FINALLY we can go! The look on your face right now is worth every extra shift I worked!
[angry] After everything we've been through [break] I can't believe you would [emphasize] betray me like this. I gave you EVERYTHING! And now I'm left with nothing but memories and broken promises!`
type ControlsState = {
chunkLength: number
maxNewTokens: number
temperature: number
topP: number
repetitionPenalty: number
normalize: boolean
format: AudioFormat
latency: LatencyMode
}
type Metrics = {
textLength: number
ttftMs: number
receivedKb: number
}
type StatusState = {
tone: 'error' | 'info'
message: string
}
type ReferenceItem = {
id: number
name: string
audio: ArrayBuffer
text: string
previewUrl: string
}
type SpeakerGroup = {
id: number
references: ReferenceItem[]
}
type PendingReference = {
mode: 'create' | 'edit'
speakerId: number
referenceId?: number
name: string
audio?: ArrayBuffer
text: string
}
const initialControls: ControlsState = {
chunkLength: 1000,
maxNewTokens: 2048,
temperature: 0.9,
topP: 0.9,
repetitionPenalty: 1.05,
normalize: false,
format: 'mp3',
latency: 'normal',
}
const formatMimeMap: Record = {
mp3: 'audio/mpeg',
wav: 'audio/wav',
pcm: 'audio/pcm',
opus: 'audio/opus',
}
function createId() {
return Date.now() + Math.floor(Math.random() * 100000)
}
function arrayBufferToBase64(buffer: ArrayBuffer): string {
const bytes = new Uint8Array(buffer)
let binary = ''
for (let i = 0; i < bytes.byteLength; i++) {
binary += String.fromCharCode(bytes[i])
}
return btoa(binary)
}
function createSpeakerGroup(): SpeakerGroup {
return {
id: createId(),
references: [],
}
}
const initialSpeakerGroup = createSpeakerGroup()
function buildReferencesPayload(
speakerGroups: SpeakerGroup[],
includeBinaryAudio: boolean,
) {
return speakerGroups.flatMap((speakerGroup) =>
speakerGroup.references.map((reference) => ({
text: reference.text,
audio: includeBinaryAudio
? arrayBufferToBase64(reference.audio)
: '',
})),
)
}
function buildPreviewPayload(
inputText: string,
controls: ControlsState,
speakerGroups: SpeakerGroup[],
) {
return {
text: inputText,
chunk_length: controls.chunkLength,
max_new_tokens: controls.maxNewTokens,
format: controls.format,
latency: controls.latency,
normalize: controls.normalize,
references: buildReferencesPayload(speakerGroups, false),
temperature: controls.temperature,
top_p: controls.topP,
repetition_penalty: controls.repetitionPenalty,
}
}
function buildRequestPayload(
inputText: string,
controls: ControlsState,
speakerGroups: SpeakerGroup[],
) {
return {
text: inputText,
chunk_length: controls.chunkLength,
max_new_tokens: controls.maxNewTokens,
format: controls.format,
latency: controls.latency,
normalize: controls.normalize,
references: buildReferencesPayload(speakerGroups, true),
temperature: controls.temperature,
top_p: controls.topP,
repetition_penalty: controls.repetitionPenalty,
}
}
function createFileName(inputText: string) {
const safePrefix = inputText.trim().replace(/\s+/g, '-').slice(0, 24) || 'tts'
return safePrefix
}
function getErrorMessage(error: unknown) {
return error instanceof Error ? error.message : 'Unknown error'
}
function waitForSourceBuffer(sourceBuffer: SourceBuffer) {
if (!sourceBuffer.updating) {
return Promise.resolve()
}
return new Promise((resolve) => {
const handleUpdateEnd = () => {
sourceBuffer.removeEventListener('updateend', handleUpdateEnd)
resolve()
}
sourceBuffer.addEventListener('updateend', handleUpdateEnd)
})
}
function canUseStreamingPlayback(format: AudioFormat) {
const mime = formatMimeMap[format]
return typeof window.MediaSource !== 'undefined' && MediaSource.isTypeSupported(mime)
}
type SettingSliderProps = {
label: string
value: number
min: number
max: number
step?: number
onValueChange: (value: number) => void
formatValue?: (value: number) => string
}
function SettingSlider({
label,
value,
min,
max,
step = 1,
onValueChange,
formatValue,
}: SettingSliderProps) {
return (
{label}
{formatValue ? formatValue(value) : value}
{
const current = nextValue[0]
if (typeof current === 'number') {
onValueChange(current)
}
}}
/>
)
}
function App() {
const [inputText, setInputText] = useState(defaultInputText)
const [controls, setControls] = useState(initialControls)
const [speakerGroups, setSpeakerGroups] = useState([initialSpeakerGroup])
const [pendingReference, setPendingReference] = useState(null)
const [openSpeakerIds, setOpenSpeakerIds] = useState([initialSpeakerGroup.id])
const [metrics, setMetrics] = useState(null)
const [isGenerating, setIsGenerating] = useState(false)
const [copyLabel, setCopyLabel] = useState('Copy')
const [isRequestPreviewOpen, setIsRequestPreviewOpen] = useState(false)
const [statusMessage, setStatusMessage] = useState(null)
const [downloadUrl, setDownloadUrl] = useState(null)
const [downloadName, setDownloadName] = useState('generated-audio.mp3')
const audioRef = useRef(null)
const fileInputRef = useRef(null)
const speakerGroupsRef = useRef([])
const uploadTargetSpeakerIdRef = useRef(null)
const downloadUrlRef = useRef(null)
const mediaSourceUrlRef = useRef(null)
speakerGroupsRef.current = speakerGroups
useEffect(() => {
return () => {
speakerGroupsRef.current.forEach((speakerGroup) => {
speakerGroup.references.forEach((reference) => {
URL.revokeObjectURL(reference.previewUrl)
})
})
if (downloadUrlRef.current) {
URL.revokeObjectURL(downloadUrlRef.current)
}
if (mediaSourceUrlRef.current) {
URL.revokeObjectURL(mediaSourceUrlRef.current)
}
}
}, [])
function addSpeaker() {
const nextSpeaker = createSpeakerGroup()
setSpeakerGroups((current) => [...current, nextSpeaker])
setOpenSpeakerIds((current) => [...current, nextSpeaker.id])
}
function removeSpeaker(speakerId: number) {
setSpeakerGroups((current) => {
const targetSpeaker = current.find((speakerGroup) => speakerGroup.id === speakerId)
if (targetSpeaker) {
targetSpeaker.references.forEach((reference) => {
URL.revokeObjectURL(reference.previewUrl)
})
}
const next = current.filter((speakerGroup) => speakerGroup.id !== speakerId)
return next.length > 0 ? next : [createSpeakerGroup()]
})
setOpenSpeakerIds((current) => current.filter((currentSpeakerId) => currentSpeakerId !== speakerId))
if (pendingReference?.speakerId === speakerId) {
setPendingReference(null)
}
}
function addReference(speakerId: number, name: string, audio: ArrayBuffer, text: string) {
const previewUrl = URL.createObjectURL(new Blob([audio], { type: formatMimeMap.mp3 }))
setSpeakerGroups((current) =>
current.map((speakerGroup) =>
speakerGroup.id === speakerId
? {
...speakerGroup,
references: [
...speakerGroup.references,
{
id: createId(),
name,
audio,
text,
previewUrl,
},
],
}
: speakerGroup,
),
)
}
function removeReference(speakerId: number, referenceId: number) {
setSpeakerGroups((current) =>
current.map((speakerGroup) => {
if (speakerGroup.id !== speakerId) {
return speakerGroup
}
return {
...speakerGroup,
references: speakerGroup.references.filter((reference) => {
if (reference.id === referenceId) {
URL.revokeObjectURL(reference.previewUrl)
return false
}
return true
}),
}
}),
)
}
function updateReferenceText(speakerId: number, referenceId: number, text: string) {
setSpeakerGroups((current) =>
current.map((speakerGroup) =>
speakerGroup.id === speakerId
? {
...speakerGroup,
references: speakerGroup.references.map((reference) =>
reference.id === referenceId ? { ...reference, text } : reference,
),
}
: speakerGroup,
),
)
}
function clearDownloadUrl() {
if (downloadUrlRef.current) {
URL.revokeObjectURL(downloadUrlRef.current)
downloadUrlRef.current = null
}
setDownloadUrl(null)
}
function clearMediaSourceUrl() {
if (mediaSourceUrlRef.current) {
URL.revokeObjectURL(mediaSourceUrlRef.current)
mediaSourceUrlRef.current = null
}
}
async function handleReferenceUpload(event: React.ChangeEvent) {
const file = event.target.files?.[0]
const speakerId = uploadTargetSpeakerIdRef.current
event.target.value = ''
uploadTargetSpeakerIdRef.current = null
if (!file || typeof speakerId !== 'number') {
return
}
const audio = await file.arrayBuffer()
setPendingReference({
mode: 'create',
speakerId,
name: file.name,
audio,
text: '',
})
}
function savePendingReference() {
if (!pendingReference) {
return
}
if (pendingReference.mode === 'create' && pendingReference.audio) {
addReference(
pendingReference.speakerId,
pendingReference.name,
pendingReference.audio,
pendingReference.text,
)
}
if (pendingReference.mode === 'edit' && typeof pendingReference.referenceId === 'number') {
updateReferenceText(
pendingReference.speakerId,
pendingReference.referenceId,
pendingReference.text,
)
}
setPendingReference(null)
setStatusMessage(null)
}
async function copyRequestPreview() {
const requestPreview = JSON.stringify(
buildPreviewPayload(inputText, controls, speakerGroups),
null,
2,
)
try {
await navigator.clipboard.writeText(requestPreview)
setCopyLabel('Copied')
window.setTimeout(() => setCopyLabel('Copy'), 2000)
} catch (error) {
setStatusMessage({
tone: 'error',
message: `Failed to copy request preview: ${getErrorMessage(error)}`,
})
}
}
async function handleGenerateAudio() {
const audioElement = audioRef.current
if (!audioElement) {
return
}
const mime = formatMimeMap[controls.format]
const useStreamingPlayback = canUseStreamingPlayback(controls.format)
clearDownloadUrl()
clearMediaSourceUrl()
setMetrics(null)
setStatusMessage(null)
setIsGenerating(true)
try {
const response = await fetch('/v1/tts', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(buildRequestPayload(inputText, controls, speakerGroups)),
})
if (!response.ok || !response.body) {
throw new Error('Failed to generate audio')
}
const reader = response.body.getReader()
let mediaSource: MediaSource | null = null
if (useStreamingPlayback) {
mediaSource = new MediaSource()
const streamUrl = URL.createObjectURL(mediaSource)
mediaSourceUrlRef.current = streamUrl
audioElement.src = streamUrl
} else {
audioElement.removeAttribute('src')
audioElement.load()
}
const allChunks: ArrayBuffer[] = []
const playQueue: ArrayBuffer[] = []
let sourceBuffer: SourceBuffer | null = null
let readingDone = false
let receivedLength = 0
let ttftMs = -1
const startTime = performance.now()
if (mediaSource) {
const sourceReady = new Promise((resolve, reject) => {
mediaSource.addEventListener(
'sourceopen',
() => {
try {
sourceBuffer = mediaSource.addSourceBuffer(mime)
const processQueue = async () => {
if (!sourceBuffer || !mediaSource) {
return
}
while (true) {
if (readingDone && playQueue.length === 0) {
await waitForSourceBuffer(sourceBuffer)
if (mediaSource.readyState === 'open') {
mediaSource.endOfStream()
}
break
}
const chunk = playQueue.shift()
if (!chunk) {
await new Promise((resolveSleep) => {
window.setTimeout(resolveSleep, 50)
})
continue
}
await waitForSourceBuffer(sourceBuffer)
sourceBuffer.appendBuffer(chunk)
await waitForSourceBuffer(sourceBuffer)
}
}
void processQueue()
resolve()
} catch (error) {
reject(error)
}
},
{ once: true },
)
})
await sourceReady
}
while (true) {
const { done, value } = await reader.read()
if (done) {
readingDone = true
break
}
receivedLength += value.byteLength
if (ttftMs < 0) {
ttftMs = performance.now() - startTime
}
setMetrics({
textLength: inputText.length,
ttftMs,
receivedKb: Math.round(receivedLength / 1024),
})
const chunk = value.buffer.slice(value.byteOffset, value.byteOffset + value.byteLength)
playQueue.push(chunk)
allChunks.push(chunk)
if (useStreamingPlayback && audioElement.paused) {
void audioElement.play().catch(() => undefined)
}
}
const audioBlob = new Blob(allChunks, { type: mime })
const nextDownloadUrl = URL.createObjectURL(audioBlob)
downloadUrlRef.current = nextDownloadUrl
setDownloadUrl(nextDownloadUrl)
setDownloadName(`${createFileName(inputText)}.${controls.format}`)
if (!useStreamingPlayback) {
audioElement.src = nextDownloadUrl
audioElement.load()
setStatusMessage({
tone: 'info',
message: `Format "${controls.format}" is not supported for in-browser playback. The file is ready to download after generation completes.`,
})
}
} catch (error) {
setStatusMessage({
tone: 'error',
message: `Audio generation failed: ${getErrorMessage(error)}`,
})
} finally {
setIsGenerating(false)
}
}
const requestPreview = JSON.stringify(
buildPreviewPayload(inputText, controls, speakerGroups),
null,
2,
)
const totalReferenceCount = speakerGroups.reduce(
(count, speakerGroup) => count + speakerGroup.references.length,
0,
)
return (
Input
Enter the text to synthesize and inspect the outgoing request payload.
Input Text
Request Preview
Live snapshot of the payload sent to the backend.
{copyLabel}
{isRequestPreviewOpen ? 'Collapse' : 'Expand'}
{isGenerating ? (
) : (
)}
{isGenerating ? 'Generating Audio...' : 'Generate Audio'}
{statusMessage ? (
{statusMessage.tone === 'error' ? (
) : (
)}
{statusMessage.tone === 'error' ? 'Error' : 'Notice'}
{statusMessage.message}
) : null}
Stream the result when supported, then preview or download the final file.
{metrics ? (
<>
Text length: {metrics.textLength}
TTFT: {metrics.ttftMs.toFixed(2)} ms
Received: {metrics.receivedKb} KB
>
) : (
No output yet
)}
Reference Audio
Build one or more speaker groups. Each speaker can have multiple reference clips.
{speakerGroups.length > 0 ? (
speakerGroups.map((speakerGroup, speakerIndex) => (
{
setOpenSpeakerIds((current) =>
open
? [...current, speakerGroup.id]
: current.filter(
(currentSpeakerId) => currentSpeakerId !== speakerGroup.id,
),
)
}}
>
Speaker {speakerIndex}
{speakerGroup.references.length} reference
{speakerGroup.references.length === 1 ? '' : 's'}
{
uploadTargetSpeakerIdRef.current = speakerGroup.id
fileInputRef.current?.click()
}}
>
Upload
{speakerGroups.length > 1 ? (
removeSpeaker(speakerGroup.id)}
>
Remove
) : null}
{speakerGroup.references.length > 0 ? (
speakerGroup.references.map((reference) => (
setPendingReference({
mode: 'edit',
speakerId: speakerGroup.id,
referenceId: reference.id,
name: reference.name,
text: reference.text,
})
}
>
Edit Text
removeReference(speakerGroup.id, reference.id)
}
>
Remove
))
) : (
No references yet.
)}
))
) : (
No speaker groups configured yet.
)}
Generation Settings
Adjust sampling and output parameters.
Latency Mode
{
if (value) {
setControls((current) => ({
...current,
latency: value as LatencyMode,
}))
}
}}
>
balanced
normal
Low uses incremental local decode for faster first audio. Normal waits for the
full LLM result, then decodes once.
Format
{
if (value) {
setControls((current) => ({
...current,
format: value as AudioFormat,
}))
}
}}
>
mp3
wav
pcm
opus
Normalize
Normalize text before synthesis to keep input formatting consistent.
setControls((current) => ({
...current,
normalize: checked,
}))
}
/>
setControls((current) => ({
...current,
chunkLength: value,
}))
}
/>
setControls((current) => ({
...current,
maxNewTokens: value,
}))
}
/>
value.toFixed(2)}
onValueChange={(value) =>
setControls((current) => ({
...current,
temperature: value,
}))
}
/>
value.toFixed(2)}
onValueChange={(value) =>
setControls((current) => ({
...current,
topP: value,
}))
}
/>
value.toFixed(2)}
onValueChange={(value) =>
setControls((current) => ({
...current,
repetitionPenalty: value,
}))
}
/>
!open && setPendingReference(null)}>
{pendingReference?.mode === 'create' ? 'Save Reference Text' : 'Edit Reference Text'}
{pendingReference
? `Speaker ${speakerGroups.findIndex(
(speakerGroup) => speakerGroup.id === pendingReference.speakerId,
)}`
: ''}
setPendingReference(null)}>
Cancel
Save
)
}
export default App
================================================
FILE: awesome_webui/src/components/ui/alert.tsx
================================================
import * as React from 'react'
import { cva, type VariantProps } from 'class-variance-authority'
import { cn } from '@/lib/utils'
const alertVariants = cva('relative w-full rounded-lg border px-4 py-3 text-sm', {
variants: {
variant: {
default: 'bg-card text-card-foreground',
destructive: 'border-destructive/20 bg-destructive/5 text-destructive',
warning: 'border-amber-200 bg-amber-50 text-amber-900',
},
},
defaultVariants: {
variant: 'default',
},
})
function Alert({
className,
variant,
...props
}: React.ComponentProps<'div'> & VariantProps) {
return
}
function AlertTitle({ className, ...props }: React.ComponentProps<'h5'>) {
return
}
function AlertDescription({ className, ...props }: React.ComponentProps<'div'>) {
return
}
export { Alert, AlertDescription, AlertTitle }
================================================
FILE: awesome_webui/src/components/ui/badge.tsx
================================================
/* eslint-disable react-refresh/only-export-components */
import * as React from 'react'
import { cva, type VariantProps } from 'class-variance-authority'
import { cn } from '@/lib/utils'
const badgeVariants = cva(
'inline-flex items-center rounded-md border px-2 py-0.5 text-xs font-medium transition-colors',
{
variants: {
variant: {
default: 'border-transparent bg-primary text-primary-foreground',
secondary: 'border-transparent bg-secondary text-secondary-foreground',
outline: 'text-foreground',
},
},
defaultVariants: {
variant: 'default',
},
},
)
function Badge({
className,
variant,
...props
}: React.ComponentProps<'div'> & VariantProps) {
return
}
export { Badge, badgeVariants }
================================================
FILE: awesome_webui/src/components/ui/button.tsx
================================================
/* eslint-disable react-refresh/only-export-components */
import * as React from 'react'
import { Slot } from '@radix-ui/react-slot'
import { cva, type VariantProps } from 'class-variance-authority'
import { cn } from '@/lib/utils'
const buttonVariants = cva(
'inline-flex items-center justify-center gap-2 whitespace-nowrap rounded-md text-sm font-medium transition-colors disabled:pointer-events-none disabled:opacity-50 outline-none focus-visible:ring-2 focus-visible:ring-ring/70 focus-visible:ring-offset-2 focus-visible:ring-offset-background',
{
variants: {
variant: {
default: 'bg-primary text-primary-foreground hover:bg-primary/90',
destructive: 'bg-destructive text-destructive-foreground hover:bg-destructive/90',
outline: 'border bg-card hover:bg-accent hover:text-accent-foreground',
secondary: 'bg-secondary text-secondary-foreground hover:bg-secondary/80',
ghost: 'hover:bg-accent hover:text-accent-foreground',
},
size: {
default: 'h-9 px-4 py-2',
sm: 'h-8 rounded-md px-3 text-xs',
lg: 'h-11 rounded-md px-6',
icon: 'size-9',
},
},
defaultVariants: {
variant: 'default',
size: 'default',
},
},
)
type ButtonProps = React.ComponentProps<'button'> &
VariantProps & {
asChild?: boolean
}
function Button({ className, variant, size, asChild = false, ...props }: ButtonProps) {
const Comp = asChild ? Slot : 'button'
return
}
export { Button, buttonVariants }
================================================
FILE: awesome_webui/src/components/ui/card.tsx
================================================
import * as React from 'react'
import { cn } from '@/lib/utils'
function Card({ className, ...props }: React.ComponentProps<'div'>) {
return (
)
}
function CardHeader({ className, ...props }: React.ComponentProps<'div'>) {
return
}
function CardTitle({ className, ...props }: React.ComponentProps<'div'>) {
return
}
function CardDescription({ className, ...props }: React.ComponentProps<'div'>) {
return
}
function CardContent({ className, ...props }: React.ComponentProps<'div'>) {
return
}
export { Card, CardContent, CardDescription, CardHeader, CardTitle }
================================================
FILE: awesome_webui/src/components/ui/collapsible.tsx
================================================
import * as CollapsiblePrimitive from '@radix-ui/react-collapsible'
const Collapsible = CollapsiblePrimitive.Root
const CollapsibleTrigger = CollapsiblePrimitive.CollapsibleTrigger
const CollapsibleContent = CollapsiblePrimitive.CollapsibleContent
export { Collapsible, CollapsibleContent, CollapsibleTrigger }
================================================
FILE: awesome_webui/src/components/ui/dialog.tsx
================================================
import * as React from 'react'
import * as DialogPrimitive from '@radix-ui/react-dialog'
import { X } from 'lucide-react'
import { cn } from '@/lib/utils'
const Dialog = DialogPrimitive.Root
const DialogTrigger = DialogPrimitive.Trigger
const DialogPortal = DialogPrimitive.Portal
const DialogClose = DialogPrimitive.Close
function DialogOverlay({
className,
...props
}: React.ComponentProps) {
return (
)
}
function DialogContent({
className,
children,
...props
}: React.ComponentProps) {
return (
{children}
Close
)
}
function DialogHeader({ className, ...props }: React.ComponentProps<'div'>) {
return
}
function DialogFooter({ className, ...props }: React.ComponentProps<'div'>) {
return
}
function DialogTitle({ className, ...props }: React.ComponentProps) {
return (
)
}
function DialogDescription({
className,
...props
}: React.ComponentProps) {
return (
)
}
export {
Dialog,
DialogContent,
DialogDescription,
DialogFooter,
DialogHeader,
DialogTitle,
DialogTrigger,
}
================================================
FILE: awesome_webui/src/components/ui/label.tsx
================================================
import * as React from 'react'
import * as LabelPrimitive from '@radix-ui/react-label'
import { cn } from '@/lib/utils'
function Label({ className, ...props }: React.ComponentProps) {
return (
)
}
export { Label }
================================================
FILE: awesome_webui/src/components/ui/scroll-area.tsx
================================================
import * as React from 'react'
import * as ScrollAreaPrimitive from '@radix-ui/react-scroll-area'
import { cn } from '@/lib/utils'
function ScrollArea({
className,
children,
...props
}: React.ComponentProps) {
return (
{children}
)
}
function ScrollBar({
className,
orientation = 'vertical',
...props
}: React.ComponentProps) {
return (
)
}
export { ScrollArea, ScrollBar }
================================================
FILE: awesome_webui/src/components/ui/separator.tsx
================================================
import * as React from 'react'
import * as SeparatorPrimitive from '@radix-ui/react-separator'
import { cn } from '@/lib/utils'
function Separator({
className,
orientation = 'horizontal',
decorative = true,
...props
}: React.ComponentProps) {
return (
)
}
export { Separator }
================================================
FILE: awesome_webui/src/components/ui/slider.tsx
================================================
import * as React from 'react'
import * as SliderPrimitive from '@radix-ui/react-slider'
import { cn } from '@/lib/utils'
function Slider({
className,
...props
}: React.ComponentProps) {
return (
)
}
export { Slider }
================================================
FILE: awesome_webui/src/components/ui/switch.tsx
================================================
import * as React from 'react'
import * as SwitchPrimitive from '@radix-ui/react-switch'
import { cn } from '@/lib/utils'
function Switch({
className,
...props
}: React.ComponentProps) {
return (
)
}
export { Switch }
================================================
FILE: awesome_webui/src/components/ui/textarea.tsx
================================================
import * as React from 'react'
import { cn } from '@/lib/utils'
function Textarea({ className, ...props }: React.ComponentProps<'textarea'>) {
return (
)
}
export { Textarea }
================================================
FILE: awesome_webui/src/components/ui/toggle-group.tsx
================================================
import * as React from 'react'
import * as ToggleGroupPrimitive from '@radix-ui/react-toggle-group'
import { cva, type VariantProps } from 'class-variance-authority'
import { cn } from '@/lib/utils'
const toggleGroupItemVariants = cva(
'inline-flex items-center justify-center rounded-md text-sm font-medium transition-colors hover:bg-accent hover:text-accent-foreground focus-visible:ring-2 focus-visible:ring-ring/70 focus-visible:ring-offset-2 focus-visible:ring-offset-background disabled:pointer-events-none disabled:opacity-50 data-[state=on]:bg-primary data-[state=on]:text-primary-foreground border border-border bg-card',
{
variants: {
size: {
default: 'h-9 px-3',
sm: 'h-8 px-2.5 text-xs',
lg: 'h-10 px-4',
},
},
defaultVariants: {
size: 'default',
},
},
)
function ToggleGroup({
className,
...props
}: React.ComponentProps) {
return (
)
}
function ToggleGroupItem({
className,
size,
...props
}: React.ComponentProps &
VariantProps) {
return (
)
}
export { ToggleGroup, ToggleGroupItem }
================================================
FILE: awesome_webui/src/index.css
================================================
@import "tailwindcss";
:root {
--background: 0 0% 96%;
--foreground: 240 10% 3.9%;
--card: 0 0% 100%;
--card-foreground: 240 10% 3.9%;
--popover: 0 0% 100%;
--popover-foreground: 240 10% 3.9%;
--primary: 240 5.9% 10%;
--primary-foreground: 0 0% 98%;
--secondary: 240 4.8% 95.9%;
--secondary-foreground: 240 5.9% 10%;
--muted: 240 4.8% 95.9%;
--muted-foreground: 240 3.8% 46.1%;
--accent: 240 4.8% 95.9%;
--accent-foreground: 240 5.9% 10%;
--destructive: 0 72.2% 50.6%;
--destructive-foreground: 0 0% 98%;
--border: 240 5.9% 88%;
--input: 240 5.9% 88%;
--ring: 240 5% 64.9%;
--radius: 0.75rem;
}
@theme inline {
--color-background: hsl(var(--background));
--color-foreground: hsl(var(--foreground));
--color-card: hsl(var(--card));
--color-card-foreground: hsl(var(--card-foreground));
--color-popover: hsl(var(--popover));
--color-popover-foreground: hsl(var(--popover-foreground));
--color-primary: hsl(var(--primary));
--color-primary-foreground: hsl(var(--primary-foreground));
--color-secondary: hsl(var(--secondary));
--color-secondary-foreground: hsl(var(--secondary-foreground));
--color-muted: hsl(var(--muted));
--color-muted-foreground: hsl(var(--muted-foreground));
--color-accent: hsl(var(--accent));
--color-accent-foreground: hsl(var(--accent-foreground));
--color-destructive: hsl(var(--destructive));
--color-destructive-foreground: hsl(var(--destructive-foreground));
--color-border: hsl(var(--border));
--color-input: hsl(var(--input));
--color-ring: hsl(var(--ring));
--radius-sm: calc(var(--radius) - 4px);
--radius-md: calc(var(--radius) - 2px);
--radius-lg: var(--radius);
--radius-xl: calc(var(--radius) + 4px);
}
@layer base {
* {
@apply border-border;
}
html {
min-width: 320px;
}
body {
@apply bg-background text-foreground antialiased;
font-family: "Inter", "Avenir Next", "Segoe UI", sans-serif;
}
button,
input,
textarea {
font: inherit;
}
}
================================================
FILE: awesome_webui/src/main.tsx
================================================
import { StrictMode } from 'react'
import { createRoot } from 'react-dom/client'
import './index.css'
import App from './App.tsx'
createRoot(document.getElementById('root')!).render(
,
)
================================================
FILE: awesome_webui/tsconfig.app.json
================================================
{
"compilerOptions": {
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.app.tsbuildinfo",
"target": "ES2022",
"useDefineForClassFields": true,
"lib": [
"ES2022",
"DOM",
"DOM.Iterable"
],
"module": "ESNext",
"types": [
"vite/client"
],
"skipLibCheck": true,
"baseUrl": ".",
"paths": {
"@/*": [
"./src/*"
]
},
"moduleResolution": "bundler",
"allowImportingTsExtensions": true,
"verbatimModuleSyntax": true,
"moduleDetection": "force",
"noEmit": true,
"jsx": "react-jsx",
"strict": true,
"noUnusedLocals": true,
"noUnusedParameters": true,
"erasableSyntaxOnly": true,
"noFallthroughCasesInSwitch": true,
"noUncheckedSideEffectImports": true
},
"include": [
"src"
]
}
================================================
FILE: awesome_webui/tsconfig.json
================================================
{
"files": [],
"references": [
{ "path": "./tsconfig.app.json" },
{ "path": "./tsconfig.node.json" }
]
}
================================================
FILE: awesome_webui/tsconfig.node.json
================================================
{
"compilerOptions": {
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.node.tsbuildinfo",
"target": "ES2023",
"lib": [
"ES2023"
],
"module": "ESNext",
"types": [
"node"
],
"skipLibCheck": true,
"baseUrl": ".",
"paths": {
"@/*": [
"./src/*"
]
},
"moduleResolution": "bundler",
"allowImportingTsExtensions": true,
"verbatimModuleSyntax": true,
"moduleDetection": "force",
"noEmit": true,
"strict": true,
"noUnusedLocals": true,
"noUnusedParameters": true,
"erasableSyntaxOnly": true,
"noFallthroughCasesInSwitch": true,
"noUncheckedSideEffectImports": true
},
"include": [
"vite.config.ts"
]
}
================================================
FILE: awesome_webui/vite.config.ts
================================================
import fs from 'node:fs'
import { defineConfig, type Plugin } from 'vite'
import react from '@vitejs/plugin-react-swc'
import tailwindcss from '@tailwindcss/vite'
import path from 'node:path'
function inlineEntryAssets(): Plugin {
let resolvedOutDir = ''
return {
name: 'inline-entry-assets',
apply: 'build',
configResolved(config) {
resolvedOutDir = path.resolve(config.root, config.build.outDir)
},
closeBundle() {
const indexHtmlPath = path.join(resolvedOutDir, 'index.html')
if (!fs.existsSync(indexHtmlPath)) {
return
}
const filesToDelete = new Set()
const escapeInlineScript = (code: string) => code.replace(/<\/script/gi, '<\\/script')
const escapeInlineStyle = (code: string) => code.replace(/<\/style/gi, '<\\/style')
const normalizeFileName = (assetPath: string) =>
assetPath.replace(/^\//, '').replace(/^\.\//, '')
const readBuiltAsset = (assetPath: string) => {
const fileName = normalizeFileName(assetPath)
const absolutePath = path.join(resolvedOutDir, fileName)
if (!fs.existsSync(absolutePath)) {
return null
}
filesToDelete.add(absolutePath)
return fs.readFileSync(absolutePath, 'utf8')
}
let html = fs.readFileSync(indexHtmlPath, 'utf8')
html = html.replace(
/ ]+href="([^"]+)"[^>]*>/g,
(_fullMatch, href: string) => {
const absolutePath = path.join(resolvedOutDir, normalizeFileName(href))
if (fs.existsSync(absolutePath)) {
filesToDelete.add(absolutePath)
}
return ''
},
)
html = html.replace(
/ ]+href="([^"]+)"[^>]*>/g,
(fullMatch, href: string) => {
const assetSource = readBuiltAsset(href)
if (!assetSource) {
return fullMatch
}
return ``
},
)
html = html.replace(
/`
},
)
fs.writeFileSync(indexHtmlPath, html)
for (const filePath of filesToDelete) {
fs.rmSync(filePath, { force: true })
}
fs.rmSync(path.join(resolvedOutDir, 'vite.svg'), { force: true })
fs.rmSync(path.join(resolvedOutDir, 'assets'), { recursive: true, force: true })
},
}
}
// https://vite.dev/config/
export default defineConfig({
plugins: [react(), tailwindcss(), inlineEntryAssets()],
publicDir: false,
resolve: {
alias: {
'@': path.resolve(__dirname, './src'),
},
},
build: {
assetsInlineLimit: Number.MAX_SAFE_INTEGER,
cssCodeSplit: false,
modulePreload: false,
rollupOptions: {
output: {
inlineDynamicImports: true,
},
},
},
server: {
proxy: {
'/v1': 'http://localhost:8888',
'/v2': 'http://localhost:8888',
'/health': 'http://localhost:8888',
},
},
})
================================================
FILE: compose.base.yml
================================================
services:
app-base:
build:
context: .
dockerfile: docker/Dockerfile
args:
BACKEND: ${BACKEND:-cuda} # or cpu
UV_VERSION: ${UV_VERSION:-0.8.15}
volumes:
- ./checkpoints:/app/checkpoints
- ./references:/app/references
environment:
COMPILE: ${COMPILE:-0}
# GPU (remove this block if CPU-only):
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: all
capabilities: [gpu]
tty: true
stdin_open: true
================================================
FILE: compose.yml
================================================
name: fish-speech
services:
webui:
extends:
file: compose.base.yml
service: app-base
build:
target: webui
environment:
COMPILE: ${COMPILE:-0}
profiles: ["webui"]
ports:
- "${GRADIO_PORT:-7860}:7860"
server:
extends:
file: compose.base.yml
service: app-base
build:
target: server
environment:
COMPILE: ${COMPILE:-0}
profiles: ["server"]
ports:
- "${API_PORT:-8080}:8080"
================================================
FILE: docker/Dockerfile
================================================
# docker/Dockerfile
# IMPORTANT: The docker images do not contain the checkpoints. You need to mount the checkpoints to the container.
# Build the image:
# docker build \
# --platform linux/amd64 \
# -f docker/Dockerfile \
# --build-arg BACKEND=[cuda, cpu] \
# --target [webui, server] \
# -t fish-speech-[webui, server]:[cuda, cpu] .
# e.g. for building the webui:
# docker build \
# --platform linux/amd64 \
# -f docker/Dockerfile \
# --build-arg BACKEND=cuda \
# --target webui \
# -t fish-speech-webui:cuda .
# e.g. for building the server:
# docker build \
# --platform linux/amd64 \
# -f docker/Dockerfile \
# --build-arg BACKEND=cuda \
# --target server \
# -t fish-speech-server:cuda .
# Multi-platform build:
# docker buildx build \
# --platform linux/amd64,linux/arm64 \
# -f docker/Dockerfile \
# --build-arg BACKEND=cpu \
# --target webui \
# -t fish-speech-webui:cpu .
# Running the image interactively:
# docker run \
# --gpus all \
# -v /path/to/fish-speech/checkpoints:/app/checkpoints \
# -e COMPILE=1 \ ... or -e COMPILE=0 \
# -it fish-speech-[webui, server]:[cuda, cpu]
# E.g. running the webui:
# docker run \
# --gpus all \
# -v ./checkpoints:/app/checkpoints \
# -e COMPILE=1 \
# -p 7860:7860 \
# fish-speech-webui:cuda
# E.g. running the server:
# docker run \
# --gpus all \
# -v ./checkpoints:/app/checkpoints \
# -p 8080:8080 \
# -it fish-speech-server:cuda
# Select the specific cuda version (see https://hub.docker.com/r/nvidia/cuda/)
ARG CUDA_VER=12.6.0
# Adapt the uv extra to fit the cuda version (one of [cu126, cu128, cu129])
ARG UV_EXTRA=cu126
ARG BACKEND=cuda
ARG UBUNTU_VER=24.04
ARG PY_VER=3.12
ARG UV_VERSION=0.8.15
# Create non-root user early for security
ARG USERNAME=fish
ARG USER_UID=1000
ARG USER_GID=1000
##############################################################
# Base stage per backend
##############################################################
# --- CUDA (x86_64) ---
FROM nvidia/cuda:${CUDA_VER}-cudnn-runtime-ubuntu${UBUNTU_VER} AS base-cuda
ENV DEBIAN_FRONTEND=noninteractive
# Install system dependencies in a single layer with cleanup
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
set -eux \
&& rm -f /etc/apt/apt.conf.d/docker-clean \
&& echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' >/etc/apt/apt.conf.d/keep-cache \
&& apt-get update \
&& apt-get install -y --no-install-recommends \
python3-pip \
python3-dev \
git \
ca-certificates \
curl \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# --- CPU-only (portable x86_64) ---
FROM python:${PY_VER}-slim AS base-cpu
ENV UV_EXTRA=cpu
# Install system dependencies in a single layer with cleanup
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
set -eux \
&& rm -f /etc/apt/apt.conf.d/docker-clean \
&& echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' >/etc/apt/apt.conf.d/keep-cache \
&& apt-get update \
&& apt-get install -y --no-install-recommends \
git \
ca-certificates \
curl \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
##############################################################
# UV stage
##############################################################
ARG UV_VERSION
FROM ghcr.io/astral-sh/uv:${UV_VERSION} AS uv-bin
##############################################################
# Shared app base stage
##############################################################
FROM base-${BACKEND} AS app-base
ARG PY_VER
ARG BACKEND
ARG USERNAME
ARG USER_UID
ARG USER_GID
ARG UV_VERSION
ARG UV_EXTRA
ENV BACKEND=${BACKEND} \
DEBIAN_FRONTEND=noninteractive \
PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1
# System dependencies for audio processing
ARG DEPENDENCIES=" \
libsox-dev \
build-essential \
cmake \
libasound-dev \
portaudio19-dev \
libportaudio2 \
libportaudiocpp0 \
ffmpeg"
# Install system dependencies with caching and cleanup
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
set -eux \
&& rm -f /etc/apt/apt.conf.d/docker-clean \
&& echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' >/etc/apt/apt.conf.d/keep-cache \
&& apt-get update \
&& apt-get install -y --no-install-recommends ${DEPENDENCIES} \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# Install specific uv version
COPY --from=uv-bin /uv /uvx /bin/
# RUN groupadd --gid ${USER_GID} ${USERNAME} \
# && useradd --uid ${USER_UID} --gid ${USER_GID} -m ${USERNAME} \
# && mkdir -p /app /home/${USERNAME}/.cache \
# && chown -R ${USERNAME}:${USERNAME} /app /home/${USERNAME}/.cache
# Create non-root user (or use existing user)
RUN set -eux; \
if getent group ${USER_GID} >/dev/null 2>&1; then \
echo "Group ${USER_GID} already exists"; \
else \
groupadd -g ${USER_GID} ${USERNAME}; \
fi; \
if id -u ${USER_UID} >/dev/null 2>&1; then \
echo "User ${USER_UID} already exists, using existing user"; \
EXISTING_USER=$(id -un ${USER_UID}); \
mkdir -p /app /home/${EXISTING_USER}/.cache; \
chown -R ${USER_UID}:${USER_GID} /app /home/${EXISTING_USER}/.cache; \
else \
useradd -m -u ${USER_UID} -g ${USER_GID} ${USERNAME}; \
mkdir -p /app /home/${USERNAME}/.cache; \
chown -R ${USERNAME}:${USERNAME} /app /home/${USERNAME}/.cache; \
fi
# Create references directory with proper permissions for the non-root user
RUN mkdir -p /app/references \
&& chown -R ${USER_UID}:${USER_GID} /app/references \
&& chmod 755 /app/references
# Set working directory
WORKDIR /app
# Copy dependency files first for better caching
COPY --chown=${USER_UID}:${USER_GID} pyproject.toml uv.lock README.md ./
# Switch to non-root user for package installation
USER ${USER_UID}:${USER_GID}
# Install Python dependencies (cacheable by lockfiles)
# Use a generic cache path that works regardless of username
RUN --mount=type=cache,target=/tmp/uv-cache,uid=${USER_UID},gid=${USER_GID} \
uv python pin ${PY_VER} \
&& uv sync --extra ${UV_EXTRA} --frozen --no-install-project
# Copy application code
COPY --chown=${USER_UID}:${USER_GID} . .
# Install the local package after copying source code
RUN uv sync --extra ${UV_EXTRA} --frozen
# Create common entrypoint script
RUN printf '%s\n' \
'#!/bin/bash' \
'set -euo pipefail' \
'' \
'# Set user info from build args' \
'USER_UID='${USER_UID} \
'USER_GID='${USER_GID} \
'' \
'# Logging function' \
'log() { echo "[$(date +"%Y-%m-%d %H:%M:%S")] $*" >&2; }' \
'' \
'# Validate environment' \
'validate_env() {' \
' if [ ! -d "/app/checkpoints" ]; then' \
' log "WARNING: /app/checkpoints directory not found. Please mount your checkpoints."' \
' fi' \
' if [ ! -d "/app/references" ]; then' \
' log "WARNING: /app/references directory not found. Please mount your references."' \
' else' \
' # Check if we can write to references directory' \
' if [ ! -w "/app/references" ]; then' \
' log "ERROR: Cannot write to /app/references directory. Please ensure the mounted directory has proper permissions for user with UID ${USER_UID}."' \
' log "You can fix this by running: sudo chown -R ${USER_UID}:${USER_GID} /path/to/your/references"' \
' exit 1' \
' fi' \
' fi' \
'}' \
'' \
'# Build device arguments' \
'build_device_args() {' \
' if [ "${BACKEND:-}" = "cpu" ]; then' \
' echo "--device cpu"' \
' fi' \
'}' \
'' \
'# Build compile arguments' \
'build_compile_args() {' \
' if [ "${1:-}" = "compile" ] || [ "${COMPILE:-}" = "1" ] || [ "${COMPILE:-}" = "true" ]; then' \
' echo "--compile"' \
' shift' \
' fi' \
' echo "$@"' \
'}' \
'' \
'# Health check function' \
'health_check() {' \
' local port=${1:-7860}' \
' local endpoint=${2:-/health}' \
' curl -f http://localhost:${port}${endpoint} 2>/dev/null || exit 1' \
'}' \
> /app/common.sh && chmod +x /app/common.sh
##############################################################
# App stages
##############################################################
# Gradio WebUI
FROM app-base AS webui
ENV PYTHONDONTWRITEBYTECODE=1 PYTHONUNBUFFERED=1
ARG GRADIO_SERVER_NAME="0.0.0.0"
ARG GRADIO_SERVER_PORT=7860
ARG LLAMA_CHECKPOINT_PATH="checkpoints/s2-pro"
ARG DECODER_CHECKPOINT_PATH="checkpoints/s2-pro/codec.pth"
ARG DECODER_CONFIG_NAME="modded_dac_vq"
# Expose port
EXPOSE ${GRADIO_SERVER_PORT}
# Set environment variables
ENV GRADIO_SERVER_NAME=${GRADIO_SERVER_NAME}
ENV GRADIO_SERVER_PORT=${GRADIO_SERVER_PORT}
ENV LLAMA_CHECKPOINT_PATH=${LLAMA_CHECKPOINT_PATH}
ENV DECODER_CHECKPOINT_PATH=${DECODER_CHECKPOINT_PATH}
ENV DECODER_CONFIG_NAME=${DECODER_CONFIG_NAME}
# Create webui entrypoint
RUN printf '%s\n' \
'#!/bin/bash' \
'source /app/common.sh' \
'' \
'log "Starting Fish Speech WebUI..."' \
'validate_env' \
'' \
'DEVICE_ARGS=$(build_device_args)' \
'COMPILE_ARGS=$(build_compile_args "$@")' \
'' \
'log "Device args: ${DEVICE_ARGS:-none}"' \
'log "Compile args: ${COMPILE_ARGS}"' \
'log "Server: ${GRADIO_SERVER_NAME}:${GRADIO_SERVER_PORT}"' \
'' \
'exec uv run tools/run_webui.py \' \
' --llama-checkpoint-path "${LLAMA_CHECKPOINT_PATH}" \' \
' --decoder-checkpoint-path "${DECODER_CHECKPOINT_PATH}" \' \
' --decoder-config-name "${DECODER_CONFIG_NAME}" \' \
' ${DEVICE_ARGS} ${COMPILE_ARGS}' \
> /app/start_webui.sh && chmod +x /app/start_webui.sh
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
CMD curl -f http://localhost:${GRADIO_SERVER_PORT}/health || exit 1
ENTRYPOINT ["/app/start_webui.sh"]
# API Server
FROM app-base AS server
ENV PYTHONDONTWRITEBYTECODE=1 PYTHONUNBUFFERED=1
ARG API_SERVER_NAME="0.0.0.0"
ARG API_SERVER_PORT=8080
ARG LLAMA_CHECKPOINT_PATH="checkpoints/s2-pro"
ARG DECODER_CHECKPOINT_PATH="checkpoints/s2-pro/codec.pth"
ARG DECODER_CONFIG_NAME="modded_dac_vq"
# Expose port
EXPOSE ${API_SERVER_PORT}
# Set environment variables
ENV API_SERVER_NAME=${API_SERVER_NAME}
ENV API_SERVER_PORT=${API_SERVER_PORT}
ENV LLAMA_CHECKPOINT_PATH=${LLAMA_CHECKPOINT_PATH}
ENV DECODER_CHECKPOINT_PATH=${DECODER_CHECKPOINT_PATH}
ENV DECODER_CONFIG_NAME=${DECODER_CONFIG_NAME}
# Create server entrypoint
RUN printf '%s\n' \
'#!/bin/bash' \
'source /app/common.sh' \
'' \
'log "Starting Fish Speech API Server..."' \
'validate_env' \
'' \
'DEVICE_ARGS=$(build_device_args)' \
'COMPILE_ARGS=$(build_compile_args "$@")' \
'' \
'log "Device args: ${DEVICE_ARGS:-none}"' \
'log "Compile args: ${COMPILE_ARGS}"' \
'log "Server: ${API_SERVER_NAME}:${API_SERVER_PORT}"' \
'' \
'exec uv run tools/api_server.py \' \
' --listen "${API_SERVER_NAME}:${API_SERVER_PORT}" \' \
' --llama-checkpoint-path "${LLAMA_CHECKPOINT_PATH}" \' \
' --decoder-checkpoint-path "${DECODER_CHECKPOINT_PATH}" \' \
' --decoder-config-name "${DECODER_CONFIG_NAME}" \' \
' ${DEVICE_ARGS} ${COMPILE_ARGS}' \
> /app/start_server.sh && chmod +x /app/start_server.sh
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
CMD curl -f http://localhost:${API_SERVER_PORT}/v1/health || exit 1
ENTRYPOINT ["/app/start_server.sh"]
# Development stage
FROM app-base AS dev
USER root
# Install development tools
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
apt-get update \
&& apt-get install -y --no-install-recommends \
vim \
htop \
strace \
gdb \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
USER ${USER_UID}:${USER_GID}
# Install development dependencies
RUN uv sync --extra ${UV_EXTRA} --dev
# Default to bash for development
ENTRYPOINT ["/bin/bash"]
================================================
FILE: dockerfile.dev
================================================
ARG VERSION=dev
ARG BASE_IMAGE=ghcr.io/fishaudio/fish-speech:${VERSION}
FROM ${BASE_IMAGE}
ARG TOOLS=" \
git \
curl \
build-essential \
ffmpeg \
libsm6 \
libxext6 \
libjpeg-dev \
zlib1g-dev \
aria2 \
zsh \
openssh-server \
sudo \
protobuf-compiler \
libasound-dev \
portaudio19-dev \
libportaudio2 \
libportaudiocpp0 \
cmake"
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
set -ex \
&& apt-get update \
&& apt-get -y install --no-install-recommends ${TOOLS}
# Install oh-my-zsh so your terminal looks nice
RUN sh -c "$(curl https://raw.githubusercontent.com/robbyrussell/oh-my-zsh/master/tools/install.sh)" "" --unattended
# Set zsh as default shell
RUN chsh -s /usr/bin/zsh
ENV SHELL=/usr/bin/zsh
================================================
FILE: docs/CNAME
================================================
speech.fish.audio
================================================
FILE: docs/README.ar.md
================================================
Fish Speech
[English](../README.md) | [简体中文](README.zh.md) | [Portuguese](README.pt-BR.md) | [日本語](README.ja.md) | [한국어](README.ko.md) | **العربية**
> [!IMPORTANT]
> **إشعار الترخيص**
> يتم إصدار قاعدة الأكواد هذه وأوزان النماذج المرتبطة بها تحت **[FISH AUDIO RESEARCH LICENSE](../LICENSE)**. يرجى الرجوع إلى ملف [LICENSE](../LICENSE) لمزيد من التفاصيل.
> [!WARNING]
> **إخلاء المسؤولية القانونية**
> نحن لا نتحمل أي مسؤولية عن أي استخدام غير قانوني لقاعدة الأكواد. يرجى الرجوع إلى القوانين المحلية المتعلقة بـ DMCA والقوانين الأخرى ذات الصلة.
## البداية السريعة
### روابط التوثيق
هذا هو التوثيق الرسمي لـ Fish Audio S2، يرجى اتباع التعليمات للبدء بسهولة.
- [التثبيت](https://speech.fish.audio/ar/install/)
- [الاستدلال عبر خط الأوامر](https://speech.fish.audio/ar/inference/)
- [الاستدلال عبر واجهة الويب](https://speech.fish.audio/ar/inference/)
- [استدلال الخادم](https://speech.fish.audio/ar/server/)
- [نشر Docker](https://speech.fish.audio/ar/install/)
> [!IMPORTANT]
> **إذا كنت ترغب في استخدام خادم SGLang، فيرجى الرجوع إلى [SGLang-Omni README](https://github.com/sgl-project/sglang-omni/blob/main/sglang_omni/models/fishaudio_s2_pro/README.md).**
### دليل وكيل LLM
```
يرجى قراءة https://speech.fish.audio/ar/install/ أولاً، وتثبيت وتكوين Fish Audio S2 وفقاً للوثائق.
```
## Fish Audio S2 Pro
**نظام تحويل النص إلى كلام (TTS) متعدد اللغات الرائد في الصناعة، والذي يعيد تعريف حدود توليد الصوت.**
Fish Audio S2 Pro هو أحدث طراز متعدد الوسائط تم تطويره بواسطة [Fish Audio](https://fish.audio/). تم تدريبه على أكثر من **10 ملايين ساعة** من البيانات الصوتية الهائلة، التي تغطي أكثر من **80 لغة** حول العالم. من خلال بنية **ثنائية الانحدار الذاتي (Dual-AR)** المبتكرة وتقنية توافق التعلم التعزيزي (RL)، يمكن لـ S2 Pro توليد كلام يتمتع بإحساس طبيعي وواقعي وعمق عاطفي كبير، مما يجعله رائداً في المنافسة بين الأنظمة المفتوحة والمغلقة المصدر.
تكمن القوة الضاربة لـ S2 Pro في دعمه للتحكم الدقيق للغاية في النبرة والعاطفة على مستوى **ما دون الكلمة (Sub-word Level)** من خلال وسوم اللغة الطبيعية (مثل `[whisper]` و `[excited]` و `[angry]`) ، مع دعم أصلي لتوليد متحدثين متعددين وحوارات متعددة الجولات بسياق طويل جداً.
تفضل بزيارة [موقع Fish Audio الرسمي](https://fish.audio/) الآن لتجربة العرض المباشر، أو اقرأ [تقريرنا الفني](https://arxiv.org/abs/2603.08823) و[مقال المدونة](https://fish.audio/blog/fish-audio-open-sources-s2/) للتعرف على المزيد.
### متغيرات النموذج
| النموذج | الحجم | التوفر | الوصف |
|------|------|-------------|-------------|
| S2-Pro | 4 مليار معلمة | [HuggingFace](https://huggingface.co/fishaudio/s2-pro) | النموذج الرائد كامل الميزات، مع أعلى جودة واستقرار |
لمزيد من التفاصيل حول النماذج، يرجى مراجعة [التقرير الفني](https://arxiv.org/abs/2411.01156).
## نتائج الاختبارات المرجعية (Benchmarks)
| الاختبار | Fish Audio S2 |
|------|------|
| Seed-TTS Eval — WER (الصينية) | **0.54%** (الأفضل إجمالاً) |
| Seed-TTS Eval — WER (الإنجليزية) | **0.99%** (الأفضل إجمالاً) |
| Audio Turing Test (مع التعليمات) | **0.515** متوسط خلفي (Posterior mean) |
| EmergentTTS-Eval — معدل الفوز | **81.88%** (الأعلى إجمالاً) |
| Fish Instruction Benchmark — TAR | **93.3%** |
| Fish Instruction Benchmark — الجودة | **4.51 / 5.0** |
| متعدد اللغات (MiniMax Testset) — أفضل WER | **11** لغة من أصل **24** |
| متعدد اللغات (MiniMax Testset) — أفضل SIM | **17** لغة من أصل **24** |
في تقييم Seed-TTS، حقق S2 أقل معدل خطأ في الكلمات (WER) بين جميع النماذج التي تم تقييمها (بما في ذلك الأنظمة مغلقة المصدر): Qwen3-TTS (0.77/1.24)، و MiniMax Speech-02 (0.99/1.90)، و Seed-TTS (1.12/2.25). وفي اختبار Audio Turing Test، سجل S2 قيمة 0.515 بزيادة قدرها 24% مقارنة بـ Seed-TTS (0.417) و 33% مقارنة بـ MiniMax-Speech (0.387). وفي EmergentTTS-Eval، تميز S2 بشكل خاص في أبعاد مثل اللغويات المصاحبة (معدل فوز 91.61%)، والجمل الاستفهامية (84.41%)، والتعقيد النحوي (83.39%).
## أبرز المميزات
### تحكم دقيق للغاية عبر اللغة الطبيعية
يمنح S2 Pro الصوت "روحاً" لا مثيل لها. من خلال صيغة `[tag]` البسيطة، يمكنك تضمين تعليمات عاطفية بدقة في أي موضع من النص.
- **دعم أكثر من 15,000 وسم فريد**: لا يقتصر على الإعدادات المسبقة الثابتة، بل يدعم **أوصاف النص الحر**. يمكنك تجربة `[whisper in small voice]` (همس بصوت منخفض)، أو `[professional broadcast tone]` (نبرة إذاعية احترافية)، أو `[pitch up]` (رفع طبقة الصوت).
- **مكتبة عواطف غنية**:
`[pause]` `[emphasis]` `[laughing]` `[inhale]` `[chuckle]` `[tsk]` `[singing]` `[excited]` `[laughing tone]` `[interrupting]` `[chuckling]` `[excited tone]` `[volume up]` `[echo]` `[angry]` `[low volume]` `[sigh]` `[low voice]` `[whisper]` `[screaming]` `[shouting]` `[loud]` `[surprised]` `[short pause]` `[exhale]` `[delight]` `[panting]` `[audience laughter]` `[with strong accent]` `[volume down]` `[clearing throat]` `[sad]` `[moaning]` `[shocked]`
### بنية مبتكرة ثنائية الانحدار الذاتي (Dual-Autoregressive)
يعتمد S2 Pro بنية Dual-AR بنظام "رئيسي-تابع"، تتكون من Decoder-only Transformer وترميز صوتي RVQ (10 قواميس أكواد، بمعدل إطارات يبلغ حوالي 21 هرتز):
- **Slow AR (4 مليار معلمة)**: يعمل على طول المحور الزمني، ويتنبأ بقاموس الأكواد الدلالي الأساسي.
- **Fast AR (400 مليون معلمة)**: يولد الـ 9 قواميس المتبقية في كل خطوة زمنية، لاستعادة أدق التفاصيل الصوتية ببراعة.
يحقق هذا التصغير غير المتماثل أقصى درجات الدقة الصوتية مع زيادة سرعة الاستدلال بشكل كبير.
### توافق التعلم التعزيزي (RL Alignment)
يستخدم S2 Pro تقنية **Group Relative Policy Optimization (GRPO)** للتوافق بعد التدريب. نستخدم نفس مجموعة النماذج المستخدمة في تنظيف البيانات وتصنيفها مباشرة كنماذج مكافأة (Reward Model)، مما يحل بشكل مثالي مشكلة عدم التطابق بين توزيع بيانات ما قبل التدريب وأهداف ما بعد التدريب.
- **إشارات مكافأة متعددة الأبعاد**: تقييم شامل للدقة الدلالية، والقدرة على اتباع التعليمات، وتسجيل التفضيل الصوتي، وتماثل نبرة الصوت، لضمان أن كل ثانية من الكلام المولد تتوافق مع الحدس البشري.
### أداء استدلال تدفقي فائق (يعتمد على SGLang)
نظراً لأن بنية Dual-AR تتماثل هيكلياً مع بنية LLM القياسية، فإن S2 Pro يدعم أصلاً جميع ميزات تسريع الاستدلال في SGLang، بما في ذلك الدفعات المستمرة (Continuous Batching)، و Paged KV Cache، و CUDA Graph، والتخزين المؤقت للبادئة القائم على RadixAttention.
**أداء وحدة معالجة رسومات NVIDIA H200 واحدة:**
- **عامل الوقت الحقيقي (RTF)**: 0.195
- **تأخر الصوت الأول (TTFA)**: حوالي 100 مللي ثانية
- **إنتاجية فائقة السرعة**: تصل إلى 3000+ وسم صوتي/ثانية مع الحفاظ على RTF < 0.5
### دعم قوي للغات المتعددة
يدعم S2 Pro أكثر من 80 لغة، مما يتيح تركيباً عالياً الجودة دون الحاجة إلى وحدات صوتية (phonemes) أو معالجة محددة لكل لغة:
- **المستوى الأول (Tier 1)**: اليابانية (ja)، الإنجليزية (en)، الصينية (zh)
- **المستوى الثاني (Tier 2)**: الكورية (ko)، الإسبانية (es)، البرتغالية (pt)، العربية (ar)، الروسية (ru)، الفرنسية (fr)، الألمانية (de)
- **تغطية عالمية**: sv, it, tr, no, nl, cy, eu, ca, da, gl, ta, hu, fi, pl, et, hi, la, ur, th, vi, jw, bn, yo, xsl, cs, sw, nn, he, ms, uk, id, kk, bg, lv, my, tl, sk, ne, fa, af, el, bo, hr, ro, sn, mi, yi, am, be, km, is, az, sd, br, sq, ps, mn, ht, ml, sr, sa, te, ka, bs, pa, lt, kn, si, hy, mr, as, gu, fo والمزيد.
### توليد متحدثين متعددين أصلي
يسمح Fish Audio S2 للمستخدمين بتحميل عينة مرجعية تحتوي على متحدثين متعددين، وسيقوم النموذج بمعالجة ميزات كل متحدث عبر وسم `<|speaker:i|>`. بعد ذلك، يمكنك التحكم في أداء النموذج عبر وسم معرف المتحدث، مما يتيح لتوليد واحد أن يتضمن متحدثين متعددين. لم تعد هناك حاجة لتحميل عينة مرجعية منفصلة وتوليد صوت لكل متحدث على حدة كما كان في السابق.
### توليد حوارات متعددة الجولات
بفضل توسيع سياق النموذج، يمكن لنموذجنا الآن الاستفادة من المعلومات السابقة لتحسين التعبير في المحتوى المولد لاحقاً، مما يعزز من طبيعية المحتوى.
### استنساخ الصوت السريع
يدعم Fish Audio S2 استنساخاً دقيقاً للصوت باستخدام عينات مرجعية قصيرة (عادةً 10-30 ثانية). يلتقط النموذج نبرة الصوت وأسلوب الكلام والميول العاطفية، مما يولد أصواتاً مستنسخة واقعية ومتسقة دون الحاجة إلى ضبط دقيق إضافي.
لاستخدام خادم SGLang، يرجى الرجوع إلى [SGLang-Omni README](https://github.com/sgl-project/sglang-omni/blob/main/sglang_omni/models/fishaudio_s2_pro/README.md).
---
## شكر وتقدير
- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
- [GPT VITS](https://github.com/innnky/gpt-vits)
- [MQTTS](https://github.com/b04901014/MQTTS)
- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
- [Qwen3](https://github.com/QwenLM/Qwen3)
## التقرير الفني
```bibtex
@misc{fish-speech-v1.4,
title={Fish-Speech: Leveraging Large Language Models for Advanced Multilingual Text-to-Speech Synthesis},
author={Shijia Liao and Yuxuan Wang and Tianyu Li and Yifan Cheng and Ruoyi Zhang and Rongzhi Zhou and Yijin Xing},
year={2024},
eprint={2411.01156},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2411.01156},
}
@misc{liao2026fishaudios2technical,
title={Fish Audio S2 Technical Report},
author={Shijia Liao and Yuxuan Wang and Songting Liu and Yifan Cheng and Ruoyi Zhang and Tianyu Li and Shidong Li and Yisheng Zheng and Xingwei Liu and Qingzheng Wang and Zhizhuo Zhou and Jiahua Liu and Xin Chen and Dawei Han},
year={2026},
eprint={2603.08823},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2603.08823},
}
```
================================================
FILE: docs/README.ja.md
================================================
Fish Speech
[English](../README.md) | [简体中文](README.zh.md) | [Portuguese](README.pt-BR.md) | **日本語** | [한국어](README.ko.md) | [العربية](README.ar.md)
> [!IMPORTANT]
> **ライセンス注意事項**
> このコードベースおよび関連するモデルウェイトは **[FISH AUDIO RESEARCH LICENSE](../LICENSE)** の下でリリースされています。詳細については [LICENSE](../LICENSE) をご参照ください。
> [!WARNING]
> **法的免責事項**
> 私たちはコードベースの不法な使用について一切の責任を負いません。DMCA 及びその他の関連法律について、現地の法律をご参照ください。
## クイックスタート
### ドキュメント入口
Fish Audio S2 の公式ドキュメントです。以下からすぐに始められます。
- [インストール](https://speech.fish.audio/ja/install/)
- [コマンドライン推論](https://speech.fish.audio/ja/inference/)
- [WebUI 推論](https://speech.fish.audio/ja/inference/)
- [サーバー推論](https://speech.fish.audio/ja/server/)
- [Docker デプロイ](https://speech.fish.audio/ja/install/)
> [!IMPORTANT]
> **SGLang サーバーについては [SGLang-Omni README](https://github.com/sgl-project/sglang-omni/blob/main/sglang_omni/models/fishaudio_s2_pro/README.md) を参照してください。**
### LLM Agent 指南
```
https://speech.fish.audio/ja/install/ の手順に従って、Fish Audio S2 をインストール・設定してください。
```
## Fish Audio S2 Pro
**業界最先端の多言語テキスト読み上げ (TTS) システム。音声生成の限界を再定義します。**
Fish Audio S2 Pro は [Fish Audio](https://fish.audio/) が開発した最高峰のマルチモーダルモデルです。世界 **80 言語以上**、**1,000 万時間** を超える膨大な音声データで学習されています。革新的な **二重自己回帰 (Dual-AR)** アーキテクチャと強化学習 (RL) アライメント技術を組み合わせることで、極めて自然でリアル、かつ感情豊かな音声を生成し、オープンソースおよびクローズドソースの双方でリーダーシップを発揮しています。
S2 Pro の最大の特徴は、自然言語タグ(例:`[whisper]`、`[excited]`、`[angry]`)による韻律や感情の **サブワードレベル (Sub-word Level)** での極めて細やかなインライン制御が可能である点です。また、マルチスピーカー生成や長文コンテキストのマルチターン対話生成にもネイティブ対応しています。
今すぐ [Fish Audio 公式サイト](https://fish.audio/) でプレイグラウンドを体験するか、[技術レポート](https://arxiv.org/abs/2603.08823) や [ブログ記事](https://fish.audio/blog/fish-audio-open-sources-s2/) を読んで詳細を確認してください。
### モデルバリアント
| モデル | サイズ | 利用可能性 | 説明 |
|------|------|-------------|-------------|
| S2-Pro | 4B パラメータ | [HuggingFace](https://huggingface.co/fishaudio/s2-pro) | 品質と安定性を最大化した、フル機能のフラッグシップモデル |
モデルの詳細は[技術レポート](https://arxiv.org/abs/2411.01156)をご参照ください。
## ベンチマーク結果
| ベンチマーク | Fish Audio S2 |
|------|------|
| Seed-TTS Eval — WER(中国語) | **0.54%**(全体最良) |
| Seed-TTS Eval — WER(英語) | **0.99%**(全体最良) |
| Audio Turing Test(指示あり) | **0.515** 事後平均値 |
| EmergentTTS-Eval — 勝率 | **81.88%**(全体最高) |
| Fish Instruction Benchmark — TAR | **93.3%** |
| Fish Instruction Benchmark — 品質 | **4.51 / 5.0** |
| 多言語(MiniMax Testset)— 最良 WER | **24 言語中 11 言語** |
| 多言語(MiniMax Testset)— 最良 SIM | **24 言語中 17 言語** |
Seed-TTS Eval では、S2 はクローズドソースを含む全評価モデルの中で最小 WER を達成しました:Qwen3-TTS(0.77/1.24)、MiniMax Speech-02(0.99/1.90)、Seed-TTS(1.12/2.25)。Audio Turing Test では 0.515 を記録し、Seed-TTS(0.417)比で 24%、MiniMax-Speech(0.387)比で 33% 上回りました。EmergentTTS-Eval では、副言語情報(91.61%)、疑問文(84.41%)、統語的複雑性(83.39%)で特に高い成績を示しています。
## ハイライト
### 自然言語による細粒度インライン制御
S2 Pro は音声にこれまでにない「魂」を宿らせます。シンプルな `[tag]` 構文を使用して、テキスト内の任意の場所に感情の指示を正確に埋め込むことができます。
- **1万5,000以上のユニークタグに対応**:固定のプリセットに限定されず、**自由形式のテキスト記述** をサポートします。`[whisper in small voice]` (ささやき声で), `[professional broadcast tone]` (プロのナレーション風), `[pitch up]` (ピッチを上げる) などを試してみてください。
- **豊富な感情ライブラリ**:
`[pause]` `[emphasis]` `[laughing]` `[inhale]` `[chuckle]` `[tsk]` `[singing]` `[excited]` `[laughing tone]` `[interrupting]` `[chuckling]` `[excited tone]` `[volume up]` `[echo]` `[angry]` `[low volume]` `[sigh]` `[low voice]` `[whisper]` `[screaming]` `[shouting]` `[loud]` `[surprised]` `[short pause]` `[exhale]` `[delight]` `[panting]` `[audience laughter]` `[with strong accent]` `[volume down]` `[clearing throat]` `[sad]` `[moaning]` `[shocked]`
### 革新的な二重自己回帰 (Dual-Autoregressive) アーキテクチャ
S2 Pro は、Decoder-only Transformer と RVQ オーディオコーデック(10 コードブック、約 21 Hz)で構成されるマスター・スレーブ型の Dual-AR アーキテクチャを採用しています:
- **Slow AR (4B パラメータ)**: 時間軸方向に動作し、核となるセマンティックコードブックを予測。
- **Fast AR (400M パラメータ)**: 各時間ステップで残り 9 個の残差コードブックを生成し、極めて繊細な音響ディテールを復元。
この非対称設計により、究極のオーディオ忠実度を維持しながら、推論速度を大幅に向上させています。
### 強化学習 (RL) アライメント
S2 Pro は、事後学習アライメントに **Group Relative Policy Optimization (GRPO)** 技術を採用しています。データのクリーニングとアノテーションに使用したモデルセットをそのまま報酬モデル (Reward Model) として使用することで、事前学習データの分布と事後学習の目標との間のミスマッチを完璧に解決しました。
- **多次元の報酬信号**: 意味の正確性、指示追従性、音響的な好み、音色の類似性を総合的に評価し、生成される一秒一秒の音声が人間の直感に沿うようにしています。
### SGLang による究極のストリーミング推論性能
Dual-AR アーキテクチャは標準的な LLM 構造と同型であるため、S2 Pro は SGLang のすべての推論加速機能をネイティブにサポートしています。これには、Continuous Batching、Paged KV Cache、CUDA Graph、RadixAttention ベースの Prefix Caching が含まれます。
**NVIDIA H200 GPU 1枚でのパフォーマンス表現:**
- **リアルタイム係数 (RTF)**: 0.195
- **初回音声出力までの時間 (TTFA)**: 約 100 ms
- **極速スループット**: RTF < 0.5 を維持しつつ 3,000+ acoustic tokens/s
### 強力な多言語サポート
S2 Pro は 80 以上の言語をサポートしており、音素や特定の言語に対する前処理なしで高品質な合成を実現します:
- **第1層 (Tier 1)**: 日本語 (ja), 英語 (en), 中国語 (zh)
- **第2層 (Tier 2)**: 韓国語 (ko), スペイン語 (es), ポルトガル語 (pt), アラビア語 (ar), ロシア語 (ru), フランス語 (fr), ドイツ語 (de)
- **グローバルカバレッジ**: sv, it, tr, no, nl, cy, eu, ca, da, gl, ta, hu, fi, pl, e!t, hi, la, ur, th, vi, jw, bn, yo, xsl, cs, sw, nn, he, ms, uk, id, kk, bg, lv, my, tl, sk, ne, fa, af, el, bo, hr, ro, sn, mi, yi, am, be, km, is, az, sd, br, sq, ps, mn, ht, ml, sr, sa, te, ka, bs, pa, lt, kn, si, hy, mr, as, gu, fo など。
### ネイティブなマルチスピーカー生成
Fish Audio S2 では、複数のスピーカーを含む参照オーディオをアップロードでき、モデルは `<|speaker:i|>` トークンを介して各スピーカーの特徴を処理します。スピーカー ID トークンを使用してモデルの出力を制御することで、1回の生成に複数のスピーカーを混在させることが可能です。個別のスピーカーごとに参照オーディオをアップロードし直す手間はもう不要です。
### マルチターン対話生成
コンテキストの拡張により、以前のターンの情報を利用して後続の生成内容の表現力を高めることができ、対話としての自然さが大幅に向上しました。
### 高速音声クローニング
Fish Audio S2 は、短い参照サンプル(通常 10〜30 秒)を使用した正確な音声クローニングをサポートしています。モデルは音色、話し方、感情を捉え、追加の微調整なしでリアルで一貫したクローン音声を生成します。
SGLang サーバーの利用については、[SGLang-Omni README](https://github.com/sgl-project/sglang-omni/blob/main/sglang_omni/models/fishaudio_s2_pro/README.md) を参照してください。
---
## 謝辞
- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
- [GPT VITS](https://github.com/innnky/gpt-vits)
- [MQTTS](https://github.com/b04901014/MQTTS)
- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
- [Qwen3](https://github.com/QwenLM/Qwen3)
## 技術レポート
```bibtex
@misc{fish-speech-v1.4,
title={Fish-Speech: Leveraging Large Language Models for Advanced Multilingual Text-to-Speech Synthesis},
author={Shijia Liao and Yuxuan Wang and Tianyu Li and Yifan Cheng and Ruoyi Zhang and Rongzhi Zhou and Yijin Xing},
year={2024},
eprint={2411.01156},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2411.01156},
}
@misc{liao2026fishaudios2technical,
title={Fish Audio S2 Technical Report},
author={Shijia Liao and Yuxuan Wang and Songting Liu and Yifan Cheng and Ruoyi Zhang and Tianyu Li and Shidong Li and Yisheng Zheng and Xingwei Liu and Qingzheng Wang and Zhizhuo Zhou and Jiahua Liu and Xin Chen and Dawei Han},
year={2026},
eprint={2603.08823},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2603.08823},
}
```
================================================
FILE: docs/README.ko.md
================================================
Fish Speech
[English](../README.md) | [简体中文](README.zh.md) | [Portuguese](README.pt-BR.md) | [日本語](README.ja.md) | **한국어** | [العربية](README.ar.md)
> [!IMPORTANT]
> **라이선스 고지**
> 이 코드베이스 및 관련 모델 가중치는 **[FISH AUDIO RESEARCH LICENSE](../LICENSE)** 에 따라 배포됩니다. 자세한 내용은 [LICENSE](../LICENSE)를 참조하십시오.
> [!WARNING]
> **법적 면책 조항**
> 당사는 코드베이스의 불법적인 사용에 대해 어떠한 책임도 지지 않습니다. 해당 지역의 DMCA 및 기타 관련 법률을 참조하십시오.
## 빠른 시작
### 문서 입구
Fish Audio S2의 공식 문서입니다. 지침에 따라 쉽게 시작하십시오.
- [설치](https://speech.fish.audio/ko/install/)
- [명령줄 추론](https://speech.fish.audio/ko/inference/)
- [WebUI 추론](https://speech.fish.audio/ko/inference/)
- [서버 추론](https://speech.fish.audio/ko/server/)
- [Docker 배포](https://speech.fish.audio/ko/install/)
> [!IMPORTANT]
> **SGLang 서버를 사용하려면 [SGLang-Omni README](https://github.com/sgl-project/sglang-omni/blob/main/sglang_omni/models/fishaudio_s2_pro/README.md)를 참조하십시오.**
### LLM Agent 가이드
```
먼저 https://speech.fish.audio/ko/install/ 을 읽고 문서에 따라 Fish Audio S2를 설치 및 구성하십시오.
```
## Fish Audio S2 Pro
**음성 생성의 경계를 재정의하는 업계 최고의 다국어 텍스트 음성 변환(TTS) 시스템.**
Fish Audio S2 Pro는 [Fish Audio](https://fish.audio/)에서 개발한 최첨단 멀티모달 모델입니다. 전 세계 **80개 이상의 언어**를 아우르는 **1,000만 시간** 이상의 방대한 오디오 데이터로 학습되었습니다. 혁신적인 **이중 자기회귀(Dual-AR)** 아키텍처와 강화 학습(RL) 정렬 기술을 통해 S2 Pro는 극도로 자연스럽고 사실적이며 감정이 풍부한 음성을 생성하며, 오픈 소스와 클ローズ드 소스 경쟁 모두에서 선두를 달리고 있습니다.
S2 Pro의 핵심 강점은 자연어 태그(예: `[whisper]`, `[excited]`, `[angry]`)를 통해 운율과 감정을 **하위 단어 수준(Sub-word Level)**에서 매우 세밀하게 인라인 제어할 수 있다는 점입니다. 또한 다중 화자 생성 및 긴 컨텍스트의 다중 턴 대화 생성을 기본적으로 지원합니다.
지금 바로 [Fish Audio 공식 웹사이트](https://fish.audio/)에서 온라인 데모를 체험하거나, [기술 보고서](https://arxiv.org/abs/2603.08823) 및 [블로그 게시물](https://fish.audio/blog/fish-audio-open-sources-s2/)을 통해 자세히 알아보십시오.
### 모델 변체
| 모델 | 크기 | 가용성 | 설명 |
|------|------|-------------|-------------|
| S2-Pro | 4B 파라미터 | [HuggingFace](https://huggingface.co/fishaudio/s2-pro) | 최고의 품질과 안정성을 갖춘 모든 기능을 갖춘 플래그십 모델 |
모델에 대한 자세한 내용은 [기술 보고서](https://arxiv.org/abs/2411.01156)를 참조하십시오.
## 벤치마크 결과
| 벤치마크 | Fish Audio S2 |
|------|------|
| Seed-TTS Eval — WER(중국어) | **0.54%** (전체 최고) |
| Seed-TTS Eval — WER(영어) | **0.99%** (전체 최고) |
| Audio Turing Test (지침 포함) | **0.515** 후험 평균 |
| EmergentTTS-Eval — 승률 | **81.88%** (전체 최고) |
| Fish Instruction Benchmark — TAR | **93.3%** |
| Fish Instruction Benchmark — 품질 | **4.51 / 5.0** |
| 다국어 (MiniMax Testset) — 최고 WER | **24개 언어 중 11개** |
| 다국어 (MiniMax Testset) — 최고 SIM | **24개 언어 중 17개** |
Seed-TTS Eval에서 S2는 클ローズ드 소스 시스템을 포함한 모든 평가 모델 중 가장 낮은 WER을 달성했습니다: Qwen3-TTS (0.77/1.24), MiniMax Speech-02 (0.99/1.90), Seed-TTS (1.12/2.25). Audio Turing Test에서 S2의 0.515는 Seed-TTS (0.417) 대비 24%, MiniMax-Speech (0.387) 대비 33% 향상된 수치입니다. EmergentTTS-Eval에서 S2는 부차 언어학(91.61% 승률), 의문문(84.41%), 구문 복잡성(83.39%) 등의 측면에서 특히 두드러진 성과를 보였습니다.
## 하이라이트
### 자연어를 통한 초미세 인라인 제어
S2 Pro는 음성에 전례 없는 "영혼"을 부여합니다. 간단한 `[tag]` 구문을 사용하여 텍스트의 어느 위치에나 감정 지침을 정확하게 삽입할 수 있습니다.
- **15,000개 이상의 고유 태그 지원**: 고정된 사전 설정에 국한되지 않고 **자유 형식의 텍스트 설명**을 지원합니다. `[whisper in small voice]` (작은 목소리로 속삭임), `[professional broadcast tone]` (전문 방송 톤), `[pitch up]` (음높이 높임) 등을 시도해 보십시오.
- **풍부한 감정 라이브러리**:
`[pause]` `[emphasis]` `[laughing]` `[inhale]` `[chuckle]` `[tsk]` `[singing]` `[excited]` `[laughing tone]` `[interrupting]` `[chuckling]` `[excited tone]` `[volume up]` `[echo]` `[angry]` `[low volume]` `[sigh]` `[low voice]` `[whisper]` `[screaming]` `[shouting]` `[loud]` `[surprised]` `[short pause]` `[exhale]` `[delight]` `[panting]` `[audience laughter]` `[with strong accent]` `[volume down]` `[clearing throat]` `[sad]` `[moaning]` `[shocked]`
### 혁신적인 이중 자기회귀 (Dual-Autoregressive) 아키텍처
S2 Pro는 Decoder-only Transformer와 RVQ 오디오 코덱(10개 코드북, 약 21Hz 프레임 속도)으로 구성된 마스터-슬레이브 방식의 Dual-AR 아키텍처를 채택했습니다.
- **Slow AR (4B 파라미터)**: 시간 축을 따라 작동하며 핵심 의미 코드북을 예측합니다.
- **Fast AR (400M 파라미터)**: 각 타임스텝에서 나머지 9개의 잔차 코드북을 생성하여 극도로 정교한 음향 세부 사항을 복원합니다.
이러한 비대칭 설계는 오디오의 최고 충실도를 보장하는 동시에 추론 속도를 대폭 향상시킵니다.
### 강화 학습 (RL) 정렬
S2 Pro는 사후 학습 정렬을 위해 **Group Relative Policy Optimization (GRPO)** 기술을 채택했습니다. 데이터 정제 및 주석 처리에 사용된 것과 동일한 모델 세트를 보상 모델(Reward Model)로 직접 사용함으로써 사전 학습 데이터 분포와 사후 학습 목표 간의 불일치 문제를 완벽하게 해결했습니다.
- **다차원 보상 신호**: 의미 체계의 정확성, 지침 준수 능력, 음향 선호도 점수 및 음색 유사성을 종합적으로 평가하여 생성된 음성의 매초가 인간의 직관에 부합하도록 보장합니다.
### SGLang 기반의 극한 스트리밍 추론 성능
Dual-AR 아키텍처는 표준 LLM 구조와 동형이므로 S2 Pro는 Continuous Batching, Paged KV Cache, CUDA Graph 및 RadixAttention 기반 Prefix Caching을 포함한 SGLang의 모든 추론 가속 기능을 기본적으로 지원합니다.
**단일 NVIDIA H200 GPU 성능 지표:**
- **실시간 계수 (RTF)**: 0.195
- **첫 음성 지연 (TTFA)**: 약 100 ms
- **초고속 처리량**: RTF < 0.5 유지 시 처리량 3,000+ acoustic tokens/s 달성
### 강력한 다국어 지원
S2 Pro는 음소나 특정 언어 처리가 필요 없는 고품질 합성을 80개 이상의 언어에서 지원합니다.
- **1계층 (Tier 1)**: 일본어 (ja), 영어 (en), 중국어 (zh)
- **2계층 (Tier 2)**: 한국어 (ko), 스페인어 (es), 포르투갈어 (pt), 아랍어 (ar), 러시아어 (ru), 프랑스어 (fr), 독일어 (de)
- **글로벌 커버리지**: sv, it, tr, no, nl, cy, eu, ca, da, gl, ta, hu, fi, pl, et, hi, la, ur, th, vi, jw, bn, yo, xsl, cs, sw, nn, he, ms, uk, id, kk, bg, lv, my, tl, sk, ne, fa, af, el, bo, hr, ro, sn, mi, yi, am, be, km, is, az, sd, br, sq, ps, mn, ht, ml, sr, sa, te, ka, bs, pa, lt, kn, si, hy, mr, as, gu, fo 등.
### 네이티브 다중 화자 생성
Fish Audio S2를 사용하면 사용자가 여러 화자가 포함된 참조 오디오를 업로드할 수 있으며, 모델은 `<|speaker:i|>` 토큰을 통해 각 화자의 특징을 처리합니다. 이후 화자 ID 토큰을 사용하여 모델의 표현을 제어함으로써 한 번의 생성에 여러 화자를 포함할 수 있습니다. 더 이상 화자마다 별도의 참조 오디오를 업로드하고 음성을 생성할 필요가 없습니다.
### 다중 턴 대화 생성
모델 컨텍스트 확장에 힘입어 이제 이전 정보의 도움을 받아 후속 생성 내용의 표현력을 높이고 콘텐츠의 자연스러움을 향상시킬 수 있습니다.
### 고속 음성 복제
Fish Audio S2는 짧은 참조 샘플(보통 10-30초)을 사용한 정확한 음성 복제를 지원합니다. 모델은 음색, 말하기 스타일 및 감정적 경향을 포착하여 추가적인 미세 조정 없이도 사실적이고 일관된 복제 음성을 생성합니다.
SGLang 서버 사용에 대해서는 [SGLang-Omni README](https://github.com/sgl-project/sglang-omni/blob/main/sglang_omni/models/fishaudio_s2_pro/README.md)를 참조하십시오.
---
## 감사의 말
- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
- [GPT VITS](https://github.com/innnky/gpt-vits)
- [MQTTS](https://github.com/b04901014/MQTTS)
- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
- [Qwen3](https://github.com/QwenLM/Qwen3)
## 기술 보고서
```bibtex
@misc{fish-speech-v1.4,
title={Fish-Speech: Leveraging Large Language Models for Advanced Multilingual Text-to-Speech Synthesis},
author={Shijia Liao and Yuxuan Wang and Tianyu Li and Yifan Cheng and Ruoyi Zhang and Rongzhi Zhou and Yijin Xing},
year={2024},
eprint={2411.01156},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2411.01156},
}
@misc{liao2026fishaudios2technical,
title={Fish Audio S2 Technical Report},
author={Shijia Liao and Yuxuan Wang and Songting Liu and Yifan Cheng and Ruoyi Zhang and Tianyu Li and Shidong Li and Yisheng Zheng and Xingwei Liu and Qingzheng Wang and Zhizhuo Zhou and Jiahua Liu and Xin Chen and Dawei Han},
year={2026},
eprint={2603.08823},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2603.08823},
}
```
================================================
FILE: docs/README.pt-BR.md
================================================
Fish Speech
[English](../README.md) | [简体中文](README.zh.md) | **Portuguese** | [日本語](README.ja.md) | [한국어](README.ko.md) | [العربية](README.ar.md)
> [!IMPORTANT]
> **Aviso de Licença**
> Este repositório de código e seus pesos de modelo associados são lançados sob a **[FISH AUDIO RESEARCH LICENSE](../LICENSE)**. Consulte [LICENSE](../LICENSE) para obter mais detalhes.
> [!WARNING]
> **Aviso Legal**
> Não nos responsabilizamos por qualquer uso ilegal deste repositório. Consulte as leis locais sobre DMCA e outras regulamentações relevantes.
## Início Rápido
### Links da Documentação
Esta é a documentação oficial do Fish Audio S2, siga as instruções para começar facilmente.
- [Instalação](https://speech.fish.audio/install/)
- [Inferência por Linha de Comando](https://speech.fish.audio/inference/)
- [Inferência por WebUI](https://speech.fish.audio/inference/)
- [Inferência por Servidor](https://speech.fish.audio/server/)
- [Implantação Docker](https://speech.fish.audio/install/)
> [!IMPORTANT]
> **Caso deseje utilizar o SGLang Server, consulte o [SGLang-Omni README](https://github.com/sgl-project/sglang-omni/blob/main/sglang_omni/models/fishaudio_s2_pro/README.md).**
### Guia para Agentes de LLM
```
Leia primeiro https://speech.fish.audio/install/ e siga a documentação para instalar e configurar o Fish Audio S2.
```
## Fish Audio S2 Pro
**O sistema de conversão de texto em fala (TTS) multilíngue líder do setor, redefinindo as fronteiras da geração de voz.**
Fish Audio S2 Pro é o modelo multimodal mais avançado desenvolvido pela [Fish Audio](https://fish.audio/). Treinado em mais de **10 milhões de horas** de dados de áudio massivos, cobrindo mais de **80 idiomas** globais. Através de uma arquitetura inovadora de **Dual-Autoregressive (Dual-AR)** e tecnologia de alinhamento por aprendizado por reforço (RL), o S2 Pro é capaz de gerar fala com um senso de naturalidade, realismo e riqueza emocional extremos, liderando tanto em competições de código aberto quanto proprietário.
O grande diferencial do S2 Pro reside em seu suporte para controle inline de granularidade ultra-fina de prosódia e emoção ao nível de **sub-palavra (Sub-word Level)** via tags de linguagem natural (como `[whisper]`, `[excited]`, `[angry]`), além de suporte nativo para múltiplos falantes e geração de diálogos de múltiplos turnos com contexto ultra-longo.
Visite agora o [site oficial da Fish Audio](https://fish.audio/) para experimentar a demonstração online, ou leia nosso [relatório técnico](https://arxiv.org/abs/2603.08823) e [artigo no blog](https://fish.audio/blog/fish-audio-open-sources-s2/) para saber mais.
### Variantes de Modelo
| Modelo | Tamanho | Disponibilidade | Descrição |
|------|------|-------------|-------------|
| S2-Pro | 4B parâmetros | [HuggingFace](https://huggingface.co/fishaudio/s2-pro) | Modelo flagship completo, com máxima qualidade e estabilidade |
Para mais detalhes sobre os modelos, consulte o [relatório técnico](https://arxiv.org/abs/2411.01156).
## Resultados de Benchmark
| Benchmark | Fish Audio S2 |
|------|------|
| Seed-TTS Eval — WER (Chinês) | **0.54%** (Melhor geral) |
| Seed-TTS Eval — WER (Inglês) | **0.99%** (Melhor geral) |
| Audio Turing Test (Com instrução) | **0.515** Média posterior |
| EmergentTTS-Eval — Taxa de Vitória | **81.88%** (Maior geral) |
| Fish Instruction Benchmark — TAR | **93.3%** |
| Fish Instruction Benchmark — Qualidade | **4.51 / 5.0** |
| Multilíngue (MiniMax Testset) — Melhor WER | **11 de 24** idiomas |
| Multilíngue (MiniMax Testset) — Melhor SIM | **17 de 24** idiomas |
No Seed-TTS Eval, o S2 alcançou o menor WER entre todos os modelos avaliados (incluindo sistemas proprietários): Qwen3-TTS (0.77/1.24), MiniMax Speech-02 (0.99/1.90), Seed-TTS (1.12/2.25). No Audio Turing Test, o valor de 0.515 do S2 representa um aumento de 24% em relação ao Seed-TTS (0.417) e 33% em relação ao MiniMax-Speech (0.387). No EmergentTTS-Eval, o S2 destacou-se especialmente em dimensões como paralinguística (taxa de vitória de 91.61%), frases interrogativas (84.41%) e complexidade sintática (83.39%).
## Destaques
### Controle Inline de Granularidade Ultra-Fina via Linguagem Natural
S2 Pro confere à voz uma "espiritualidade" sem precedentes. Através de uma sintaxe simples de `[tag]`, você pode inserir instruções emocionais precisamente em qualquer posição do texto.
- **Suporte para mais de 15.000 tags únicas**: Não limitado a predefinições fixas, suporta **descrições textuais de formato livre**. Você pode tentar `[whisper in small voice]` (sussurrando), `[professional broadcast tone]` (tom de locução profissional) ou `[pitch up]` (aumentar o tom).
- **Rica biblioteca de emoções**:
`[pause]` `[emphasis]` `[laughing]` `[inhale]` `[chuckle]` `[tsk]` `[singing]` `[excited]` `[laughing tone]` `[interrupting]` `[chuckling]` `[excited tone]` `[volume up]` `[echo]` `[angry]` `[low volume]` `[sigh]` `[low voice]` `[whisper]` `[screaming]` `[shouting]` `[loud]` `[surprised]` `[short pause]` `[exhale]` `[delight]` `[panting]` `[audience laughter]` `[with strong accent]` `[volume down]` `[clearing throat]` `[sad]` `[moaning]` `[shocked]`
### Arquitetura Inovadora Dual-Autoregressive (Dual-AR)
S2 Pro adota uma arquitetura Dual-AR mestre-escravo, consistindo de um Decoder-only Transformer e um codec de áudio RVQ (10 codebooks, cerca de 21 Hz de taxa de frames):
- **Slow AR (4B parâmetros)**: Atua ao longo do eixo temporal, prevendo o codebook semântico central.
- **Fast AR (400M parâmetros)**: Gera os 9 codebooks residuais restantes em cada passo de tempo, restaurando detalhes acústicos extremos com delicadeza.
Este design assimétrico garante fidelidade extrema ao áudio enquanto aumenta significativamente a velocidade de inferência.
### Alinhamento por Aprendizado por Reforço (RL Alignment)
S2 Pro utiliza a tecnologia **Group Relative Policy Optimization (GRPO)** para o alinhamento pós-treinamento. Utilizamos o mesmo conjunto de modelos para limpeza e anotação de dados diretamente como modelos de recompensa (Reward Model), resolvendo perfeitamente o problema de descasamento entre a distribuição dos dados de pré-treinamento e os objetivos de pós-treinamento.
- **Sinais de recompensa multidimensionais**: Avalia de forma abrangente a precisão semântica, a capacidade de seguir instruções, a pontuação de preferência acústica e a similaridade de timbre, garantindo que cada segundo de fala gerada esteja alinhado com a intuição humana.
### Desempenho de Inferência de Streaming Extremo (Baseado em SGLang)
Como a arquitetura Dual-AR é estruturalmente isomorfa à estrutura padrão de LLMs, o S2 Pro suporta nativamente todos os recursos de aceleração de inferência do SGLang, incluindo loteamento contínuo (Continuous Batching), Paged KV Cache, CUDA Graph e cache de prefixo baseado em RadixAttention.
**Desempenho em uma única GPU NVIDIA H200:**
- **Fator em Tempo Real (RTF)**: 0.195
- **Latência do Primeiro Áudio (TTFA)**: aprox. 100 ms
- **Taxa de Transferência Ultrarrápida**: Alcance de 3.000+ acoustic tokens/s mantendo RTF < 0.5
### Poderoso Suporte Multilíngue
S2 Pro suporta mais de 80 idiomas, possibilitando síntese de alta qualidade sem a necessidade de fonemas ou processamento específico por idioma:
- **Tier 1**: Japonês (ja), Inglês (en), Chinês (zh)
- **Tier 2**: Coreano (ko), Espanhol (es), Português (pt), Árabe (ar), Russo (ru), Francês (fr), Alemão (de)
- **Cobertura Global**: sv, it, tr, no, nl, cy, eu, ca, da, gl, ta, hu, fi, pl, et, hi, la, ur, th, vi, jw, bn, yo, xsl, cs, sw, nn, he, ms, uk, id, kk, bg, lv, my, tl, sk, ne, fa, af, el, bo, hr, ro, sn, mi, yi, am, be, km, is, az, sd, br, sq, ps, mn, ht, ml, sr, sa, te, ka, bs, pa, lt, kn, si, hy, mr, as, gu, fo, etc.
### Geração Nativa Multi-falante
O Fish Audio S2 permite que os usuários enviem áudio de referência contendo múltiplos falantes, e o modelo processará as características de cada falante via o token `<|speaker:i|>`. Em seguida, você pode controlar o desempenho do modelo através do token de ID do falante, permitindo incluir múltiplos falantes em uma única geração. Não é mais necessário enviar áudios de referência separadamente para cada falante.
### Geração de Diálogos Multiturnos
Graças à expansão do contexto do modelo, nosso modelo agora pode aproveitar as informações prévias para aumentar a expressividade dos conteúdos gerados subsequentemente, elevando assim a naturalidade dos diálogos.
### Clonagem de Voz Rápida
O Fish Audio S2 suporta clonagem de voz precisa usando curtas amostras de referência (normalmente 10-30 segundos). O modelo captura o timbre, o estilo de fala e as tendências emocionais, gerando vozes clonadas realistas e consistentes sem necessidade de ajustes finos adicionais.
Caso deseje utilizar o SGLang Server, consulte o [SGLang-Omni README](https://github.com/sgl-project/sglang-omni/blob/main/sglang_omni/models/fishaudio_s2_pro/README.md).
---
## Agradecimentos
- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
- [GPT VITS](https://github.com/innnky/gpt-vits)
- [MQTTS](https://github.com/b04901014/MQTTS)
- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
- [Qwen3](https://github.com/QwenLM/Qwen3)
## Relatório Técnico
```bibtex
@misc{fish-speech-v1.4,
title={Fish-Speech: Leveraging Large Language Models for Advanced Multilingual Text-to-Speech Synthesis},
author={Shijia Liao and Yuxuan Wang and Tianyu Li and Yifan Cheng and Ruoyi Zhang and Rongzhi Zhou and Yijin Xing},
year={2024},
eprint={2411.01156},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2411.01156},
}
@misc{liao2026fishaudios2technical,
title={Fish Audio S2 Technical Report},
author={Shijia Liao and Yuxuan Wang racing Songting Liu and Yifan Cheng and Ruoyi Zhang and Tianyu Li and Shidong Li and Yisheng Zheng and Xingwei Liu and Qingzheng Wang and Zhizhuo Zhou and Jiahua Liu and Xin Chen and Dawei Han},
year={2026},
eprint={2603.08823},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2603.08823},
}
```
================================================
FILE: docs/README.zh.md
================================================
Fish Speech
[English](../README.md) | **简体中文** | [Portuguese](README.pt-BR.md) | [日本語](README.ja.md) | [한국어](README.ko.md) | [العربية](README.ar.md)
> [!IMPORTANT]
> **许可证声明**
> 此代码库及其相关的模型权重均在 **[FISH AUDIO RESEARCH LICENSE](../LICENSE)** 下发布。更多详情请参考 [LICENSE](../LICENSE)。
> [!WARNING]
> **法律免责声明**
> 我们不对代码库的任何非法使用承担责任。请参考您当地关于 DMCA 和其他相关法律的法规。
## 快速开始
### 文档入口
这里是 Fish Audio S2 的官方文档,请按照说明轻松入门。
- [安装](https://speech.fish.audio/zh/install/)
- [命令行推理](https://speech.fish.audio/zh/inference/)
- [WebUI 推理](https://speech.fish.audio/zh/inference/)
- [服务端推理](https://speech.fish.audio/zh/server/)
- [Docker 部署](https://speech.fish.audio/zh/install/)
> [!IMPORTANT]
> **如需使用 SGLang Server,请参考 [SGLang-Omni README](https://github.com/sgl-project/sglang-omni/blob/main/sglang_omni/models/fishaudio_s2_pro/README.md)。**
### LLM Agent 指南
```
请先阅读 https://speech.fish.audio/zh/install/ ,并按文档安装和配置 Fish Audio S2。
```
## Fish Audio S2 Pro
**行业顶尖的多语言文本转语音 (TTS) 系统,重新定义声音生成的边界。**
Fish Audio S2 Pro 是 [Fish Audio](https://fish.audio/) 开发的最先进的多模态模型。S2 Pro 训练自超过 **1000 万小时** 的海量音频数据,覆盖全球 **80 多种语言**。通过创新的 **双自回归 (Dual-AR)** 架构与强化学习 (RL) 对齐技术,S2 Pro 能生成极具自然感、真实感且情感饱满的语音,在开源与闭源竞争中均处于领先地位。
S2 Pro 的杀手锏在于支持通过自然语言标签(如 `[whisper]`、`[excited]`、`[angry]`)对韵律与情绪进行 **亚词级(Sub-word Level)** 的极细粒度行内控制,同时原生支持多说话人与超长上下文的多轮对话生成。
立即访问 [Fish Audio 官网](https://fish.audio/) 体验在线演示,或阅读我们的[技术报告](https://arxiv.org/abs/2603.08823)与[博客文章](https://fish.audio/blog/fish-audio-open-sources-s2/)深入了解。
### 模型变体
| 模型 | 大小 | 可用性 | 描述 |
|------|------|-------------|-------------|
| S2-Pro | 4B 参数 | [HuggingFace](https://huggingface.co/fishaudio/s2-pro) | 功能齐全的旗舰模型,具有最高质量和稳定性 |
有关模型的更多详情,请参见[技术报告](https://arxiv.org/abs/2411.01156)。
## 基准测试结果
| 基准 | Fish Audio S2 |
|------|------|
| Seed-TTS Eval — WER(中文) | **0.54%**(总体最佳) |
| Seed-TTS Eval — WER(英文) | **0.99%**(总体最佳) |
| Audio Turing Test(含指令) | **0.515** 后验均值 |
| EmergentTTS-Eval — 胜率 | **81.88%**(总体最高) |
| Fish Instruction Benchmark — TAR | **93.3%** |
| Fish Instruction Benchmark — 质量 | **4.51 / 5.0** |
| 多语言(MiniMax Testset)— 最佳 WER | **24** 种语言中的 **11** 种 |
| 多语言(MiniMax Testset)— 最佳 SIM | **24** 种语言中的 **17** 种 |
在 Seed-TTS Eval 上,S2 在所有已评估模型(包括闭源系统)中实现了最低 WER:Qwen3-TTS(0.77/1.24)、MiniMax Speech-02(0.99/1.90)、Seed-TTS(1.12/2.25)。在 Audio Turing Test 上,S2 的 0.515 相比 Seed-TTS(0.417)提升 24%,相比 MiniMax-Speech(0.387)提升 33%。在 EmergentTTS-Eval 中,S2 在副语言学(91.61% 胜率)、疑问句(84.41%)和句法复杂度(83.39%)等维度表现尤为突出。
## 亮点
### 通过自然语言进行极细粒度行内控制
S2 Pro 赋予了语音前所未有的“灵性”。通过简单的 `[tag]` 语法,你可以在文本的任何位置精准嵌入情感指令。
- **15,000+ 独特标签支持**:不局限于固定的预设,支持 **自由格式的文本描述**。你可以尝试 `[whisper in small voice]` (低声耳语), `[professional broadcast tone]` (专业播音腔), 或 `[pitch up]` (提高音调)。
- **丰富的情绪库**:
`[pause]` `[emphasis]` `[laughing]` `[inhale]` `[chuckle]` `[tsk]` `[singing]` `[excited]` `[laughing tone]` `[interrupting]` `[chuckling]` `[excited tone]` `[volume up]` `[echo]` `[angry]` `[low volume]` `[sigh]` `[low voice]` `[whisper]` `[screaming]` `[shouting]` `[loud]` `[surprised]` `[short pause]` `[exhale]` `[delight]` `[panting]` `[audience laughter]` `[with strong accent]` `[volume down]` `[clearing throat]` `[sad]` `[moaning]` `[shocked]`
### 创新的双自回归 (Dual-Autoregressive) 架构
S2 Pro 采用了主从式 Dual-AR 架构,由 Decoder-only Transformer 与 RVQ 音频编解码器(10 个码本,约 21 Hz 帧率)组成:
- **Slow AR (4B 参数)**:沿时间轴工作,预测核心的语义码本。
- **Fast AR (400M 参数)**:在每个时间步生成剩余 9 个残差码本,细腻还原极致的音频细节。
这种非对称设计在保证音频极致保真度的同时,大幅提升了推理速度。
### 强化学习对齐 (RL Alignment)
S2 Pro 采用了 **Group Relative Policy Optimization (GRPO)** 技术进行后训练对齐。我们将用于数据清洗与标注的同一套模型直接作为奖励模型 (Reward Model),完美解决了预训练数据分布与后训练目标之间的不匹配问题。
- **多维奖励信号**:综合评估语义准确性、指令遵循能力、声学偏好评分以及音色相似度,确保生成的每一秒语音都符合人类直觉。
### 极致的流式推理性能 (基于 SGLang)
由于 Dual-AR 架构与标准 LLM 结构同构,S2 Pro 原生支持 SGLang 的所有推理加速特性,包括连续批处理 (Continuous Batching)、分页 KV Cache、CUDA Graph 与基于 RadixAttention 的前缀缓存。
**单张 NVIDIA H200 GPU 性能表现:**
- **实时因子 (RTF)**:0.195
- **首音延迟 (TTFA)**:约 100 ms
- **极速吞吐**:在保持 RTF < 0.5 时,吞吐量达到 3,000+ acoustic tokens/s
### 强大的多语言支持
S2 Pro 支持 80 多种语言,无需音素或特定语言的处理即可实现高质量合成:
- **第一梯队 (Tier 1)**:日语 (ja), 英语 (en), 中文 (zh)
- **第二梯队 (Tier 2)**:韩语 (ko), 西班牙语 (es), 葡萄牙语 (pt), 阿拉伯语 (ar), 俄语 (ru), 法语 (fr), 德语 (de)
- **全球覆盖**:sv, it, tr, no, nl, cy, eu, ca, da, gl, ta, hu, fi, pl, et, hi, la, ur, th, vi, jw, bn, yo, xsl, cs, sw, nn, he, ms, uk, id, kk, bg, lv, my, tl, sk, ne, fa, af, el, bo, hr, ro, sn, mi, yi, am, be, km, is, az, sd, br, sq, ps, mn, ht, ml, sr, sa, te, ka, bs, pa, lt, kn, si, hy, mr, as, gu, fo 等。
### 原生多说话人生成
Fish Audio S2 允许用户上传包含多个说话人的参考音频,模型将通过 `<|speaker:i|>` 令牌处理每个说话人的特征。之后您可以通过说话人 ID 令牌控制模型的表现,从而实现一次生成中包含多个说话人。再也不需要像以前那样针对每个说话人都单独上传参考音频与生成语音了。
### 多轮对话生成
得益于模型上下文的扩展,我们的模型现在可以借助上文的信息提高后续生成内容的表现力,从而提升内容的自然度。
### 快速语音克隆
Fish Audio S2 支持使用短参考样本(通常为 10-30 秒)进行准确的语音克隆。模型可以捕捉音色、说话风格和情感倾向,无需额外微调即可生成逼真且一致的克隆语音。
如需使用 SGLang Server,请参考 [SGLang-Omni README](https://github.com/sgl-project/sglang-omni/blob/main/sglang_omni/models/fishaudio_s2_pro/README.md) 。
---
## 致谢
- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
- [GPT VITS](https://github.com/innnky/gpt-vits)
- [MQTTS](https://github.com/b04901014/MQTTS)
- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
- [Qwen3](https://github.com/QwenLM/Qwen3)
## 技术报告
```bibtex
@misc{fish-speech-v1.4,
title={Fish-Speech: Leveraging Large Language Models for Advanced Multilingual Text-to-Speech Synthesis},
author={Shijia Liao and Yuxuan Wang and Tianyu Li and Yifan Cheng and Ruoyi Zhang and Rongzhi Zhou and Yijin Xing},
year={2024},
eprint={2411.01156},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2411.01156},
}
@misc{liao2026fishaudios2technical,
title={Fish Audio S2 Technical Report},
author={Shijia Liao and Yuxuan Wang and Songting Liu and Yifan Cheng and Ruoyi Zhang and Tianyu Li and Shidong Li and Yisheng Zheng and Xingwei Liu and Qingzheng Wang and Zhizhuo Zhou and Jiahua Liu and Xin Chen and Dawei Han},
year={2026},
eprint={2603.08823},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2603.08823},
}
```
================================================
FILE: docs/ar/finetune.md
================================================
# الضبط الدقيق (Fine-tuning)
من الواضح أنك عندما فتحت هذه الصفحة، لم تكن راضيًا عن أداء النموذج المدرب مسبقًا في وضع zero-shot. أنت ترغب في إجراء ضبط دقيق لنموذج لتحسين أدائه على مجموعة البيانات الخاصة بك.
في الإصدار الحالي، ما عليك سوى إجراء الضبط الدقيق لجزء 'LLAMA'.
## الضبط الدقيق لـ LLAMA
### 1. إعداد مجموعة البيانات
```
.
├── SPK1
│ ├── 21.15-26.44.lab
│ ├── 21.15-26.44.mp3
│ ├── 27.51-29.98.lab
│ ├── 27.51-29.98.mp3
│ ├── 30.1-32.71.lab
│ └── 30.1-32.71.mp3
└── SPK2
├── 38.79-40.85.lab
└── 38.79-40.85.mp3
```
تحتاج إلى تحويل مجموعة البيانات الخاصة بك إلى التنسيق أعلاه ووضعها تحت مجلد `data`. يمكن أن يكون للملف الصوتي الامتدادات `.mp3`، `.wav`، أو `.flac`، ويجب أن يكون لملف التعليقات التوضيحية الامتداد `.lab`.
!!! info "تنسيق مجموعة البيانات"
يحتاج ملف التعليقات التوضيحية `.lab` فقط إلى احتواء النص المكتوب للمقطع الصوتي، دون الحاجة إلى تنسيق خاص. على سبيل المثال، إذا كان محتوى `hi.mp3` هو "مرحبًا، وداعًا"، فسيحتوي ملف `hi.lab` على سطر واحد من النص: "مرحبًا، وداعًا".
!!! warning "تحذير"
يوصى بتطبيق تسوية جهارة الصوت (loudness normalization) على مجموعة البيانات. يمكنك استخدام [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) للقيام بذلك.
```bash
fap loudness-norm data-raw data --clean
```
### 2. الاستخراج الدفعي للرموز الدلالية (semantic tokens)
تأكد من أنك قمت بتنزيل أوزان VQGAN. إذا لم تكن قد فعلت، قم بتشغيل الأمر التالي:
```bash
huggingface-cli download fishaudio/openaudio-s1-mini --local-dir checkpoints/openaudio-s1-mini
```
يمكنك بعد ذلك تشغيل الأمر التالي لاستخراج الرموز الدلالية:
```bash
python tools/vqgan/extract_vq.py data \
--num-workers 1 --batch-size 16 \
--config-name "modded_dac_vq" \
--checkpoint-path "checkpoints/openaudio-s1-mini/codec.pth"
```
!!! note "ملاحظة"
يمكنك ضبط `--num-workers` و `--batch-size` لزيادة سرعة الاستخراج، ولكن يرجى التأكد من عدم تجاوز حد ذاكرة وحدة معالجة الرسومات (GPU) الخاصة بك.
سيقوم هذا الأمر بإنشاء ملفات `.npy` في مجلد `data`، كما هو موضح أدناه:
```
.
├── SPK1
│ ├── 21.15-26.44.lab
│ ├── 21.15-26.44.mp3
│ ├── 21.15-26.44.npy
│ ├── 27.51-29.98.lab
│ ├── 27.51-29.98.mp3
│ ├── 27.51-29.98.npy
│ ├── 30.1-32.71.lab
│ ├── 30.1-32.71.mp3
│ └── 30.1-32.71.npy
└── SPK2
├── 38.79-40.85.lab
├── 38.79-40.85.mp3
└── 38.79-40.85.npy
```
### 3. حزم مجموعة البيانات في protobuf
```bash
python tools/llama/build_dataset.py \
--input "data" \
--output "data/protos" \
--text-extension .lab \
--num-workers 16
```
بعد انتهاء تنفيذ الأمر، يجب أن ترى ملف `protos` في مجلد `data`.
### 4. أخيرًا، الضبط الدقيق باستخدام LoRA
بالمثل، تأكد من أنك قمت بتنزيل أوزان `LLAMA`. إذا لم تكن قد فعلت، قم بتشغيل الأمر التالي:
```bash
huggingface-cli download fishaudio/openaudio-s1-mini --local-dir checkpoints/openaudio-s1-mini
```
أخيرًا، يمكنك بدء الضبط الدقيق عن طريق تشغيل الأمر التالي:
```bash
python fish_speech/train.py --config-name text2semantic_finetune \
project=$project \
+lora@model.model.lora_config=r_8_alpha_16
```
!!! note "ملاحظة"
يمكنك تعديل معلمات التدريب مثل `batch_size`، `gradient_accumulation_steps`، وما إلى ذلك لتناسب ذاكرة وحدة معالجة الرسومات الخاصة بك عن طريق تعديل `fish_speech/configs/text2semantic_finetune.yaml`.
!!! note "ملاحظة"
لمستخدمي Windows، يمكنك استخدام `trainer.strategy.process_group_backend=gloo` لتجنب مشكلات `nccl`.
بعد اكتمال التدريب، يمكنك الرجوع إلى قسم [الاستدلال (inference)](inference.md) لاختبار نموذجك.
!!! info "معلومات"
بشكل افتراضي، سيتعلم النموذج فقط أنماط كلام المتحدث وليس جرس الصوت (timbre). لا تزال بحاجة إلى استخدام التلقينات (prompts) لضمان استقرار جرس الصوت.
إذا كنت ترغب في تعلم جرس الصوت، يمكنك زيادة عدد خطوات التدريب، ولكن هذا قد يؤدي إلى الإفراط في التخصيص (overfitting).
بعد التدريب، تحتاج إلى تحويل أوزان LoRA إلى أوزان عادية قبل إجراء الاستدلال.
```bash
python tools/llama/merge_lora.py \
--lora-config r_8_alpha_16 \
--base-weight checkpoints/openaudio-s1-mini \
--lora-weight results/$project/checkpoints/step_000000010.ckpt \
--output checkpoints/openaudio-s1-mini-yth-lora/
```
!!! note "ملاحظة"
يمكنك أيضًا تجربة نقاط تحقق (checkpoints) أخرى. نقترح استخدام أقدم نقطة تحقق تلبي متطلباتك، حيث إنها غالبًا ما تؤدي أداءً أفضل على البيانات خارج التوزيع (OOD).
================================================
FILE: docs/ar/index.md
================================================
!!! info "تنبيه الترخيص"
يتم إصدار قاعدة الأكواد هذه وأوزان النماذج المرتبطة بها بموجب رخصة **FISH AUDIO RESEARCH LICENSE**. يرجى الرجوع إلى [LICENSE](https://github.com/fishaudio/fish-speech/blob/main/LICENSE) لمزيد من التفاصيل.
!!! warning "إخلاء المسؤولية القانونية"
نحن لا نتحمل أي مسؤولية عن أي استخدام غير قانوني لقاعدة الأكواد. يرجى مراجعة القوانين المحلية المتعلقة بـ DMCA والقوانين الأخرى ذات الصلة.
## البدء السريع
### ابدأ من الوثائق
هذه هي الوثائق الرسمية لـ Fish Audio S2، ويمكنك البدء مباشرة عبر الروابط التالية:
- [التثبيت](https://speech.fish.audio/ar/install/)
- [الاستدلال عبر سطر الأوامر](https://speech.fish.audio/ar/inference/)
- [استدلال WebUI](https://speech.fish.audio/ar/inference/)
- [الاستدلال عبر الخادم](https://speech.fish.audio/ar/server/)
- [إعداد Docker](https://speech.fish.audio/ar/install/)
> [!IMPORTANT]
> **بالنسبة لخادم SGLang، راجع [SGLang-Omni README](https://github.com/sgl-project/sglang-omni/blob/main/sglang_omni/models/fishaudio_s2_pro/README.md).**
### دليل وكلاء LLM
```
قم بتثبيت وإعداد Fish Audio S2 باتباع التعليمات في https://speech.fish.audio/ar/install/ .
```
## Fish Audio S2
**أفضل نظام لتحويل النص إلى كلام بين الأنظمة مفتوحة المصدر ومغلقة المصدر**
Fish Audio S2 هو أحدث نموذج من [Fish Audio](https://fish.audio/). تم تدريبه على أكثر من 10 ملايين ساعة صوتية عبر نحو 50 لغة، ويجمع بين المواءمة بالتعلم المعزز وبنية Dual-Autoregressive لإنتاج كلام طبيعي وواقعي وغني بالتعبير العاطفي.
يدعم S2 التحكم الدقيق في النبرة والعاطفة داخل النص نفسه باستخدام وسوم باللغة الطبيعية مثل `[laugh]` و`[whispers]` و`[super happy]`، كما يدعم بشكل أصيل توليد متحدثين متعددين وحوارات متعددة الأدوار.
يمكنك تجربة النموذج مباشرة عبر [موقع Fish Audio](https://fish.audio/)، وقراءة المزيد في [منشور المدونة](https://fish.audio/blog/fish-audio-open-sources-s2/) و[التقرير التقني](https://arxiv.org/abs/2603.08823).
### إصدارات النموذج
| النموذج | الحجم | التوفر | الوصف |
|------|------|-------------|-------------|
| S2-Pro | 4B معلمة | [HuggingFace](https://huggingface.co/fishaudio/s2-pro) | نموذج رائد كامل الميزات بأعلى مستوى من الجودة والاستقرار |
يمكن العثور على مزيد من التفاصيل في [التقرير التقني](https://arxiv.org/abs/2411.01156).
## نتائج القياس المعياري
| المعيار | Fish Audio S2 |
|------|------|
| Seed-TTS Eval — WER (الصينية) | **0.54%** (الأفضل إجمالاً) |
| Seed-TTS Eval — WER (الإنجليزية) | **0.99%** (الأفضل إجمالاً) |
| Audio Turing Test (مع التعليمات) | **0.515** المتوسط البعدي |
| EmergentTTS-Eval — معدل الفوز | **81.88%** (الأعلى إجمالاً) |
| Fish Instruction Benchmark — TAR | **93.3%** |
| Fish Instruction Benchmark — الجودة | **4.51 / 5.0** |
| متعدد اللغات (MiniMax Testset) — أفضل WER | **11 من 24** لغة |
| متعدد اللغات (MiniMax Testset) — أفضل SIM | **17 من 24** لغة |
في Seed-TTS Eval، حقق S2 أقل WER بين جميع النماذج التي تم تقييمها، بما في ذلك الأنظمة المغلقة: Qwen3-TTS (0.77/1.24)، وMiniMax Speech-02 (0.99/1.90)، وSeed-TTS (1.12/2.25). وفي Audio Turing Test، تفوقت قيمة 0.515 على Seed-TTS (0.417) بنسبة 24% وعلى MiniMax-Speech (0.387) بنسبة 33%. وفي EmergentTTS-Eval، حقق S2 نتائج قوية بشكل خاص في الخصائص شبه اللغوية (91.61%)، والأسئلة (84.41%)، والتعقيد النحوي (83.39%).
## أبرز المميزات
### تحكم مضمّن دقيق عبر اللغة الطبيعية
يتيح Fish Audio S2 تحكمًا موضعيًا في توليد الكلام من خلال تضمين تعليمات باللغة الطبيعية مباشرة عند مواقع كلمات أو عبارات محددة داخل النص. وبدلًا من الاعتماد على مجموعة ثابتة من الوسوم المُعرّفة مسبقًا، يقبل S2 أوصافًا نصية حرة مثل [whisper in small voice] أو [professional broadcast tone] أو [pitch up]، مما يتيح تحكمًا مفتوحًا في التعبير على مستوى الكلمة.
### بنية Dual-Autoregressive
يعتمد S2 على Transformer أحادي الاتجاه (Decoder-only) مع مُرمّز صوتي قائم على RVQ (عدد 10 codebooks وبمعدل إطارات يقارب 21 هرتز). وتُقسّم بنية Dual-AR عملية التوليد إلى مرحلتين:
- **Slow AR** يعمل على المحور الزمني ويتنبأ بالـ semantic codebook الأساسي.
- **Fast AR** يولّد الـ 9 residual codebooks المتبقية في كل خطوة زمنية لإعادة بناء التفاصيل الصوتية الدقيقة.
هذا التصميم غير المتماثل (4B معلمة على المحور الزمني و400M على محور العمق) يرفع كفاءة الاستدلال مع الحفاظ على جودة الصوت.
### المواءمة بالتعلم المعزز
يستخدم S2 خوارزمية Group Relative Policy Optimization (GRPO) للمواءمة بعد التدريب. ويتم إعادة استخدام نفس النماذج التي استُخدمت لتصفية بيانات التدريب وتعليقها كنماذج مكافأة في التعلم المعزز مباشرة، مما يلغي عدم تطابق التوزيع بين بيانات ما قبل التدريب وأهداف ما بعد التدريب. وتجمع إشارة المكافأة بين الدقة الدلالية، والالتزام بالتعليمات، وتقييم التفضيل الصوتي، وتشابه النبرة.
### البث الإنتاجي عبر SGLang
لأن بنية Dual-AR متماثلة بنيويًا مع نماذج LLM autoregressive القياسية، فإن S2 يرث مباشرة تحسينات الخدمة الأصلية في SGLang، بما في ذلك: continuous batching، وpaged KV cache، وCUDA graph replay، وprefix caching المعتمد على RadixAttention.
على بطاقة NVIDIA H200 واحدة:
- **عامل الزمن الحقيقي (RTF):** 0.195
- **الزمن حتى أول مقطع صوتي:** حوالي 100 مللي ثانية
- **معدل المعالجة:** أكثر من 3,000 acoustic tokens/s مع الحفاظ على RTF أقل من 0.5
### دعم لغات متعددة
يدعم Fish Audio S2 تحويل النص إلى كلام بجودة عالية ولغات متعددة دون الحاجة إلى رموز صوتية أو معالجة مسبقة خاصة بكل لغة. بما في ذلك:
**الإنجليزية، الصينية، اليابانية، الكورية، العربية، الألمانية، الفرنسية...**
**وأكثر من ذلك بكثير!**
القائمة في توسع مستمر، تحقق من [Fish Audio](https://fish.audio/) لمعرفة أحدث الإصدارات.
### توليد أصلي لمتحدثين متعددين
يسمح Fish Audio S2 للمستخدمين برفع صوت مرجعي يحتوي على متحدثين متعددين، وسيتعامل النموذج مع ميزات كل متحدث عبر رمز `<|speaker:i|>`. يمكنك بعد ذلك التحكم في أداء النموذج باستخدام رمز معرف المتحدث، مما يسمح بتوليد واحد يتضمن متحدثين متعددين. لم تعد بحاجة لرفع ملفات مرجعية منفصلة لكل متحدث.
### توليد حوارات متعددة الأدوار
بفضل توسيع سياق النموذج، يمكن لنموذجنا الآن استخدام المعلومات السابقة لتحسين التعبير في المحتوى المولد لاحقاً، مما يزيد من طبيعية المحتوى.
### استنساخ صوت سريع
يدعم Fish Audio S2 استنساخ الصوت بدقة باستخدام عينة مرجعية قصيرة (عادةً 10-30 ثانية). يلتقط النموذج نبرة الصوت، وأسلوب التحدث، والميول العاطفية، مما ينتج أصواتاً مستنسخة واقعية ومتسقة دون الحاجة إلى ضبط دقيق إضافي.
لاستخدام خادم SGLang، راجع [SGLang-Omni README](https://github.com/sgl-project/sglang-omni/blob/main/sglang_omni/models/fishaudio_s2_pro/README.md) .
---
## شكر وتقدير
- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
- [GPT VITS](https://github.com/innnky/gpt-vits)
- [MQTTS](https://github.com/b04901014/MQTTS)
- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
- [Qwen3](https://github.com/QwenLM/Qwen3)
## التقرير التقني
```bibtex
@misc{fish-speech-v1.4,
title={Fish-Speech: Leveraging Large Language Models for Advanced Multilingual Text-to-Speech Synthesis},
author={Shijia Liao and Yuxuan Wang and Tianyu Li and Yifan Cheng and Ruoyi Zhang and Rongzhi Zhou and Yijin Xing},
year={2024},
eprint={2411.01156},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2411.01156},
}
@misc{liao2026fishaudios2technical,
title={Fish Audio S2 Technical Report},
author={Shijia Liao and Yuxuan Wang and Songting Liu and Yifan Cheng and Ruoyi Zhang and Tianyu Li and Shidong Li and Yisheng Zheng and Xingwei Liu and Qingzheng Wang and Zhizhuo Zhou and Jiahua Liu and Xin Chen and Dawei Han},
year={2026},
eprint={2603.08823},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2603.08823},
}
```
================================================
FILE: docs/ar/inference.md
================================================
# الاستنتاج
يتطلب نموذج Fish Audio S2 ذاكرة فيديو (VRAM) كبيرة. نوصي باستخدام وحدة معالجة رسومات (GPU) بسعة 24 جيجابايت على الأقل للاستنتاج.
## تحميل الأوزان
أولاً ، تحتاج إلى تحميل أوزان النموذج:
```bash
hf download fishaudio/s2-pro --local-dir checkpoints/s2-pro
```
## الاستنتاج عبر خط الأوامر
!!! note
إذا كنت تخطط لترك النموذج يختار نغمة الصوت عشوائيًا ، فيمكنك تخطي هذه الخطوة.
### 1. الحصول على رموز VQ من الصوت المرجعي
```bash
python fish_speech/models/dac/inference.py \
-i "test.wav" \
--checkpoint-path "checkpoints/s2-pro/codec.pth"
```
يجب أن تحصل على `fake.npy` و `fake.wav`.
### 2. توليد الرموز الدلالية (Semantic tokens) من النص:
```bash
python fish_speech/models/text2semantic/inference.py \
--text "النص الذي تريد تحويله" \
--prompt-text "النص المرجعي الخاص بك" \
--prompt-tokens "fake.npy" \
# --compile
```
سيقوم هذا الأمر بإنشاء ملف `codes_N` في دليل العمل ، حيث N هو عدد صحيح يبدأ من 0.
!!! note
قد ترغب في استخدام `--compile` لدمج نوى CUDA لاستنتاج أسرع. ومع ذلك ، نوصي باستخدام تحسين تسريع الاستنتاج sglang الخاص بنا.
بالمقابل ، إذا كنت لا تخطط لاستخدام التسريع ، يمكنك التعليق على معلمة `--compile`.
!!! info
بالنسبة لوحدات معالجة الرسومات التي لا تدعم bf16 ، قد تحتاج إلى استخدام معلمة `--half`.
### 3. توليد الصوت من الرموز الدلالية:
```bash
python fish_speech/models/dac/inference.py \
-i "codes_0.npy" \
```
بعد ذلك ستحصل على ملف `fake.wav`.
## استنتاج WebUI
### 1. Gradio WebUI
للحفاظ على التوافق، ما زلنا نحتفظ بواجهة Gradio WebUI السابقة.
```bash
python tools/run_webui.py # --compile إذا كنت بحاجة إلى تسريع
```
### 2. Awesome WebUI
تعد Awesome WebUI واجهة ويب حديثة تعتمد على TypeScript، وتوفر ميزات أغنى وتجربة مستخدم أفضل.
**بناء WebUI:**
يجب أن يكون لديك Node.js و npm مثبتين على جهازك المحلي أو الخادم.
1. ادخل إلى دليل `awesome_webui`:
```bash
cd awesome_webui
```
2. تثبيت التبعيات:
```bash
npm install
```
3. بناء WebUI:
```bash
npm run build
```
**بدء تشغيل خادم الخلفية:**
بعد بناء WebUI، عد إلى دليل جذر المشروع وقم بتشغيل خادم API:
```bash
python tools/api_server.py --listen 0.0.0.0:8888 --compile
```
**الوصول:**
بمجرد تشغيل الخادم، يمكنك الوصول إليه عبر المتصفح على العنوان التالي:
`http://localhost:8888/ui`
================================================
FILE: docs/ar/install.md
================================================
## المتطلبات
- ذاكرة وحدة معالجة الرسومات (GPU): 24 جيجابايت (للاستدلال)
- النظام: Linux, WSL
## إعداد النظام
يدعم Fish Audio S2 طرق تثبيت متعددة. اختر الطريقة التي تناسب بيئة التطوير الخاصة بك.
**المتطلبات الأساسية**: قم بتثبيت تبعيات النظام لمعالجة الصوت:
``` bash
apt install portaudio19-dev libsox-dev ffmpeg
```
### Conda
```bash
conda create -n fish-speech python=3.12
conda activate fish-speech
# تثبيت نسخة GPU (اختر إصدار CUDA الخاص بك: cu126, cu128, cu129)
pip install -e .[cu129]
# تثبيت نسخة CPU فقط
pip install -e .[cpu]
# التثبيت الافتراضي (يستخدم فهرس PyTorch الافتراضي)
pip install -e .
# إذا واجهت خطأ أثناء التثبيت بسبب pyaudio، ففكر في استخدام الأمر التالي:
# conda install pyaudio
# ثم قم بتشغيل pip install -e . مرة أخرى
```
### UV
يوفر UV حلاً أسرع لتثبيت التبعيات:
```bash
# تثبيت نسخة GPU (اختر إصدار CUDA الخاص بك: cu126, cu128, cu129)
uv sync --python 3.12 --extra cu129
# تثبيت نسخة CPU فقط
uv sync --python 3.12 --extra cpu
```
### دعم Intel Arc XPU
لمستخدمي وحدات معالجة الرسومات Intel Arc، قم بالتثبيت مع دعم XPU على النحو التالي:
```bash
conda create -n fish-speech python=3.12
conda activate fish-speech
# تثبيت مكتبة C++ القياسية المطلوبة
conda install libstdcxx -c conda-forge
# تثبيت PyTorch مع دعم Intel XPU
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu
# تثبيت Fish Speech
pip install -e .
```
!!! warning
خيار `compile` غير مدعوم على أنظمة Windows و macOS. إذا كنت ترغب في التشغيل مع التجميع، ستحتاج إلى تثبيت Triton بنفسك.
## إعداد Docker
يوفر نموذج سلسلة Fish Audio S2 خيارات نشر متعددة مع Docker لتلبية الاحتياجات المختلفة. يمكنك استخدام الصور المعدة مسبقًا من Docker Hub، أو البناء محليًا باستخدام Docker Compose، أو بناء صور مخصصة يدويًا.
لقد قدمنا صور Docker لكل من واجهة المستخدم الرسومية (WebUI) وخادم API، لكل من وحدات معالجة الرسومات (GPU) (CUDA 12.6 افتراضيًا) ووحدات المعالجة المركزية (CPU). يمكنك استخدام الصور المعدة مسبقًا من Docker Hub، أو البناء محليًا باستخدام Docker Compose، أو بناء صور مخصصة يدويًا. إذا كنت ترغب في البناء محليًا، فاتبع الإرشادات أدناه. إذا كنت ترغب فقط في استخدام الصور المعدة مسبقًا، فاتبع مباشرةً [دليل الاستدلال](inference.md).
### المتطلبات الأساسية
- تثبيت Docker و Docker Compose
- تثبيت NVIDIA Docker runtime (لدعم GPU)
- ذاكرة GPU لا تقل عن 24 جيجابايت للاستدلال باستخدام CUDA
### استخدام Docker Compose
للتطوير أو التخصيص، يمكنك استخدام Docker Compose للبناء والتشغيل محليًا:
```bash
# أولاً، استنسخ المستودع
git clone https://github.com/fishaudio/fish-speech.git
cd fish-speech
# بدء واجهة المستخدم الرسومية (WebUI) مع CUDA
docker compose --profile webui up
# بدء واجهة المستخدم الرسومية (WebUI) مع تحسين التجميع
COMPILE=1 docker compose --profile webui up
# بدء خادم API
docker compose --profile server up
# بدء خادم API مع تحسين التجميع
COMPILE=1 docker compose --profile server up
# النشر باستخدام CPU فقط
BACKEND=cpu docker compose --profile webui up
```
#### متغيرات البيئة لـ Docker Compose
يمكنك تخصيص النشر باستخدام متغيرات البيئة:
```bash
# مثال على ملف .env
BACKEND=cuda # أو cpu
COMPILE=1 # تمكين تحسين التجميع
GRADIO_PORT=7860 # منفذ واجهة المستخدم الرسومية (WebUI)
API_PORT=8080 # منفذ خادم API
UV_VERSION=0.8.15 # إصدار مدير الحزم UV
```
سيقوم الأمر ببناء الصورة وتشغيل الحاوية. يمكنك الوصول إلى واجهة المستخدم الرسومية (WebUI) على `http://localhost:7860` وخادم API على `http://localhost:8080`.
### البناء اليدوي باستخدام Docker
للمستخدمين المتقدمين الذين يرغبون في تخصيص عملية البناء:
```bash
# بناء صورة واجهة المستخدم الرسومية (WebUI) مع دعم CUDA
docker build \
--platform linux/amd64 \
-f docker/Dockerfile \
--build-arg BACKEND=cuda \
--build-arg CUDA_VER=12.6.0 \
--build-arg UV_EXTRA=cu126 \
--target webui \
-t fish-speech-webui:cuda .
# بناء صورة خادم API مع دعم CUDA
docker build \
--platform linux/amd64 \
-f docker/Dockerfile \
--build-arg BACKEND=cuda \
--build-arg CUDA_VER=12.6.0 \
--build-arg UV_EXTRA=cu126 \
--target server \
-t fish-speech-server:cuda .
# بناء صورة CPU فقط (تدعم منصات متعددة)
docker build \
--platform linux/amd64,linux/arm64 \
-f docker/Dockerfile \
--build-arg BACKEND=cpu \
--target webui \
-t fish-speech-webui:cpu .
# بناء صورة التطوير
docker build \
--platform linux/amd64 \
-f docker/Dockerfile \
--build-arg BACKEND=cuda \
--target dev \
-t fish-speech-dev:cuda .
```
#### وسيطات البناء
- `BACKEND`: `cuda` أو `cpu` (الافتراضي: `cuda`)
- `CUDA_VER`: إصدار CUDA (الافتراضي: `12.6.0`)
- `UV_EXTRA`: حزمة UV إضافية لـ CUDA (الافتراضي: `cu126`)
- `UBUNTU_VER`: إصدار Ubuntu (الافتراضي: `24.04`)
- `PY_VER`: إصدار Python (الافتراضي: `3.12`)
### تحميل المجلدات
تتطلب كلتا الطريقتين تحميل المجلدات التالية:
- `./checkpoints:/app/checkpoints` - مجلد أوزان النموذج
- `./references:/app/references` - مجلد ملفات الصوت المرجعية
### متغيرات البيئة
- `COMPILE=1` - تمكين `torch.compile` لتسريع الاستدلال (حوالي 10 أضعاف)
- `GRADIO_SERVER_NAME=0.0.0.0` - مضيف خادم واجهة المستخدم الرسومية (WebUI)
- `GRADIO_SERVER_PORT=7860` - منفذ خادم واجهة المستخدم الرسومية (WebUI)
- `API_SERVER_NAME=0.0.0.0` - مضيف خادم API
- `API_SERVER_PORT=8080` - منفذ خادم API
!!! note
تتوقع حاويات Docker أن يتم تحميل أوزان النموذج في `/app/checkpoints`. تأكد من تنزيل أوزان النموذج المطلوبة قبل بدء الحاويات.
!!! warning
يتطلب دعم GPU وجود NVIDIA Docker runtime. للنشر باستخدام CPU فقط، قم بإزالة علامة `--gpus all` واستخدم صور CPU.
================================================
FILE: docs/en/finetune.md
================================================
# Fine-tuning
!!! warning
We highly do note recoomand users to do fine-tuning on an RL trained model. Fine-tuning a model after RL can shift the model distribution, which may lead to degraded performance.
In the current version, you only need to finetune the 'LLAMA' part.
## Fine-tuning LLAMA
### 1. Prepare the dataset
```
.
├── SPK1
│ ├── 21.15-26.44.lab
│ ├── 21.15-26.44.mp3
│ ├── 27.51-29.98.lab
│ ├── 27.51-29.98.mp3
│ ├── 30.1-32.71.lab
│ └── 30.1-32.71.mp3
└── SPK2
├── 38.79-40.85.lab
└── 38.79-40.85.mp3
```
You need to convert your dataset into the above format and place it under `data`. The audio file can have the extensions `.mp3`, `.wav`, or `.flac`, and the annotation file should have the extension `.lab`.
!!! info
The `.lab` annotation file only needs to contain the transcription of the audio, with no special formatting required. For example, if `hi.mp3` says "Hello, goodbye," then the `hi.lab` file would contain a single line of text: "Hello, goodbye."
!!! warning
It's recommended to apply loudness normalization to the dataset. You can use [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) to do this.
```bash
fap loudness-norm data-raw data --clean
```
### 2. Batch extraction of semantic tokens
Make sure you have downloaded the VQGAN weights. If not, run the following command:
```bash
huggingface-cli download fishaudio/openaudio-s1-mini --local-dir checkpoints/openaudio-s1-mini
```
You can then run the following command to extract semantic tokens:
```bash
python tools/vqgan/extract_vq.py data \
--num-workers 1 --batch-size 16 \
--config-name "modded_dac_vq" \
--checkpoint-path "checkpoints/openaudio-s1-mini/codec.pth"
```
!!! note
You can adjust `--num-workers` and `--batch-size` to increase extraction speed, but please make sure not to exceed your GPU memory limit.
This command will create `.npy` files in the `data` directory, as shown below:
```
.
├── SPK1
│ ├── 21.15-26.44.lab
│ ├── 21.15-26.44.mp3
│ ├── 21.15-26.44.npy
│ ├── 27.51-29.98.lab
│ ├── 27.51-29.98.mp3
│ ├── 27.51-29.98.npy
│ ├── 30.1-32.71.lab
│ ├── 30.1-32.71.mp3
│ └── 30.1-32.71.npy
└── SPK2
├── 38.79-40.85.lab
├── 38.79-40.85.mp3
└── 38.79-40.85.npy
```
### 3. Pack the dataset into protobuf
```bash
python tools/llama/build_dataset.py \
--input "data" \
--output "data/protos" \
--text-extension .lab \
--num-workers 16
```
After the command finishes executing, you should see the `protos` file in the `data` directory.
### 4. Finally, fine-tuning with LoRA
Similarly, make sure you have downloaded the `LLAMA` weights. If not, run the following command:
```bash
huggingface-cli download fishaudio/openaudio-s1-mini --local-dir checkpoints/openaudio-s1-mini
```
Finally, you can start the fine-tuning by running the following command:
```bash
python fish_speech/train.py --config-name text2semantic_finetune \
project=$project \
+lora@model.model.lora_config=r_8_alpha_16
```
!!! note
You can modify the training parameters such as `batch_size`, `gradient_accumulation_steps`, etc. to fit your GPU memory by modifying `fish_speech/configs/text2semantic_finetune.yaml`.
!!! note
For Windows users, you can use `trainer.strategy.process_group_backend=gloo` to avoid `nccl` issues.
After training is complete, you can refer to the [inference](inference.md) section to test your model.
!!! info
By default, the model will only learn the speaker's speech patterns and not the timbre. You still need to use prompts to ensure timbre stability.
If you want to learn the timbre, you can increase the number of training steps, but this may lead to overfitting.
After training, you need to convert the LoRA weights to regular weights before performing inference.
```bash
python tools/llama/merge_lora.py \
--lora-config r_8_alpha_16 \
--base-weight checkpoints/openaudio-s1-mini \
--lora-weight results/$project/checkpoints/step_000000010.ckpt \
--output checkpoints/openaudio-s1-mini-yth-lora/
```
!!! note
You may also try other checkpoints. We suggest using the earliest checkpoint that meets your requirements, as they often perform better on out-of-distribution (OOD) data.
================================================
FILE: docs/en/index.md
================================================
!!! info "License Notice"
This codebase and its associated model weights are released under **FISH AUDIO RESEARCH LICENSE**. Please refer to [LICENSE](https://github.com/fishaudio/fish-speech/blob/main/LICENSE) for more details. We will take action against any violation of the license.
!!! warning "Legal Disclaimer"
We do not hold any responsibility for any illegal usage of the codebase. Please refer to your local laws about DMCA and other related laws.
## Quick Start
### For Human
Here are the official documents for Fish Audio S2, follow the instructions to get started easily.
- [Installation](https://speech.fish.audio/install/)
- [Command Line Inference](https://speech.fish.audio/inference/#command-line-inference)
- [WebUI Inference](https://speech.fish.audio/inference/#webui-inference)
- [Server Inference](https://speech.fish.audio/server/)
- [Docker Setup](https://speech.fish.audio/install/#docker-setup)
> [!IMPORTANT]
> **For SGLang server, please read [SGLang-Omni README](https://github.com/sgl-project/sglang-omni/blob/main/sglang_omni/models/fishaudio_s2_pro/README.md).**
### For LLM Agent
```
Install and configure Fish-Audio S2 by following the instructions here: https://speech.fish.audio/install/
```
## Fish Audio S2
**Best text-to-speech system among both open source and closed source**
Fish Audio S2 is the latest model developed by [Fish Audio](https://fish.audio/). Trained on over 10 million hours of audio across approximately 50 languages, S2 combines reinforcement learning alignment with a Dual-Autoregressive architecture to generate speech that sounds natural, realistic, and emotionally rich.
S2 supports fine-grained inline control of prosody and emotion using natural-language tags like `[laugh]`, `[whispers]`, and `[super happy]`, as well as native multi-speaker and multi-turn generation.
Visit the [Fish Audio website](https://fish.audio/) for live playground. Read the [blog post](https://fish.audio/blog/fish-audio-open-sources-s2/) and [technical report](https://arxiv.org/abs/2603.08823) for more details.
### Model Variants
| Model | Size | Availability | Description |
|------|------|-------------|-------------|
| S2-Pro | 4B parameters | [HuggingFace](https://huggingface.co/fishaudio/s2-pro) | Full-featured flagship model with maximum quality and stability |
More details of the model can be found in the [technical report](https://arxiv.org/abs/2411.01156).
## Benchmark Results
| Benchmark | Fish Audio S2 |
|------|------|
| Seed-TTS Eval — WER (Chinese) | **0.54%** (best overall) |
| Seed-TTS Eval — WER (English) | **0.99%** (best overall) |
| Audio Turing Test (with instruction) | **0.515** posterior mean |
| EmergentTTS-Eval — Win Rate | **81.88%** (highest overall) |
| Fish Instruction Benchmark — TAR | **93.3%** |
| Fish Instruction Benchmark — Quality | **4.51 / 5.0** |
| Multilingual (MiniMax Testset) — Best WER | **11 of 24** languages |
| Multilingual (MiniMax Testset) — Best SIM | **17 of 24** languages |
On Seed-TTS Eval, S2 achieves the lowest WER among all evaluated models including closed-source systems: Qwen3-TTS (0.77/1.24), MiniMax Speech-02 (0.99/1.90), Seed-TTS (1.12/2.25). On the Audio Turing Test, 0.515 surpasses Seed-TTS (0.417) by 24% and MiniMax-Speech (0.387) by 33%. On EmergentTTS-Eval, S2 achieves particularly strong results in paralinguistics (91.61% win rate), questions (84.41%), and syntactic complexity (83.39%).
## Highlights
### Fine-Grained Inline Control via Natural Language
S2 enables localized control over speech generation by embedding natural-language instructions directly at specific word or phrase positions within the text. Rather than relying on a fixed set of predefined tags, S2 accepts free-form textual descriptions — such as `[whisper in small voice]`, `[professional broadcast tone]`, or `[pitch up]` — allowing open-ended expression control at the word level.
### Dual-Autoregressive Architecture
S2 builds on a decoder-only transformer combined with an RVQ-based audio codec (10 codebooks, ~21 Hz frame rate). The Dual-AR architecture splits generation into two stages:
- **Slow AR** operates along the time axis and predicts the primary semantic codebook.
- **Fast AR** generates the remaining 9 residual codebooks at each time step, reconstructing fine-grained acoustic detail.
This asymmetric design — 4B parameters along the time axis, 400M parameters along the depth axis — keeps inference efficient while preserving audio fidelity.
### Reinforcement Learning Alignment
S2 uses Group Relative Policy Optimization (GRPO) for post-training alignment. The same models used to filter and annotate training data are directly reused as reward models during RL — eliminating distribution mismatch between pre-training data and post-training objectives. The reward signal combines semantic accuracy, instruction adherence, acoustic preference scoring, and timbre similarity.
### Production Streaming via SGLang
Because the Dual-AR architecture is structurally isomorphic to standard autoregressive LLMs, S2 directly inherits all LLM-native serving optimizations from SGLang — including continuous batching, paged KV cache, CUDA graph replay, and RadixAttention-based prefix caching.
On a single NVIDIA H200 GPU:
- **Real-Time Factor (RTF):** 0.195
- **Time-to-first-audio:** ~100 ms
- **Throughput:** 3,000+ acoustic tokens/s while maintaining RTF below 0.5
### Multilingual Support
S2 supports high-quality multilingual text-to-speech without requiring phonemes or language-specific preprocessing. Including:
**English, Chinese, Japanese, Korean, Arabics, German, French...**
**AND MORE!**
The list is constantly expanding, check [Fish Audio](https://fish.audio/) for the latest releases.
### Native Multi-Speaker Generation
Fish Audio S2 allows users to upload reference audio with multi-speaker, the model will deal with every speaker's feature via `<|speaker:i|>` token. Then you can control the model's performance with the speaker id token, allowing a single generation to include multiple speakers. You no longer need to upload reference audio separately for each speaker.
### Multi-Turn Generation
Thanks to the expansion of the model context, our model can now use previous information to improve the expressiveness of subsequent generated content, thereby increasing the naturalness of the content.
### Rapid Voice Cloning
Fish Audio S2 supports accurate voice cloning using a short reference sample (typically 10–30 seconds). The model captures timbre, speaking style, and emotional tendencies, producing realistic and consistent cloned voices without additional fine-tuning.
Please refer to [SGLang-Omni README](https://github.com/sgl-project/sglang-omni/blob/main/sglang_omni/models/fishaudio_s2_pro/README.md) to use the SGLang server.
---
## Credits
- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
- [GPT VITS](https://github.com/innnky/gpt-vits)
- [MQTTS](https://github.com/b04901014/MQTTS)
- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
- [Qwen3](https://github.com/QwenLM/Qwen3)
## Tech Report
```bibtex
@misc{fish-speech-v1.4,
title={Fish-Speech: Leveraging Large Language Models for Advanced Multilingual Text-to-Speech Synthesis},
author={Shijia Liao and Yuxuan Wang and Tianyu Li and Yifan Cheng and Ruoyi Zhang and Rongzhi Zhou and Yijin Xing},
year={2024},
eprint={2411.01156},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2411.01156},
}
@misc{liao2026fishaudios2technical,
title={Fish Audio S2 Technical Report},
author={Shijia Liao and Yuxuan Wang and Songting Liu and Yifan Cheng and Ruoyi Zhang and Tianyu Li and Shidong Li and Yisheng Zheng and Xingwei Liu and Qingzheng Wang and Zhizhuo Zhou and Jiahua Liu and Xin Chen and Dawei Han},
year={2026},
eprint={2603.08823},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2603.08823},
}
```
================================================
FILE: docs/en/inference.md
================================================
# Inference
The Fish Audio S2 model requires a large amount of VRAM. We recommend using a GPU with at least 24GB for inference.
## Download Weights
First, you need to download the model weights:
```bash
hf download fishaudio/s2-pro --local-dir checkpoints/s2-pro
```
## Command Line Inference
!!! note
If you plan to let the model randomly choose a voice timbre, you can skip this step.
### 1. Get VQ tokens from reference audio
```bash
python fish_speech/models/dac/inference.py \
-i "test.wav" \
--checkpoint-path "checkpoints/s2-pro/codec.pth"
```
You should get a `fake.npy` and a `fake.wav`.
### 2. Generate Semantic tokens from text:
```bash
python fish_speech/models/text2semantic/inference.py \
--text "The text you want to convert" \
--prompt-text "Your reference text" \
--prompt-tokens "fake.npy" \
# --compile
```
This command will create a `codes_N` file in the working directory, where N is an integer starting from 0.
!!! note
You may want to use `--compile` to fuse CUDA kernels for faster inference. However, we recommend using our sglang inference acceleration optimization.
Correspondingly, if you do not plan to use acceleration, you can comment out the `--compile` parameter.
!!! info
For GPUs that do not support bf16, you may need to use the `--half` parameter.
### 3. Generate vocals from semantic tokens:
```bash
python fish_speech/models/dac/inference.py \
-i "codes_0.npy" \
```
After that, you will get a `fake.wav` file.
## WebUI Inference
### 1. Gradio WebUI
For compatibility, we still maintain the Gradio WebUI.
```bash
python tools/run_webui.py # --compile if you need acceleration
```
### 2. Awesome WebUI
Awesome WebUI is a modernized Web interface built with TypeScript, offering richer features and a better user experience.
**Build WebUI:**
You need to have Node.js and npm installed on your local machine or server.
1. Enter the `awesome_webui` directory:
```bash
cd awesome_webui
```
2. Install dependencies:
```bash
npm install
```
3. Build the WebUI:
```bash
npm run build
```
**Start Backend Server:**
After building the WebUI, return to the project root and start the API server:
```bash
python tools/api_server.py --listen 0.0.0.0:8888 --compile
```
**Access:**
Once the server is running, you can access it via your browser:
`http://localhost:8888/ui`
================================================
FILE: docs/en/install.md
================================================
## Requirements
- GPU Memory: 24GB (Inference)
- System: Linux, WSL
## System Setup
Fish Audio S2 supports multiple installation methods. Choose the one that best fits your development environment.
**Prerequisites**: Install system dependencies for audio processing:
``` bash
apt install portaudio19-dev libsox-dev ffmpeg
```
### Conda
```bash
conda create -n fish-speech python=3.12
conda activate fish-speech
# GPU installation (choose your CUDA version: cu126, cu128, cu129)
pip install -e .[cu129]
# CPU-only installation
pip install -e .[cpu]
# Default installation (uses PyTorch default index)
pip install -e .
# If you encounter an error during installation due to pyaudio, consider using the following command:
# conda install pyaudio
# Then run pip install -e . again
```
### UV
UV provides faster dependency resolution and installation:
```bash
# GPU installation (choose your CUDA version: cu126, cu128, cu129)
uv sync --python 3.12 --extra cu129
# CPU-only installation
uv sync --python 3.12 --extra cpu
```
### Intel Arc XPU support
For Intel Arc GPU users, install with XPU support:
```bash
conda create -n fish-speech python=3.12
conda activate fish-speech
# Install required C++ standard library
conda install libstdcxx -c conda-forge
# Install PyTorch with Intel XPU support
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu
# Install Fish Speech
pip install -e .
```
!!! warning
The `compile` option is not supported on Windows and macOS. If you want to run with compile, you need to install Triton manually.
## Docker Setup
Fish Audio S2 series model provides multiple Docker deployment options to suit different needs. You can use pre-built images from Docker Hub, build locally with Docker Compose, or manually build custom images.
We provide Docker images for both WebUI and API server on both GPU (CUDA126 by default) and CPU. You can use the pre-built images from Docker Hub, build locally with Docker Compose, or manually build custom images. If you want to build locally, follow the instructions below. If you only want to use pre-built images, follow the [inference guide](inference.md).
### Prerequisites
- Docker and Docker Compose installed
- NVIDIA Docker runtime (for GPU support)
- At least 24GB GPU memory for CUDA inference
# Use docker compose
For development or customization, you can use Docker Compose to build and run locally:
```bash
# Clone the repository first
git clone https://github.com/fishaudio/fish-speech.git
cd fish-speech
# Start WebUI with CUDA
docker compose --profile webui up
# Start WebUI with compile optimization
COMPILE=1 docker compose --profile webui up
# Start API server
docker compose --profile server up
# Start API server with compile optimization
COMPILE=1 docker compose --profile server up
# For CPU-only deployment
BACKEND=cpu docker compose --profile webui up
```
#### Environment Variables for Docker Compose
You can customize the deployment using environment variables:
```bash
# .env file example
BACKEND=cuda # or cpu
COMPILE=1 # Enable compile optimization
GRADIO_PORT=7860 # WebUI port
API_PORT=8080 # API server port
UV_VERSION=0.8.15 # UV package manager version
```
The command will build the image and run the container. You can access the WebUI at `http://localhost:7860` and the API server at `http://localhost:8080`.
### Manual Docker Build
For advanced users who want to customize the build process:
```bash
# Build WebUI image with CUDA support
docker build \
--platform linux/amd64 \
-f docker/Dockerfile \
--build-arg BACKEND=cuda \
--build-arg CUDA_VER=12.6.0 \
--build-arg UV_EXTRA=cu126 \
--target webui \
-t fish-speech-webui:cuda .
# Build API server image with CUDA support
docker build \
--platform linux/amd64 \
-f docker/Dockerfile \
--build-arg BACKEND=cuda \
--build-arg CUDA_VER=12.6.0 \
--build-arg UV_EXTRA=cu126 \
--target server \
-t fish-speech-server:cuda .
# Build CPU-only images (supports multi-platform)
docker build \
--platform linux/amd64,linux/arm64 \
-f docker/Dockerfile \
--build-arg BACKEND=cpu \
--target webui \
-t fish-speech-webui:cpu .
# Build development image
docker build \
--platform linux/amd64 \
-f docker/Dockerfile \
--build-arg BACKEND=cuda \
--target dev \
-t fish-speech-dev:cuda .
```
#### Build Arguments
- `BACKEND`: `cuda` or `cpu` (default: `cuda`)
- `CUDA_VER`: CUDA version (default: `12.6.0`)
- `UV_EXTRA`: UV extra for CUDA (default: `cu126`)
- `UBUNTU_VER`: Ubuntu version (default: `24.04`)
- `PY_VER`: Python version (default: `3.12`)
### Volume Mounts
Both methods require mounting these directories:
- `./checkpoints:/app/checkpoints` - Model weights directory
- `./references:/app/references` - Reference audio files directory
### Environment Variables
- `COMPILE=1` - Enable torch.compile for faster inference (~10x speedup)
- `GRADIO_SERVER_NAME=0.0.0.0` - WebUI server host
- `GRADIO_SERVER_PORT=7860` - WebUI server port
- `API_SERVER_NAME=0.0.0.0` - API server host
- `API_SERVER_PORT=8080` - API server port
!!! note
The Docker containers expect model weights to be mounted at `/app/checkpoints`. Make sure to download the required model weights before starting the containers.
!!! warning
GPU support requires NVIDIA Docker runtime. For CPU-only deployment, remove the `--gpus all` flag and use CPU images.
================================================
FILE: docs/en/server.md
================================================
# Server
This page covers server-side inference for Fish Audio S2, plus quick links for WebUI inference and Docker deployment.
## API Server Inference
Fish Speech provides an HTTP API server entrypoint at `tools/api_server.py`.
### Start the server locally
```bash
python tools/api_server.py \
--llama-checkpoint-path checkpoints/s2-pro \
--decoder-checkpoint-path checkpoints/s2-pro/codec.pth \
--listen 0.0.0.0:8080
```
Common options:
- `--compile`: enable `torch.compile` optimization
- `--half`: use fp16 mode
- `--api-key`: require bearer token authentication
- `--workers`: set worker process count
### Health check
```bash
curl -X GET http://127.0.0.1:8080/v1/health
```
Expected response:
```json
{"status":"ok"}
```
### Main API endpoint
- `POST /v1/tts` for text-to-speech generation
- `POST /v1/vqgan/encode` for VQ encode
- `POST /v1/vqgan/decode` for VQ decode
## WebUI Inference
For WebUI usage, see:
- [WebUI Inference](https://speech.fish.audio/inference/#webui-inference)
## Docker
For Docker-based server or WebUI deployment, see:
- [Docker Setup](https://speech.fish.audio/install/#docker-setup)
You can also start the server profile directly with Docker Compose:
```bash
docker compose --profile server up
```
================================================
FILE: docs/ja/finetune.md
================================================
# ファインチューニング
このページを開いたということは、明らかに、事前学習済みモデルのゼロショット性能に満足していないということでしょう。データセットでより良い性能を発揮するようにモデルをファインチューニングしたいとお考えのはずです。
現在のバージョンでは、「LLAMA」部分のみをファインチューニングする必要があります。
## LLAMA のファインチューニング
### 1. データセットの準備
```
.
├── SPK1
│ ├── 21.15-26.44.lab
│ ├── 21.15-26.44.mp3
│ ├── 27.51-29.98.lab
│ ├── 27.51-29.98.mp3
│ ├── 30.1-32.71.lab
│ └── 30.1-32.71.mp3
└── SPK2
├── 38.79-40.85.lab
└── 38.79-40.85.mp3
```
データセットを上記の形式に変換し、`data` ディレクトリに配置する必要があります。音声ファイルの拡張子は `.mp3`、`.wav`、または `.flac` が使用でき、注釈ファイルの拡張子は `.lab` にすることを推奨します。
!!! info
`.lab` 注釈ファイルには、音声の書き起こしテキストのみを含める必要があり、特別なフォーマット要件はありません。たとえば、`hi.mp3` の内容が「こんにちは、さようなら。」である場合、`hi.lab` ファイルには「こんにちは、さようなら。」という一行のテキストのみが含まれます。
!!! warning
データセットにラウドネス正規化を適用することをお勧めします。これには [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) を使用できます。
```bash
fap loudness-norm data-raw data --clean
```
### 2. セマンティックトークンの一括抽出
VQGANの重みをダウンロードしていることを確認してください。まだの場合は、次のコマンドを実行してください。
```bash
huggingface-cli download fishaudio/openaudio-s1-mini --local-dir checkpoints/openaudio-s1-mini
```
その後、次のコマンドを実行してセマンティックトークンを抽出できます。
```bash
python tools/vqgan/extract_vq.py data \
--num-workers 1 --batch-size 16 \
--config-name "modded_dac_vq" \
--checkpoint-path "checkpoints/openaudio-s1-mini/codec.pth"
```
!!! note
`--num-workers` と `--batch-size` を調整して抽出速度を向上させることができますが、GPUメモリの制限を超えないように注意してください。
このコマンドは `data` ディレクトリに `.npy` ファイルを作成します。以下のようになります。
```
.
├── SPK1
│ ├── 21.15-26.44.lab
│ ├── 21.15-26.44.mp3
│ ├── 21.15-26.44.npy
│ ├── 27.51-29.98.lab
│ ├── 27.51-29.98.mp3
│ ├── 27.51-29.98.npy
│ ├── 30.1-32.71.lab
│ ├── 30.1-32.71.mp3
│ └── 30.1-32.71.npy
└── SPK2
├── 38.79-40.85.lab
├── 38.79-40.85.mp3
└── 38.79-40.85.npy```
### 3. データセットを protobuf にパックする
```bash
python tools/llama/build_dataset.py \
--input "data" \
--output "data/protos" \
--text-extension .lab \
--num-workers 16
```
コマンドの実行が完了すると、`data` ディレクトリに `protos` ファイルが表示されるはずです。
### 4. 最後に LoRA でファインチューニング
同様に、`LLAMA` の重みをダウンロードしていることを確認してください。まだの場合は、次のコマンドを実行してください。
```bash
huggingface-cli download fishaudio/openaudio-s1-mini --local-dir checkpoints/openaudio-s1-mini
```
最後に、次のコマンドを実行してファインチューニングを開始できます。
```bash
python fish_speech/train.py --config-name text2semantic_finetune \
project=$project \
+lora@model.model.lora_config=r_8_alpha_16
```
!!! note
`fish_speech/configs/text2semantic_finetune.yaml` を変更することで、`batch_size` や `gradient_accumulation_steps` などのトレーニングパラメータをGPUメモリに合わせて変更できます。
!!! note
Windows ユーザーの場合、`trainer.strategy.process_group_backend=gloo` を使用して `nccl` の問題を回避できます。
トレーニングが完了したら、[推論](inference.md) のセクションを参照してモデルをテストできます。
!!! info
デフォルト設定では、モデルは話者の発音方法のみを学習し、音色は学習しません。音色の安定性を確保するためには、依然としてプロンプトを使用する必要があります。
音色を学習させたい場合は、トレーニングステップ数を増やしてください。ただし、これにより過学習が発生する可能性があります。
トレーニング後、推論を行う前に LoRA の重みを通常の重みに変換する必要があります。
```bash
python tools/llama/merge_lora.py \
--lora-config r_8_alpha_16 \
--base-weight checkpoints/openaudio-s1-mini \
--lora-weight results/$project/checkpoints/step_000000010.ckpt \
--output checkpoints/openaudio-s1-mini-yth-lora/
```
!!! note
他のチェックポイントを試すこともできます。要件を満たす最も早いチェックポイントを使用することをお勧めします。これらは通常、OOD(分布外)データに対してより良いパフォーマンスを発揮します。
================================================
FILE: docs/ja/index.md
================================================
!!! info "ライセンス通知"
このコードベースおよび関連するモデルの重みは **FISH AUDIO RESEARCH LICENSE** の下でリリースされています。詳細は [LICENSE](https://github.com/fishaudio/fish-speech/blob/main/LICENSE) を参照してください。
!!! warning "法的免責事項"
私たちは、コードベースのいかなる違法な使用に対しても責任を負いません。DMCA およびその他の関連法に関する現地の規制を参照してください。
## クイックスタート
### まずはドキュメントから
Fish Audio S2 の公式ドキュメントです。以下からすぐに始められます。
- [インストール](https://speech.fish.audio/ja/install/)
- [コマンドライン推論](https://speech.fish.audio/ja/inference/)
- [WebUI 推論](https://speech.fish.audio/ja/inference/)
- [サーバー推論](https://speech.fish.audio/ja/server/)
- [Docker セットアップ](https://speech.fish.audio/ja/install/)
> [!IMPORTANT]
> **SGLang サーバーについては [SGLang-Omni README](https://github.com/sgl-project/sglang-omni/blob/main/sglang_omni/models/fishaudio_s2_pro/README.md) を参照してください。**
### LLM Agent 向け
```
https://speech.fish.audio/ja/install/ の手順に従って、Fish Audio S2 をインストール・設定してください。
```
## Fish Audio S2
**オープンソースおよびクローズドソースの中で最も優れたテキスト読み上げシステム**
Fish Audio S2 は [Fish Audio](https://fish.audio/) が開発した最新モデルです。約 50 言語・1,000 万時間超の音声データで学習され、強化学習アラインメントと Dual-Autoregressive アーキテクチャを組み合わせることで、自然でリアルかつ感情表現豊かな音声を生成します。
S2 は `[laugh]`、`[whispers]`、`[super happy]` といった自然言語タグで、韻律や感情を文中の任意位置で細かく制御できます。さらに、マルチスピーカー生成とマルチターン生成にもネイティブ対応しています。
ライブデモは [Fish Audio ウェブサイト](https://fish.audio/) から、詳細は [ブログ記事](https://fish.audio/blog/fish-audio-open-sources-s2/) と [技術レポート](https://arxiv.org/abs/2603.08823) をご覧ください。
### モデルバリアント
| モデル | サイズ | 利用可能性 | 説明 |
|------|------|-------------|-------------|
| S2-Pro | 4B パラメータ | [HuggingFace](https://huggingface.co/fishaudio/s2-pro) | 品質と安定性を最大化したフル機能のフラッグシップモデル |
モデルの詳細は[技術レポート](https://arxiv.org/abs/2411.01156)をご参照ください。
## ベンチマーク結果
| ベンチマーク | Fish Audio S2 |
|------|------|
| Seed-TTS Eval — WER(中国語) | **0.54%**(全体最良) |
| Seed-TTS Eval — WER(英語) | **0.99%**(全体最良) |
| Audio Turing Test(指示あり) | **0.515** 事後平均値 |
| EmergentTTS-Eval — 勝率 | **81.88%**(全体最高) |
| Fish Instruction Benchmark — TAR | **93.3%** |
| Fish Instruction Benchmark — 品質 | **4.51 / 5.0** |
| 多言語(MiniMax Testset)— 最良 WER | **24 言語中 11 言語** |
| 多言語(MiniMax Testset)— 最良 SIM | **24 言語中 17 言語** |
Seed-TTS Eval では、S2 はクローズドソースを含む全評価モデルの中で最小 WER を達成しました:Qwen3-TTS(0.77/1.24)、MiniMax Speech-02(0.99/1.90)、Seed-TTS(1.12/2.25)。Audio Turing Test では 0.515 を記録し、Seed-TTS(0.417)比で 24%、MiniMax-Speech(0.387)比で 33% 上回りました。EmergentTTS-Eval では、副言語情報(91.61%)、疑問文(84.41%)、統語的複雑性(83.39%)で特に高い成績を示しています。
## ハイライト
### 自然言語による細粒度インライン制御
Fish Audio S2 では、テキスト内の特定の単語やフレーズ位置に自然言語の指示を直接埋め込むことで、音声生成を局所的に制御できます。固定の事前定義タグに依存するのではなく、S2 は [whisper in small voice]、[professional broadcast tone]、[pitch up] のような自由形式のテキスト記述を受け付け、単語レベルで表現をオープンエンドに制御できます。
### 二重自己回帰(Dual-Autoregressive)アーキテクチャ
S2 はデコーダー専用 Transformer と RVQ ベースの音声コーデック(10 codebooks、約 21 Hz)を組み合わせています。Dual-AR は生成を 2 段階に分割します。
- **Slow AR** は時間軸方向に動作し、主となる semantic codebook を予測。
- **Fast AR** は各時刻で残り 9 個の residual codebook を生成し、細かな音響ディテールを復元。
この非対称設計(時間軸 4B パラメータ、深さ軸 400M パラメータ)により、音質を保ちながら推論効率を高めています。
### 強化学習アラインメント
S2 は後学習アラインメントに Group Relative Policy Optimization(GRPO)を採用しています。学習データのフィルタリングとアノテーションに使った同一モデル群を、そのまま RL の報酬モデルとして再利用することで、事前学習データ分布と事後学習目的のミスマッチを抑制しています。報酬信号には、意味的正確性、指示追従性、音響的選好スコア、音色類似度が含まれます。
### SGLang による本番向けストリーミング
Dual-AR は構造的に標準的な自己回帰 LLM と同型のため、S2 は SGLang の LLM 向け最適化をそのまま活用できます。たとえば continuous batching、paged KV cache、CUDA graph replay、RadixAttention ベースの prefix caching です。
単一の NVIDIA H200 GPU での実測:
- **RTF(Real-Time Factor):** 0.195
- **初回音声出力までの時間:** 約 100 ms
- **スループット:** RTF 0.5 未満を維持しつつ 3,000+ acoustic tokens/s
### 多言語サポート
Fish Audio S2 は、音素や言語固有の前処理を必要とせずに、高品質な多言語テキスト読み上げをサポートします。以下を含みます:
**英語、中国語、日本語、韓国語、アラビア語、ドイツ語、フランス語...**
**さらに多く!**
リストは常に拡大しています。最新のリリースについては [Fish Audio](https://fish.audio/) を確認してください。
### ネイティブなマルチスピーカー生成
Fish Audio S2 では、ユーザーが複数のスピーカーを含む参照オーディオをアップロードでき、モデルは `<|speaker:i|>` トークンを介して各スピーカーの特徴を処理します。その後、スピーカーIDトークンを使用してモデルのパフォーマンスを制御し、1回の生成で複数のスピーカーを含めることができます。以前のように各スピーカーに対して個別に参照オーディオをアップロードして音声を生成する必要はもうありません。
### マルチターン対話生成
モデルのコンテキストの拡張により、以前の情報を使用して後続の生成されたコンテンツの表現力を向上させ、コンテンツの自然さを高めることができるようになりました。
### 高速音声クローニング
Fish Audio S2 は、短い参照サンプル(通常10〜30秒)を使用した正確な音声クローニングをサポートしています。モデルは音色、話し方、感情的な傾向を捉え、追加の微調整なしでリアルで一貫したクローン音声を生成します。
SGLang サーバーの利用については [SGLang-Omni README](https://github.com/sgl-project/sglang-omni/blob/main/sglang_omni/models/fishaudio_s2_pro/README.md) を参照してください。
---
## クレジット
- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
- [GPT VITS](https://github.com/innnky/gpt-vits)
- [MQTTS](https://github.com/b04901014/MQTTS)
- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
- [Qwen3](https://github.com/QwenLM/Qwen3)
## 技術レポート
```bibtex
@misc{fish-speech-v1.4,
title={Fish-Speech: Leveraging Large Language Models for Advanced Multilingual Text-to-Speech Synthesis},
author={Shijia Liao and Yuxuan Wang and Tianyu Li and Yifan Cheng and Ruoyi Zhang and Rongzhi Zhou and Yijin Xing},
year={2024},
eprint={2411.01156},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2411.01156},
}
@misc{liao2026fishaudios2technical,
title={Fish Audio S2 Technical Report},
author={Shijia Liao and Yuxuan Wang and Songting Liu and Yifan Cheng and Ruoyi Zhang and Tianyu Li and Shidong Li and Yisheng Zheng and Xingwei Liu and Qingzheng Wang and Zhizhuo Zhou and Jiahua Liu and Xin Chen and Dawei Han},
year={2026},
eprint={2603.08823},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2603.08823},
}
```
================================================
FILE: docs/ja/inference.md
================================================
# 推論
Fish Audio S2 モデルは大きなビデオメモリを必要とします。推論には少なくとも 24GB の GPU を使用することをお勧めします。
## 重みのダウンロード
まず、モデルの重みをダウンロードする必要があります:
```bash
hf download fishaudio/s2-pro --local-dir checkpoints/s2-pro
```
## コマンドライン推論
!!! note
モデルに音声をランダムに選択させる場合は、このステップをスキップできます。
### 1. リファレンスオーディオから VQ トークンを取得する
```bash
python fish_speech/models/dac/inference.py \
-i "test.wav" \
--checkpoint-path "checkpoints/s2-pro/codec.pth"
```
`fake.npy` と `fake.wav` が生成されるはずです。
### 2. テキストから Semantic トークンを生成する:
```bash
python fish_speech/models/text2semantic/inference.py \
--text "変換したいテキスト" \
--prompt-text "リファレンステキスト" \
--prompt-tokens "fake.npy" \
# --compile
```
このコマンドは、作業ディレクトリに `codes_N` ファイルを作成します。ここで N は 0 から始まる整数です。
!!! note
より高速な推論のために CUDA カーネルを融合する `--compile` を使用したい場合がありますが、私たちの sglang 推論加速最適化を使用することをお勧めします。
同様に、加速を使用する予定がない場合は、`--compile` パラメータをコメントアウトしてください。
!!! info
bf16 をサポートしていない GPU の場合、`--half` パラメータを使用する必要があるかもしれません。
### 3. セマンティックトークンから音声を生成する:
```bash
python fish_speech/models/dac/inference.py \
-i "codes_0.npy" \
```
その後、`fake.wav` ファイルが取得できます。
## WebUI 推論
### 1. Gradio WebUI
互換性を維持するため、以前の Gradio WebUI も引き続き利用可能です。
```bash
python tools/run_webui.py # 加速が必要な場合は --compile
```
### 2. Awesome WebUI
Awesome WebUI は TypeScript で開発された、より豊富な機能と優れたユーザー体験を提供する最新の Web インターフェースです。
**WebUI のビルド:**
ローカルまたはサーバーに Node.js と npm がインストールされている必要があります。
1. `awesome_webui` ディレクトリに移動します:
```bash
cd awesome_webui
```
2. 依存関係をインストールします:
```bash
npm install
```
3. WebUI をビルドします:
```bash
npm run build
```
**バックエンドサーバーの起動:**
WebUI のビルドが完了したら、プロジェクトのルートに戻り、API サーバーを起動します:
```bash
python tools/api_server.py --listen 0.0.0.0:8888 --compile
```
**アクセス:**
サーバーが起動したら、ブラウザから以下のアドレスにアクセスして体験できます:
`http://localhost:8888/ui`
================================================
FILE: docs/ja/install.md
================================================
## 必要条件
- GPUメモリ: 24GB (推論時)
- システム: Linux, WSL
## システムセットアップ
Fish Audio S2は複数のインストール方法をサポートしています。ご自身の開発環境に最も適した方法をお選びください。
**前提条件**: 音声処理のためのシステム依存関係をインストールします:
``` bash
apt install portaudio19-dev libsox-dev ffmpeg
```
### Conda
```bash
conda create -n fish-speech python=3.12
conda activate fish-speech
# GPU版のインストール (CUDAバージョンを選択: cu126, cu128, cu129)
pip install -e .[cu129]
# CPU版のみのインストール
pip install -e .[cpu]
# デフォルトインストール (PyTorchのデフォルトインデックスを使用)
pip install -e .
# pyaudioのインストールでエラーが発生する場合は、以下のコマンドを試してください:
# conda install pyaudio
# その後、再度 pip install -e . を実行してください
```
### UV
UVはより高速な依存関係の解決とインストールを実現します:
```bash
# GPU版のインストール (CUDAバージョンを選択: cu126, cu128, cu129)
uv sync --python 3.12 --extra cu129
# CPU版のみのインストール
uv sync --python 3.12 --extra cpu
```
### Intel Arc XPU サポート
Intel Arc GPUユーザーは、以下の手順でXPUサポートをインストールしてください:
```bash
conda create -n fish-speech python=3.12
conda activate fish-speech
# 必要なC++標準ライブラリをインストール
conda install libstdcxx -c conda-forge
# Intel XPU対応のPyTorchをインストール
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu
# Fish Speechのインストール
pip install -e .
```
!!! warning
`compile`オプションはWindowsとmacOSではサポートされていません。コンパイルを有効にして実行したい場合は、ご自身でTritonをインストールする必要があります。
## Dockerセットアップ
Fish Audio S2シリーズモデルは、さまざまなニーズに応えるため複数のDockerデプロイメントオプションを提供しています。Docker Hubのビルド済みイメージを使用するか、Docker Composeでローカルビルドするか、手動でカスタムイメージをビルドすることができます。
WebUIとAPIサーバーの両方について、GPU(デフォルトはCUDA 12.6)版とCPU版のDockerイメージを提供しています。Docker Hubのビルド済みイメージを使用するか、Docker Composeでローカルビルドするか、手動でカスタムイメージをビルドするかを選択できます。ローカルでビルドする場合は、以下の手順に従ってください。ビルド済みイメージを使用するだけの場合は、[推論ガイド](inference.md)を直接参照してください。
### 前提条件
- DockerとDocker Composeがインストール済みであること
- NVIDIA Dockerランタイムがインストール済みであること(GPUサポート用)
- CUDAによる推論のために、少なくとも24GBのGPUメモリがあること
### Docker Composeの使用
開発やカスタマイズのために、Docker Composeを使用してローカルでビルド・実行できます:
```bash
# まず、リポジトリをクローンします
git clone https://github.com/fishaudio/fish-speech.git
cd fish-speech
# CUDAでWebUIを起動
docker compose --profile webui up
# コンパイル最適化を有効にしてWebUIを起動
COMPILE=1 docker compose --profile webui up
# APIサーバーを起動
docker compose --profile server up
# コンパイル最適化を有効にしてAPIサーバーを起動
COMPILE=1 docker compose --profile server up
# CPUのみでのデプロイ
BACKEND=cpu docker compose --profile webui up
```
#### Docker Compose 環境変数
環境変数を使用してデプロイメントをカスタマイズできます:
```bash
# .env ファイルの例
BACKEND=cuda # または cpu
COMPILE=1 # コンパイル最適化を有効化
GRADIO_PORT=7860 # WebUIのポート
API_PORT=8080 # APIサーバーのポート
UV_VERSION=0.8.15 # UVパッケージマネージャーのバージョン
```
このコマンドはイメージをビルドし、コンテナを実行します。WebUIには`http://localhost:7860`で、APIサーバーには`http://localhost:8080`でアクセスできます。
### 手動でのDockerビルド
ビルドプロセスをカスタマイズしたい上級者向け:
```bash
# CUDAサポート付きのWebUIイメージをビルド
docker build \
--platform linux/amd64 \
-f docker/Dockerfile \
--build-arg BACKEND=cuda \
--build-arg CUDA_VER=12.6.0 \
--build-arg UV_EXTRA=cu126 \
--target webui \
-t fish-speech-webui:cuda .
# CUDAサポート付きのAPIサーバーイメージをビルド
docker build \
--platform linux/amd64 \
-f docker/Dockerfile \
--build-arg BACKEND=cuda \
--build-arg CUDA_VER=12.6.0 \
--build-arg UV_EXTRA=cu126 \
--target server \
-t fish-speech-server:cuda .
# CPUのみのイメージをビルド(マルチプラットフォーム対応)
docker build \
--platform linux/amd64,linux/arm64 \
-f docker/Dockerfile \
--build-arg BACKEND=cpu \
--target webui \
-t fish-speech-webui:cpu .
# 開発用イメージをビルド
docker build \
--platform linux/amd64 \
-f docker/Dockerfile \
--build-arg BACKEND=cuda \
--target dev \
-t fish-speech-dev:cuda .
```
#### ビルド引数
- `BACKEND`: `cuda` または `cpu` (デフォルト: `cuda`)
- `CUDA_VER`: CUDAバージョン (デフォルト: `12.6.0`)
- `UV_EXTRA`: CUDA用のUV追加パッケージ (デフォルト: `cu126`)
- `UBUNTU_VER`: Ubuntuバージョン (デフォルト: `24.04`)
- `PY_VER`: Pythonバージョン (デフォルト: `3.12`)
### ボリュームマウント
どちらの方法でも、以下のディレクトリをマウントする必要があります:
- `./checkpoints:/app/checkpoints` - モデルの重みファイル用ディレクトリ
- `./references:/app/references` - 参照音声ファイル用ディレクトリ
### 環境変数
- `COMPILE=1` - `torch.compile`を有効にして推論を高速化(約10倍)
- `GRADIO_SERVER_NAME=0.0.0.0` - WebUIサーバーのホスト
- `GRADIO_SERVER_PORT=7860` - WebUIサーバーのポート
- `API_SERVER_NAME=0.0.0.0` - APIサーバーのホスト
- `API_SERVER_PORT=8080` - APIサーバーのポート
!!! note
Dockerコンテナは、モデルの重みが`/app/checkpoints`にマウントされることを想定しています。コンテナを起動する前に、必要なモデルの重みをダウンロードしてください。
!!! warning
GPUサポートにはNVIDIA Dockerランタイムが必要です。CPUのみでデプロイする場合は、`--gpus all`フラグを削除し、CPU用のイメージを使用してください。
================================================
FILE: docs/ko/finetune.md
================================================
# 미세 조정 (Fine-tuning)
이 페이지를 열었다는 것은, 사전 훈련된 모델의 제로샷(zero-shot) 성능에 만족하지 못했다는 의미일 것입니다. 여러분의 데이터셋에서 더 나은 성능을 내도록 모델을 미세 조정하고 싶으실 겁니다.
현재 버전에서는 'LLAMA' 부분만 미세 조정하면 됩니다.
## LLAMA 미세 조정
### 1. 데이터셋 준비
```
.
├── SPK1
│ ├── 21.15-26.44.lab
│ ├── 21.15-26.44.mp3
│ ├── 27.51-29.98.lab
│ ├── 27.51-29.98.mp3
│ ├── 30.1-32.71.lab
│ └── 30.1-32.71.mp3
└── SPK2
├── 38.79-40.85.lab
└── 38.79-40.85.mp3
```
데이터셋을 위 형식으로 변환하여 `data` 폴더 아래에 배치해야 합니다. 오디오 파일 확장자는 `.mp3`, `.wav` 또는 `.flac`일 수 있으며, 주석 파일 확장자는 `.lab`을 권장합니다.
!!! info
`.lab` 주석 파일에는 오디오의 전사 텍스트만 포함하면 되며, 특별한 형식 요구사항은 없습니다. 예를 들어 `hi.mp3`의 내용이 "안녕하세요, 안녕히 가세요."라면, `hi.lab` 파일에는 "안녕하세요, 안녕히 가세요."라는 한 줄의 텍스트만 포함하면 됩니다.
!!! warning
데이터셋에 음량 정규화를 적용하는 것이 좋습니다. 이를 위해 [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess)를 사용할 수 있습니다.
```bash
fap loudness-norm data-raw data --clean
```
### 2. 시맨틱 토큰 일괄 추출
VQGAN 가중치를 다운로드했는지 확인하세요. 그렇지 않은 경우 다음 명령을 실행하세요.
```bash
huggingface-cli download fishaudio/openaudio-s1-mini --local-dir checkpoints/openaudio-s1-mini
```
그런 다음 다음 명령을 실행하여 시맨틱 토큰을 추출할 수 있습니다.
```bash
python tools/vqgan/extract_vq.py data \
--num-workers 1 --batch-size 16 \
--config-name "modded_dac_vq" \
--checkpoint-path "checkpoints/openaudio-s1-mini/codec.pth"
```
!!! note
`--num-workers`와 `--batch-size`를 조정하여 추출 속도를 높일 수 있지만, GPU 메모리 한도를 초과하지 않도록 주의하세요.
이 명령은 `data` 디렉토리에 `.npy` 파일을 생성합니다. 결과는 다음과 같습니다.
```
.
├── SPK1
│ ├── 21.15-26.44.lab
│ ├── 21.15-26.44.mp3
│ ├── 21.15-26.44.npy
│ ├── 27.51-29.98.lab
│ ├── 27.51-29.98.mp3
│ ├── 27.51-29.98.npy
│ ├── 30.1-32.71.lab
│ ├── 30.1-32.71.mp3
│ └── 30.1-32.71.npy
└── SPK2
├── 38.79-40.85.lab
├── 38.79-40.85.mp3
└── 38.79-40.85.npy
```
### 3. 데이터셋을 protobuf로 패킹하기
```bash
python tools/llama/build_dataset.py \
--input "data" \
--output "data/protos" \
--text-extension .lab \
--num-workers 16
```
명령 실행이 완료되면 `data` 디렉토리에서 `protos` 파일을 볼 수 있어야 합니다.
### 4. 마지막으로, LoRA로 미세 조정하기
마찬가지로, `LLAMA` 가중치를 다운로드했는지 확인하세요. 그렇지 않은 경우 다음 명령을 실행하세요.
```bash
huggingface-cli download fishaudio/openaudio-s1-mini --local-dir checkpoints/openaudio-s1-mini
```
마지막으로, 다음 명령을 실행하여 미세 조정을 시작할 수 있습니다.
```bash
python fish_speech/train.py --config-name text2semantic_finetune \
project=$project \
+lora@model.model.lora_config=r_8_alpha_16
```
!!! note
`fish_speech/configs/text2semantic_finetune.yaml` 파일을 수정하여 `batch_size`, `gradient_accumulation_steps` 등 훈련 매개변수를 GPU 메모리에 맞게 조정할 수 있습니다.
!!! note
Windows 사용자의 경우, `trainer.strategy.process_group_backend=gloo`를 사용하여 `nccl` 관련 문제를 피할 수 있습니다.
훈련이 완료되면 [추론](inference.md) 섹션을 참조하여 모델을 테스트할 수 있습니다.
!!! info
기본 설정에서는 모델이 화자의 발음 방식만 학습하고 음색은 학습하지 않습니다. 음색 안정성을 보장하려면 여전히 프롬프트를 사용해야 합니다.
음색을 학습시키고 싶다면 훈련 스텝 수를 늘리되, 이는 과적합(overfitting)으로 이어질 수 있습니다.
훈련 후, 추론을 수행하기 전에 LoRA 가중치를 일반 가중치로 변환해야 합니다.
```bash
python tools/llama/merge_lora.py \
--lora-config r_8_alpha_16 \
--base-weight checkpoints/openaudio-s1-mini \
--lora-weight results/$project/checkpoints/step_000000010.ckpt \
--output checkpoints/openaudio-s1-mini-yth-lora/
```
!!! note
다른 체크포인트를 시도해 볼 수도 있습니다. 요구 사항을 충족하는 가장 이른 체크포인트를 사용하는 것이 좋습니다. 이러한 체크포인트는 보통 OOD(분포 외) 데이터에서 더 나은 성능을 보입니다.
================================================
FILE: docs/ko/index.md
================================================
!!! info "라이선스 공지"
이 코드베이스 및 관련 모델 가중치는 **FISH AUDIO RESEARCH LICENSE** 하에 릴리스되었습니다. 자세한 내용은 [LICENSE](https://github.com/fishaudio/fish-speech/blob/main/LICENSE)를 참조하십시오.
!!! warning "법적 면책 조항"
코드베이스의 불법적인 사용에 대해 당사는 어떠한 책임도 지지 않습니다. DMCA 및 기타 관련 법률에 관한 현지 규정을 참조하십시오.
## 빠른 시작
### 문서로 바로 시작하기
Fish Audio S2 공식 문서입니다. 아래 링크에서 바로 시작할 수 있습니다.
- [설치](https://speech.fish.audio/ko/install/)
- [커맨드라인 추론](https://speech.fish.audio/ko/inference/)
- [WebUI 추론](https://speech.fish.audio/ko/inference/)
- [서버 추론](https://speech.fish.audio/ko/server/)
- [Docker 설정](https://speech.fish.audio/ko/install/)
> [!IMPORTANT]
> **SGLang 서버는 [SGLang-Omni README](https://github.com/sgl-project/sglang-omni/blob/main/sglang_omni/models/fishaudio_s2_pro/README.md)를 참고하세요.**
### LLM Agent 가이드
```
https://speech.fish.audio/ko/install/ 문서를 따라 Fish Audio S2를 설치하고 구성하세요.
```
## Fish Audio S2
**오픈 소스와 클로즈드 소스 모두에서 가장 뛰어난 텍스트 음성 변환 시스템**
Fish Audio S2는 [Fish Audio](https://fish.audio/)가 개발한 최신 모델입니다. 약 50개 언어, 1,000만 시간 이상의 오디오 데이터로 학습되었고, 강화학습 정렬과 Dual-Autoregressive 아키텍처를 결합해 자연스럽고 사실적이며 감정 표현이 풍부한 음성을 생성합니다.
S2는 `[laugh]`, `[whispers]`, `[super happy]` 같은 자연어 태그를 사용해 운율과 감정을 문장 내부에서 세밀하게 제어할 수 있으며, 멀티 화자/멀티 턴 생성도 네이티브로 지원합니다.
실시간 데모는 [Fish Audio 웹사이트](https://fish.audio/)에서, 자세한 내용은 [블로그 글](https://fish.audio/blog/fish-audio-open-sources-s2/)과 [기술 보고서](https://arxiv.org/abs/2603.08823)에서 확인할 수 있습니다.
### 모델 변형
| 모델 | 크기 | 가용성 | 설명 |
|------|------|-------------|-------------|
| S2-Pro | 4B 매개변수 | [HuggingFace](https://huggingface.co/fishaudio/s2-pro) | 최고 수준의 품질과 안정성을 제공하는 풀기능 플래그십 모델 |
모델 상세는 [기술 보고서](https://arxiv.org/abs/2411.01156)를 참고하세요.
## 벤치마크 결과
| 벤치마크 | Fish Audio S2 |
|------|------|
| Seed-TTS Eval — WER (중국어) | **0.54%** (전체 최고) |
| Seed-TTS Eval — WER (영어) | **0.99%** (전체 최고) |
| Audio Turing Test (지시 포함) | **0.515** 사후 평균 |
| EmergentTTS-Eval — 승률 | **81.88%** (전체 최고) |
| Fish Instruction Benchmark — TAR | **93.3%** |
| Fish Instruction Benchmark — 품질 | **4.51 / 5.0** |
| 다국어 (MiniMax Testset) — 최고 WER | **24개 언어 중 11개** |
| 다국어 (MiniMax Testset) — 최고 SIM | **24개 언어 중 17개** |
Seed-TTS Eval에서 S2는 클로즈드 소스 시스템을 포함한 전체 비교 모델 중 가장 낮은 WER를 기록했습니다: Qwen3-TTS (0.77/1.24), MiniMax Speech-02 (0.99/1.90), Seed-TTS (1.12/2.25). Audio Turing Test에서는 0.515를 기록해 Seed-TTS (0.417) 대비 24%, MiniMax-Speech (0.387) 대비 33% 높았습니다. EmergentTTS-Eval에서는 파라언어 표현(91.61%), 의문문(84.41%), 구문 복잡도(83.39%)에서 특히 강한 성능을 보였습니다.
## 주요 특징
### 자연어 기반 세밀한 인라인 제어
Fish Audio S2는 텍스트의 특정 단어 또는 구문 위치에 자연어 지시를 직접 삽입해 음성 생성을 국소적으로 제어할 수 있습니다. 고정된 사전 정의 태그에 의존하는 대신, S2는 [whisper in small voice], [professional broadcast tone], [pitch up] 같은 자유 형식 텍스트 설명을 받아 단어 수준의 개방형 표현 제어를 지원합니다.
### Dual-Autoregressive 아키텍처
S2는 decoder-only Transformer와 RVQ 기반 오디오 코덱(10 codebooks, 약 21 Hz 프레임레이트)을 결합합니다. Dual-AR은 생성 과정을 두 단계로 나눕니다.
- **Slow AR**: 시간축을 따라 동작하며 주 semantic codebook을 예측
- **Fast AR**: 각 시점에서 나머지 9개 residual codebook을 생성해 세밀한 음향 디테일을 복원
이 비대칭 설계(시간축 4B 파라미터, 깊이축 400M 파라미터)는 음질을 유지하면서 추론 효율을 높입니다.
### 강화학습 정렬
S2는 후학습 정렬을 위해 Group Relative Policy Optimization(GRPO)을 사용합니다. 학습 데이터 필터링/라벨링에 쓰인 동일한 모델을 RL 보상 모델로 재사용해, 사전학습 데이터 분포와 후학습 목표 간의 분포 불일치를 줄였습니다. 보상 신호는 의미 정확도, 지시 준수도, 음향 선호 점수, 음색 유사도를 함께 반영합니다.
### SGLang 기반 프로덕션 스트리밍
Dual-AR 구조는 표준 자기회귀 LLM과 구조적으로 동형이기 때문에, S2는 SGLang의 LLM 서빙 최적화를 그대로 활용합니다. 예: continuous batching, paged KV cache, CUDA graph replay, RadixAttention 기반 prefix caching.
NVIDIA H200 단일 GPU 기준:
- **실시간 계수(RTF):** 0.195
- **첫 오디오 출력까지 시간:** 약 100 ms
- **처리량:** RTF 0.5 미만 유지 시 3,000+ acoustic tokens/s
### 다국어 지원
Fish Audio S2는 음소나 언어별 전처리 없이 고품질 다국어 텍스트 음성 변환을 지원합니다. 포함 사항:
**영어, 중국어, 일본어, 한국어, 아랍어, 독일어, 프랑스어...**
**그리고 더 많이!**
목록은 계속 확장되고 있습니다. 최신 릴리스는 [Fish Audio](https://fish.audio/)를 확인하세요.
### 네이티브 멀티 화자 생성
Fish Audio S2는 사용자가 여러 화자가 포함된 참조 오디오를 업로드할 수 있도록 하며, 모델은 `<|speaker:i|>` 토큰을 통해 각 화자의 특징을 처리합니다. 그런 다음 화자 ID 토큰으로 모델의 성능을 제어하여 한 번의 생성으로 여러 화자를 포함할 수 있습니다. 이전처럼 각 화자마다 별도로 참조 오디오를 업로드하고 음성을 생성할 필요가 없습니다.
### 멀티 턴 대화 생성
모델 컨텍스트의 확장 덕분에 이제 이전 정보를 활용하여 후속 생성 콘텐츠의 표현력을 높이고 콘텐츠의 자연스러움을 향상시킬 수 있습니다.
### 빠른 음성 복제
Fish Audio S2는 짧은 참조 샘플(일반적으로 10-30초)을 사용하여 정확한 음성 복제를 지원합니다. 모델은 음색, 말하기 스타일 및 감정적 경향을 캡처하여 추가 미세 조정 없이 사실적이고 일관된 복제 음성을 생성합니다.
SGLang 서버 사용은 [SGLang-Omni README](https://github.com/sgl-project/sglang-omni/blob/main/sglang_omni/models/fishaudio_s2_pro/README.md) 를 참고하세요.
---
## 크레딧
- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
- [GPT VITS](https://github.com/innnky/gpt-vits)
- [MQTTS](https://github.com/b04901014/MQTTS)
- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
- [Qwen3](https://github.com/QwenLM/Qwen3)
## 기술 보고서
```bibtex
@misc{fish-speech-v1.4,
title={Fish-Speech: Leveraging Large Language Models for Advanced Multilingual Text-to-Speech Synthesis},
author={Shijia Liao and Yuxuan Wang and Tianyu Li and Yifan Cheng and Ruoyi Zhang and Rongzhi Zhou and Yijin Xing},
year={2024},
eprint={2411.01156},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2411.01156},
}
@misc{liao2026fishaudios2technical,
title={Fish Audio S2 Technical Report},
author={Shijia Liao and Yuxuan Wang and Songting Liu and Yifan Cheng and Ruoyi Zhang and Tianyu Li and Shidong Li and Yisheng Zheng and Xingwei Liu and Qingzheng Wang and Zhizhuo Zhou and Jiahua Liu and Xin Chen and Dawei Han},
year={2026},
eprint={2603.08823},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2603.08823},
}
```
================================================
FILE: docs/ko/inference.md
================================================
# 추론
Fish Audio S2 모델은 큰 비디오 메모리(VRAM)가 필요합니다. 추론을 위해 최소 24GB 이상의 GPU를 사용하는 것을 권장합니다.
## 가중치 다운로드
먼저 모델 가중치를 다운로드해야 합니다:
```bash
hf download fishaudio/s2-pro --local-dir checkpoints/s2-pro
```
## 명령줄 추론
!!! note
모델이 음색을 무작위로 선택하게 하려면 이 단계를 건너뛸 수 있습니다.
### 1. 참조 오디오에서 VQ 토큰 가져오기
```bash
python fish_speech/models/dac/inference.py \
-i "test.wav" \
--checkpoint-path "checkpoints/s2-pro/codec.pth"
```
`fake.npy`와 `fake.wav` 파일이 생성됩니다.
### 2. 텍스트에서 Semantic 토큰 생성:
```bash
python fish_speech/models/text2semantic/inference.py \
--text "변환하려는 텍스트" \
--prompt-text "참조 텍스트" \
--prompt-tokens "fake.npy" \
# --compile
```
이 명령은 작업 디렉토리에 `codes_N` 파일을 생성합니다. 여기서 N은 0부터 시작하는 정수입니다.
!!! note
더 빠른 추론을 위해 CUDA 커널을 병합하는 `--compile`을 사용하고 싶을 수 있지만, 당사의 sglang 추론 가속 최적화를 사용하는 것을 더 권장합니다.
마찬가지로 가속을 사용할 계획이 없다면 `--compile` 매개변수를 주석 처리할 수 있습니다.
!!! info
bf16을 지원하지 않는 GPU의 경우 `--half` 매개변수를 사용해야 할 수 있습니다.
### 3. 시맨틱 토큰에서 음성 생성:
```bash
python fish_speech/models/dac/inference.py \
-i "codes_0.npy" \
```
이후 `fake.wav` 파일을 얻게 됩니다.
## WebUI 추론
### 1. Gradio WebUI
호환성을 유지하기 위해 기존의 Gradio WebUI를 보존하고 있습니다.
```bash
python tools/run_webui.py # 가속이 필요한 경우 --compile
```
### 2. Awesome WebUI
Awesome WebUI는 TypeScript 기반으로 개발된 현대적인 웹 인터페이스로, 더 풍부한 기능과 향상된 사용자 경험을 제공합니다.
**WebUI 빌드:**
로컬 또는 서버에 Node.js와 npm이 설치되어 있어야 합니다.
1. `awesome_webui` 디렉토리로 이동합니다:
```bash
cd awesome_webui
```
2. 의존성 설치:
```bash
npm install
```
3. WebUI 빌드:
```bash
npm run build
```
**백엔드 서버 실행:**
WebUI 빌드가 완료되면 프로젝트 루트로 돌아가 API 서버를 실행합니다:
```bash
python tools/api_server.py --listen 0.0.0.0:8888 --compile
```
**접속:**
서버가 실행된 후 브라우저를 통해 다음 주소로 접속하면 체험할 수 있습니다:
`http://localhost:8888/ui`
================================================
FILE: docs/ko/install.md
================================================
## 요구 사양
- GPU 메모리: 24GB (추론 시)
- 시스템: Linux, WSL
## 시스템 설정
Fish Audio S2는 다양한 설치 방법을 지원합니다. 자신의 개발 환경에 가장 적합한 방법을 선택하세요.
**사전 요구사항**: 오디오 처리를 위한 시스템 의존성을 설치합니다:
``` bash
apt install portaudio19-dev libsox-dev ffmpeg
```
### Conda
```bash
conda create -n fish-speech python=3.12
conda activate fish-speech
# GPU 버전 설치 (CUDA 버전 선택: cu126, cu128, cu129)
pip install -e .[cu129]
# CPU 버전만 설치
pip install -e .[cpu]
# 기본 설치 (PyTorch 기본 인덱스 사용)
pip install -e .
# pyaudio 설치 중 오류가 발생하면 다음 명령을 사용해 보세요:
# conda install pyaudio
# 그런 다음 pip install -e . 를 다시 실행하세요
```
### UV
UV는 더 빠른 의존성 해결 및 설치를 제공합니다:
```bash
# GPU 버전 설치 (CUDA 버전 선택: cu126, cu128, cu129)
uv sync --python 3.12 --extra cu129
# CPU 버전만 설치
uv sync --python 3.12 --extra cpu
```
### Intel Arc XPU 지원
Intel Arc GPU 사용자는 다음을 통해 XPU 지원을 설치하세요:
```bash
conda create -n fish-speech python=3.12
conda activate fish-speech
# 필요한 C++ 표준 라이브러리 설치
conda install libstdcxx -c conda-forge
# Intel XPU를 지원하는 PyTorch 설치
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu
# Fish Speech 설치
pip install -e .
```
!!! warning
`compile` 옵션은 Windows와 macOS에서 지원되지 않습니다. 컴파일을 활성화하여 실행하려면 Triton을 직접 설치해야 합니다.
## Docker 설정
Fish Audio S2 시리즈 모델은 다양한 요구에 부응하기 위해 여러 Docker 배포 옵션을 제공합니다. Docker Hub의 사전 빌드된 이미지를 사용하거나, Docker Compose로 로컬에서 빌드하거나, 수동으로 사용자 정의 이미지를 빌드할 수 있습니다.
WebUI와 API 서버 모두에 대해 GPU(기본값 CUDA 12.6) 및 CPU 버전의 Docker 이미지를 제공합니다. Docker Hub의 사전 빌드된 이미지를 사용하거나, Docker Compose로 로컬에서 빌드하거나, 수동으로 사용자 정의 이미지를 빌드할 수 있습니다. 로컬에서 빌드하려면 아래 지침을 따르세요. 사전 빌드된 이미지를 사용하려면 [추론 가이드](inference.md)를 직접 참조하세요.
### 사전 요구사항
- Docker 및 Docker Compose 설치
- NVIDIA Docker 런타임 설치 (GPU 지원용)
- CUDA 추론을 위한 최소 24GB의 GPU 메모리
### Docker Compose 사용
개발 또는 사용자 정의를 위해 Docker Compose를 사용하여 로컬에서 빌드하고 실행할 수 있습니다:
```bash
# 먼저 리포지토리를 클론합니다
git clone https://github.com/fishaudio/fish-speech.git
cd fish-speech
# CUDA로 WebUI 시작
docker compose --profile webui up
# 컴파일 최적화로 WebUI 시작
COMPILE=1 docker compose --profile webui up
# API 서버 시작
docker compose --profile server up
# 컴파일 최적화로 API 서버 시작
COMPILE=1 docker compose --profile server up
# CPU 전용 배포
BACKEND=cpu docker compose --profile webui up
```
#### Docker Compose 환경 변수
환경 변수를 사용하여 배포를 사용자 정의할 수 있습니다:
```bash
# .env 파일 예시
BACKEND=cuda # 또는 cpu
COMPILE=1 # 컴파일 최적화 활성화
GRADIO_PORT=7860 # WebUI 포트
API_PORT=8080 # API 서버 포트
UV_VERSION=0.8.15 # UV 패키지 관리자 버전
```
이 명령은 이미지를 빌드하고 컨테이너를 실행합니다. WebUI는 `http://localhost:7860`에서, API 서버는 `http://localhost:8080`에서 접근할 수 있습니다.
### 수동 Docker 빌드
빌드 프로세스를 사용자 정의하려는 고급 사용자를 위해:
```bash
# CUDA를 지원하는 WebUI 이미지 빌드
docker build \
--platform linux/amd64 \
-f docker/Dockerfile \
--build-arg BACKEND=cuda \
--build-arg CUDA_VER=12.6.0 \
--build-arg UV_EXTRA=cu126 \
--target webui \
-t fish-speech-webui:cuda .
# CUDA를 지원하는 API 서버 이미지 빌드
docker build \
--platform linux/amd64 \
-f docker/Dockerfile \
--build-arg BACKEND=cuda \
--build-arg CUDA_VER=12.6.0 \
--build-arg UV_EXTRA=cu126 \
--target server \
-t fish-speech-server:cuda .
# CPU 전용 이미지 빌드 (멀티 플랫폼 지원)
docker build \
--platform linux/amd64,linux/arm64 \
-f docker/Dockerfile \
--build-arg BACKEND=cpu \
--target webui \
-t fish-speech-webui:cpu .
# 개발용 이미지 빌드
docker build \
--platform linux/amd64 \
-f docker/Dockerfile \
--build-arg BACKEND=cuda \
--target dev \
-t fish-speech-dev:cuda .
```
#### 빌드 인자
- `BACKEND`: `cuda` 또는 `cpu` (기본값: `cuda`)
- `CUDA_VER`: CUDA 버전 (기본값: `12.6.0`)
- `UV_EXTRA`: CUDA용 UV 추가 패키지 (기본값: `cu126`)
- `UBUNTU_VER`: Ubuntu 버전 (기본값: `24.04`)
- `PY_VER`: Python 버전 (기본값: `3.12`)
### 볼륨 마운트
두 방법 모두 다음 디렉토리를 마운트해야 합니다:
- `./checkpoints:/app/checkpoints` - 모델 가중치 디렉토리
- `./references:/app/references` - 참조 오디오 파일 디렉토리
### 환경 변수
- `COMPILE=1` - `torch.compile`을 활성화하여 추론 속도 향상 (약 10배)
- `GRADIO_SERVER_NAME=0.0.0.0` - WebUI 서버 호스트
- `GRADIO_SERVER_PORT=7860` - WebUI 서버 포트
- `API_SERVER_NAME=0.0.0.0` - API 서버 호스트
- `API_SERVER_PORT=8080` - API 서버 포트
!!! note
Docker 컨테이너는 모델 가중치가 `/app/checkpoints`에 마운트될 것으로 예상합니다. 컨테이너를 시작하기 전에 필요한 모델 가중치를 다운로드했는지 확인하세요.
!!! warning
GPU 지원에는 NVIDIA Docker 런타임이 필요합니다. CPU 전용 배포의 경우 `--gpus all` 플래그를 제거하고 CPU 이미지를 사용하세요.
================================================
FILE: docs/pt/finetune.md
================================================
# Ajuste Fino (Fine-tuning)
Obviamente, ao abrir esta página, você não estava satisfeito com o desempenho do modelo pré-treinado em modo zero-shot. Você deseja fazer um ajuste fino em um modelo para melhorar seu desempenho em seu conjunto de dados.
Na versão atual, você só precisa fazer o ajuste fino da parte 'LLAMA'.
## Ajuste Fino do LLAMA
### 1. Prepare o conjunto de dados
```
.
├── SPK1
│ ├── 21.15-26.44.lab
│ ├── 21.15-26.44.mp3
│ ├── 27.51-29.98.lab
│ ├── 27.51-29.98.mp3
│ ├── 30.1-32.71.lab
│ └── 30.1-32.71.mp3
└── SPK2
├── 38.79-40.85.lab
└── 38.79-40.85.mp3
```
Você precisa converter seu conjunto de dados para o formato acima e colocá-lo no diretório `data`. O arquivo de áudio pode ter as extensões `.mp3`, `.wav` ou `.flac`, e o arquivo de anotação deve ter a extensão `.lab`.
!!! info
O arquivo de anotação `.lab` precisa conter apenas a transcrição do áudio, sem necessidade de formatação especial. Por exemplo, se `hi.mp3` contiver "Olá, adeus.", então o arquivo `hi.lab` conterá uma única linha de texto: "Olá, adeus.".
!!! warning
Recomenda-se aplicar a normalização de volume (loudness) ao conjunto de dados. Você pode usar o [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) para fazer isso.
```bash
fap loudness-norm data-raw data --clean
```
### 2. Extração em lote de tokens semânticos
Certifique-se de que você baixou os pesos do VQGAN. Se não, execute o seguinte comando:
```bash
huggingface-cli download fishaudio/openaudio-s1-mini --local-dir checkpoints/openaudio-s1-mini
```
Em seguida, você pode executar o seguinte comando para extrair os tokens semânticos:
```bash
python tools/vqgan/extract_vq.py data \
--num-workers 1 --batch-size 16 \
--config-name "modded_dac_vq" \
--checkpoint-path "checkpoints/openaudio-s1-mini/codec.pth"
```
!!! note
Você pode ajustar `--num-workers` e `--batch-size` para aumentar a velocidade de extração, mas certifique-se de não exceder o limite de memória da sua GPU.
Este comando criará arquivos `.npy` no diretório `data`, como mostrado abaixo:
```
.
├── SPK1
│ ├── 21.15-26.44.lab
│ ├── 21.15-26.44.mp3
│ ├── 21.15-26.44.npy
│ ├── 27.51-29.98.lab
│ ├── 27.51-29.98.mp3
│ ├── 27.51-29.98.npy
│ ├── 30.1-32.71.lab
│ ├── 30.1-32.71.mp3
│ └── 30.1-32.71.npy
└── SPK2
├── 38.79-40.85.lab
├── 38.79-40.85.mp3
└── 38.79-40.85.npy
```
### 3. Empacote o conjunto de dados em protobuf
```bash
python tools/llama/build_dataset.py \
--input "data" \
--output "data/protos" \
--text-extension .lab \
--num-workers 16
```
Após a conclusão da execução do comando, você deverá ver o arquivo `protos` no diretório `data`.
### 4. Finalmente, ajuste fino com LoRA
Da mesma forma, certifique-se de que você baixou os pesos do `LLAMA`. Se não, execute o seguinte comando:
```bash
huggingface-cli download fishaudio/openaudio-s1-mini --local-dir checkpoints/openaudio-s1-mini
```
Finalmente, você pode iniciar o ajuste fino executando o seguinte comando:
```bash
python fish_speech/train.py --config-name text2semantic_finetune \
project=$project \
+lora@model.model.lora_config=r_8_alpha_16
```
!!! note
Você pode modificar os parâmetros de treinamento, como `batch_size`, `gradient_accumulation_steps`, etc., para se adequar à memória da sua GPU, modificando `fish_speech/configs/text2semantic_finetune.yaml`.
!!! note
Para usuários do Windows, você pode usar `trainer.strategy.process_group_backend=gloo` para evitar problemas com `nccl`.
Após o treinamento ser concluído, você pode consultar a seção de [inferência](inference.md) para testar seu modelo.
!!! info
Por padrão, o modelo aprenderá apenas os padrões de fala do locutor e não o timbre. Você ainda precisará usar prompts para garantir a estabilidade do timbre.
Se você quiser aprender o timbre, pode aumentar o número de passos de treinamento, mas isso pode levar a um sobreajuste (overfitting).
Após o treinamento, você precisa converter os pesos do LoRA para pesos regulares antes de realizar a inferência.
```bash
python tools/llama/merge_lora.py \
--lora-config r_8_alpha_16 \
--base-weight checkpoints/openaudio-s1-mini \
--lora-weight results/$project/checkpoints/step_000000010.ckpt \
--output checkpoints/openaudio-s1-mini-yth-lora/```
!!! note
Você também pode tentar outros checkpoints. Sugerimos usar o checkpoint mais antigo que atenda aos seus requisitos, pois eles geralmente têm um desempenho melhor em dados fora de distribuição (OOD).
================================================
FILE: docs/pt/index.md
================================================
!!! info "Aviso de Licença"
Este repositório e todos os pesos de modelo associados são lançados sob a **FISH AUDIO RESEARCH LICENSE**. Consulte [LICENSE](https://github.com/fishaudio/fish-speech/blob/main/LICENSE) para mais detalhes.
!!! warning "Isenção de Responsabilidade Legal"
Não nos responsabilizamos por qualquer uso ilegal da base de códigos. Consulte as regulamentações locais sobre DMCA e outras leis relacionadas.
## Início Rápido
### Comece pela documentação
Esta é a documentação oficial do Fish Audio S2. Você pode começar por aqui:
- [Instalação](https://speech.fish.audio/pt/install/)
- [Inferência por Linha de Comando](https://speech.fish.audio/pt/inference/)
- [Inferência WebUI](https://speech.fish.audio/pt/inference/)
- [Inferência via Servidor](https://speech.fish.audio/pt/server/)
- [Configuração Docker](https://speech.fish.audio/pt/install/)
> [!IMPORTANT]
> **Para servidor com SGLang, consulte o [SGLang-Omni README](https://github.com/sgl-project/sglang-omni/blob/main/sglang_omni/models/fishaudio_s2_pro/README.md).**
### Guia para agentes LLM
```
Instale e configure o Fish Audio S2 seguindo as instruções em https://speech.fish.audio/pt/install/ .
```
## Fish Audio S2
**O melhor sistema de conversão de texto em fala entre código aberto e código fechado**
O Fish Audio S2 é o modelo mais recente da [Fish Audio](https://fish.audio/). Treinado com mais de 10 milhões de horas de áudio em cerca de 50 idiomas, o S2 combina alinhamento por reforço com uma arquitetura Dual-Autoregressive para gerar fala natural, realista e emocionalmente expressiva.
O S2 permite controle fino de prosódia e emoção dentro da própria frase com tags em linguagem natural, como `[laugh]`, `[whispers]` e `[super happy]`, além de oferecer suporte nativo a múltiplos falantes e múltiplos turnos.
AcesVisite o [site da Fish Audio](https://fish.audio/) para demonstrações ao vivo. Leia a [postagem no blog](https://fish.audio/blog/fish-audio-open-sources-s2/) e o [relatório técnico](https://arxiv.org/abs/2603.08823) para mais detalhes.
### Variantes do Modelo
| Modelo | Tamanho | Disponibilidade | Descrição |
|------|------|-------------|-------------|
| S2-Pro | 4B parâmetros | [HuggingFace](https://huggingface.co/fishaudio/s2-pro) | Modelo carro-chefe completo com máxima qualidade e estabilidade |
Mais detalhes podem ser encontrados no [relatório técnico](https://arxiv.org/abs/2411.01156).
## Resultados de Benchmark
| Benchmark | Fish Audio S2 |
|------|------|
| Seed-TTS Eval — WER (Chinês) | **0.54%** (melhor geral) |
| Seed-TTS Eval — WER (Inglês) | **0.99%** (melhor geral) |
| Audio Turing Test (com instrução) | **0.515** média a posteriori |
| EmergentTTS-Eval — Taxa de vitória | **81.88%** (maior geral) |
| Fish Instruction Benchmark — TAR | **93.3%** |
| Fish Instruction Benchmark — Qualidade | **4.51 / 5.0** |
| Multilíngue (MiniMax Testset) — Melhor WER | **11 de 24** idiomas |
| Multilíngue (MiniMax Testset) — Melhor SIM | **17 de 24** idiomas |
No Seed-TTS Eval, o S2 obteve o menor WER entre todos os modelos avaliados, incluindo sistemas fechados: Qwen3-TTS (0.77/1.24), MiniMax Speech-02 (0.99/1.90) e Seed-TTS (1.12/2.25). No Audio Turing Test, o valor 0.515 supera o Seed-TTS (0.417) em 24% e o MiniMax-Speech (0.387) em 33%. No EmergentTTS-Eval, o S2 se destacou especialmente em paralinguística (91.61%), perguntas (84.41%) e complexidade sintática (83.39%).
## Destaques
### Controle Inline Refinado via Linguagem Natural
O Fish Audio S2 permite controle localizado da geração de fala ao incorporar instruções em linguagem natural diretamente em posições específicas de palavras ou frases no texto. Em vez de depender de um conjunto fixo de tags predefinidas, o S2 aceita descrições textuais livres, como [whisper in small voice], [professional broadcast tone] ou [pitch up], permitindo controle de expressão aberto no nível da palavra.
### Arquitetura Dual-Autoregressive
O S2 é baseado em um transformer apenas decodificador, combinado com um codec de áudio RVQ (10 codebooks, ~21 Hz de taxa de quadros). A arquitetura Dual-AR divide a geração em duas etapas:
- **Slow AR** opera no eixo temporal e prevê o codebook semântico principal.
- **Fast AR** gera os 9 codebooks residuais restantes em cada passo de tempo, reconstruindo detalhes acústicos finos.
Esse desenho assimétrico (4B parâmetros no eixo temporal e 400M no eixo de profundidade) mantém a inferência eficiente sem sacrificar fidelidade de áudio.
### Alinhamento por Reforço
O S2 usa Group Relative Policy Optimization (GRPO) no pós-treinamento. Os mesmos modelos usados para filtrar e anotar dados de treino são reutilizados diretamente como modelos de recompensa no RL, eliminando o desalinhamento de distribuição entre os dados de pré-treinamento e os objetivos de pós-treinamento. O sinal de recompensa combina precisão semântica, aderência à instrução, preferência acústica e similaridade de timbre.
### Streaming em Produção com SGLang
Como a arquitetura Dual-AR é estruturalmente isomórfica a LLMs autoregressivos padrão, o S2 herda diretamente as otimizações nativas de serving do SGLang, incluindo continuous batching, paged KV cache, CUDA graph replay e prefix caching com RadixAttention.
Em uma única NVIDIA H200:
- **RTF (Real-Time Factor):** 0.195
- **Tempo até o primeiro áudio:** ~100 ms
- **Throughput:** mais de 3.000 acoustic tokens/s mantendo RTF abaixo de 0.5
### Suporte Multilíngue
O Fish Audio S2 oferece suporte a conversão de texto em fala multilíngue de alta qualidade sem a necessidade de fonemas ou processamento específico de idioma. Incluindo:
**Inglês, Chinês, Japonês, Coreano, Árabe, Alemão, Francês...**
**E MUITO MAIS!**
A lista está em constante expansão, verifique o [Fish Audio](https://fish.audio/) para os lançamentos mais recentes.
### Geração Nativa de Múltiplos Falantes
O Fish Audio S2 permite enviar um áudio de referência com vários falantes; o modelo processa as características de cada voz por meio do token `<|speaker:i|>`. Depois, você controla o comportamento do modelo com o token de ID do falante, permitindo incluir várias vozes em uma única geração. Assim, não é mais necessário subir um áudio de referência separado para cada falante.
### Geração de Múltiplos Turnos
Graças à extensão do contexto do modelo, nosso modelo agora pode usar informações anteriores para melhorar a expressividade e a naturalidade dos conteúdos gerados subsequentemente.
### Clonagem de Voz Rápida
O Fish Audio S2 suporta clonagem de voz precisa usando uma pequena amostra de referência (tipicamente de 10 a 30 segundos). O modelo captura o timbre, o estilo de fala e as tendências emocionais, produzindo vozes clonadas realistas e consistentes sem ajuste fino adicional.
Para usar o servidor SGLang, consulte [SGLang-Omni README](https://github.com/sgl-project/sglang-omni/blob/main/sglang_omni/models/fishaudio_s2_pro/README.md) .
---
## Créditos
- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
- [GPT VITS](https://github.com/innnky/gpt-vits)
- [MQTTS](https://github.com/b04901014/MQTTS)
- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
- [Qwen3](https://github.com/QwenLM/Qwen3)
## Relatório Técnico
```bibtex
@misc{fish-speech-v1.4,
title={Fish-Speech: Leveraging Large Language Models for Advanced Multilingual Text-to-Speech Synthesis},
author={Shijia Liao and Yuxuan Wang and Tianyu Li and Yifan Cheng and Ruoyi Zhang and Rongzhi Zhou and Yijin Xing},
year={2024},
eprint={2411.01156},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2411.01156},
}
@misc{liao2026fishaudios2technical,
title={Fish Audio S2 Technical Report},
author={Shijia Liao and Yuxuan Wang and Songting Liu and Yifan Cheng and Ruoyi Zhang and Tianyu Li and Shidong Li and Yisheng Zheng and Xingwei Liu and Qingzheng Wang and Zhizhuo Zhou and Jiahua Liu and Xin Chen and Dawei Han},
year={2026},
eprint={2603.08823},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2603.08823},
}
```
================================================
FILE: docs/pt/inference.md
================================================
# Inferência
O modelo Fish Audio S2 requer uma grande quantidade de VRAM. Recomendamos o uso de uma GPU com pelo menos 24GB para inferência.
## Baixar Pesos
Primeiro, você precisa baixar os pesos do modelo:
```bash
hf download fishaudio/s2-pro --local-dir checkpoints/s2-pro
```
## Inferência por Linha de Comando
!!! note
Se você planeja deixar o modelo escolher aleatoriamente um timbre de voz, pode pular esta etapa.
### 1. Obter tokens VQ do áudio de referência
```bash
python fish_speech/models/dac/inference.py \
-i "test.wav" \
--checkpoint-path "checkpoints/s2-pro/codec.pth"
```
Você deve obter um `fake.npy` e um `fake.wav`.
### 2. Gerar tokens Semânticos a partir do texto:
```bash
python fish_speech/models/text2semantic/inference.py \
--text "O texto que você deseja converter" \
--prompt-text "Seu texto de referência" \
--prompt-tokens "fake.npy" \
# --compile
```
Este comando criará um arquivo `codes_N` no diretório de trabalho, onde N é um número inteiro começando em 0.
!!! note
Você pode querer usar `--compile` para fundir kernels CUDA para uma inferência mais rápida. No entanto, recomendamos usar nossa otimização de aceleração de inferência sglang.
Da mesma forma, se você não planeja usar aceleração, pode comentar o parâmetro `--compile`.
!!! info
Para GPUs que não suportam bf16, você pode precisar usar o parâmetro `--half`.
### 3. Gerar vocais a partir de tokens semânticos:
```bash
python fish_speech/models/dac/inference.py \
-i "codes_0.npy" \
```
Depois disso, você obterá um arquivo `fake.wav`.
## Inferência WebUI
### 1. Gradio WebUI
Para manter a compatibilidade, mantemos a interface Gradio WebUI anterior.
```bash
python tools/run_webui.py # --compile se você precisar de aceleração
```
### 2. Awesome WebUI
A Awesome WebUI é uma interface web moderna baseada em TypeScript, oferecendo funcionalidades mais ricas e uma melhor experiência do usuário.
**Construir a WebUI:**
Você precisa ter o Node.js e o npm instalados em seu computador local ou servidor.
1. Entre no diretório `awesome_webui`:
```bash
cd awesome_webui
```
2. Instale as dependências:
```bash
npm install
```
3. Construa a WebUI:
```bash
npm run build
```
**Iniciar o Servidor Backend:**
Após a construção da WebUI, retorne ao diretório raiz do projeto e inicie o servidor API:
```bash
python tools/api_server.py --listen 0.0.0.0:8888 --compile
```
**Acesso:**
Após o servidor ser iniciado, você pode acessá-lo através do navegador no seguinte endereço:
`http://localhost:8888/ui`
================================================
FILE: docs/pt/install.md
================================================
## Requisitos
- Memória da GPU: 24GB (Inferência)
- Sistema: Linux, WSL
## Configuração do Sistema
O Fish Audio S2 suporta múltiplos métodos de instalação. Escolha o que melhor se adapta ao seu ambiente de desenvolvimento.
**Pré-requisitos**: Instale as dependências de sistema para processamento de áudio:
``` bash
apt install portaudio19-dev libsox-dev ffmpeg
```
### Conda
```bash
conda create -n fish-speech python=3.12
conda activate fish-speech
# Instalação com GPU (escolha a sua versão do CUDA: cu126, cu128, cu129)
pip install -e .[cu129]
# Instalação apenas para CPU
pip install -e .[cpu]
# Instalação padrão (usa o índice padrão do PyTorch)
pip install -e .
# Se encontrar um erro durante a instalação devido ao pyaudio, considere usar o seguinte comando:
# conda install pyaudio
# De seguida, execute pip install -e . novamente
```
### UV
O UV oferece uma resolução e instalação de dependências mais rápidas:
```bash
# Instalação com GPU (escolha a sua versão do CUDA: cu126, cu128, cu129)
uv sync --python 3.12 --extra cu129
# Instalação apenas para CPU
uv sync --python 3.12 --extra cpu
```
### Suporte para Intel Arc XPU
Para utilizadores de GPUs Intel Arc, instale o suporte XPU da seguinte forma:
```bash
conda create -n fish-speech python=3.12
conda activate fish-speech
# Instalar a biblioteca padrão C++ necessária
conda install libstdcxx -c conda-forge
# Instalar o PyTorch com suporte para Intel XPU
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu
# Instalar o Fish Speech
pip install -e .
```
!!! warning
A opção `compile` não é suportada no Windows e macOS. Se desejar executar com compilação, terá de instalar o Triton manualmente.
## Configuração do Docker
O modelo da série Fish Audio S2 oferece múltiplas opções de implementação com Docker para satisfazer diferentes necessidades. Pode usar imagens pré-construídas do Docker Hub, construir localmente com o Docker Compose, ou construir manualmente imagens personalizadas.
Fornecemos imagens Docker para a WebUI e o servidor API, tanto para GPU (CUDA 12.6 por defeito) como para CPU. Pode usar as imagens pré-construídas do Docker Hub, construir localmente com o Docker Compose, ou construir manualmente imagens personalizadas. Se quiser construir localmente, siga as instruções abaixo. Se apenas quiser usar as imagens pré-construídas, siga diretamente o [guia de inferência](inference.md).
### Pré-requisitos
- Docker e Docker Compose instalados
- NVIDIA Docker runtime instalado (para suporte de GPU)
- Pelo menos 24GB de memória de GPU para inferência com CUDA
### Usar o Docker Compose
Para desenvolvimento ou personalização, pode usar o Docker Compose para construir e executar localmente:
```bash
# Primeiro, clone o repositório
git clone https://github.com/fishaudio/fish-speech.git
cd fish-speech
# Iniciar a WebUI com CUDA
docker compose --profile webui up
# Iniciar a WebUI com otimização de compilação
COMPILE=1 docker compose --profile webui up
# Iniciar o servidor API
docker compose --profile server up
# Iniciar o servidor API com otimização de compilação
COMPILE=1 docker compose --profile server up
# Implementação apenas com CPU
BACKEND=cpu docker compose --profile webui up
```
#### Variáveis de Ambiente para o Docker Compose
Pode personalizar a implementação usando variáveis de ambiente:
```bash
# Exemplo de ficheiro .env
BACKEND=cuda # ou cpu
COMPILE=1 # Ativar otimização de compilação
GRADIO_PORT=7860 # Porta da WebUI
API_PORT=8080 # Porta do servidor API
UV_VERSION=0.8.15 # Versão do gestor de pacotes UV
```
O comando irá construir a imagem e executar o contentor. Pode aceder à WebUI em `http://localhost:7860` e ao servidor API em `http://localhost:8080`.
### Construção Manual com Docker
Para utilizadores avançados que desejam personalizar o processo de construção:
```bash
# Construir imagem da WebUI com suporte CUDA
docker build \
--platform linux/amd64 \
-f docker/Dockerfile \
--build-arg BACKEND=cuda \
--build-arg CUDA_VER=12.6.0 \
--build-arg UV_EXTRA=cu126 \
--target webui \
-t fish-speech-webui:cuda .
# Construir imagem do servidor API com suporte CUDA
docker build \
--platform linux/amd64 \
-f docker/Dockerfile \
--build-arg BACKEND=cuda \
--build-arg CUDA_VER=12.6.0 \
--build-arg UV_EXTRA=cu126 \
--target server \
-t fish-speech-server:cuda .
# Construir imagem apenas para CPU (suporta multiplataforma)
docker build \
--platform linux/amd64,linux/arm64 \
-f docker/Dockerfile \
--build-arg BACKEND=cpu \
--target webui \
-t fish-speech-webui:cpu .
# Construir imagem de desenvolvimento
docker build \
--platform linux/amd64 \
-f docker/Dockerfile \
--build-arg BACKEND=cuda \
--target dev \
-t fish-speech-dev:cuda .
```
#### Argumentos de Construção
- `BACKEND`: `cuda` ou `cpu` (padrão: `cuda`)
- `CUDA_VER`: Versão do CUDA (padrão: `12.6.0`)
- `UV_EXTRA`: Pacote extra do UV para CUDA (padrão: `cu126`)
- `UBUNTU_VER`: Versão do Ubuntu (padrão: `24.04`)
- `PY_VER`: Versão do Python (padrão: `3.12`)
### Montagem de Volumes
Ambos os métodos requerem a montagem dos seguintes diretórios:
- `./checkpoints:/app/checkpoints` - Diretório dos pesos do modelo
- `./references:/app/references` - Diretório dos ficheiros de áudio de referência
### Variáveis de Ambiente
- `COMPILE=1` - Ativa o `torch.compile` para uma inferência mais rápida (cerca de 10x)
- `GRADIO_SERVER_NAME=0.0.0.0` - Anfitrião do servidor WebUI
- `GRADIO_SERVER_PORT=7860` - Porta do servidor WebUI
- `API_SERVER_NAME=0.0.0.0` - Anfitrião do servidor API
- `API_SERVER_PORT=8080` - Porta do servidor API
!!! note
Os contentores Docker esperam que os pesos do modelo sejam montados em `/app/checkpoints`. Certifique-se de que descarregou os pesos do modelo necessários antes de iniciar os contentores.
!!! warning
O suporte para GPU requer o NVIDIA Docker runtime. Para implementações apenas com CPU, remova a flag `--gpus all` e use as imagens de CPU.
================================================
FILE: docs/requirements.txt
================================================
mkdocs-material
mkdocs-static-i18n[material]
mkdocs[i18n]
================================================
FILE: docs/stylesheets/extra.css
================================================
.md-grid {
max-width: 1440px;
}
================================================
FILE: docs/zh/finetune.md
================================================
# 微调
显然, 当你打开这个页面的时候, 你已经对预训练模型 zero-shot 的效果不算满意. 你想要微调一个模型, 使得它在你的数据集上表现更好.
在目前版本,你只需要微调'LLAMA'部分即可.
## LLAMA 微调
### 1. 准备数据集
```
.
├── SPK1
│ ├── 21.15-26.44.lab
│ ├── 21.15-26.44.mp3
│ ├── 27.51-29.98.lab
│ ├── 27.51-29.98.mp3
│ ├── 30.1-32.71.lab
│ └── 30.1-32.71.mp3
└── SPK2
├── 38.79-40.85.lab
└── 38.79-40.85.mp3
```
你需要将数据集转为以上格式, 并放到 `data` 下, 音频后缀可以为 `.mp3`, `.wav` 或 `.flac`, 标注文件后缀建议为 `.lab`.
!!! info
标注文件 `.lab` 仅需包含音频的转写文本,无需遵循特殊格式要求。例如,如果 `hi.mp3` 中的内容是“你好,再见。”,那么 `hi.lab` 文件中只需包含一行文本:“你好,再见”。
!!! warning
建议先对数据集进行响度匹配, 你可以使用 [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) 来完成这一步骤.
```bash
fap loudness-norm data-raw data --clean
```
### 2. 批量提取语义 token
确保你已经下载了 vqgan 权重, 如果没有, 请运行以下命令:
```bash
huggingface-cli download fishaudio/openaudio-s1-mini --local-dir checkpoints/openaudio-s1-mini
```
随后可运行以下命令来提取语义 token:
```bash
python tools/vqgan/extract_vq.py data \
--num-workers 1 --batch-size 16 \
--config-name "modded_dac_vq" \
--checkpoint-path "checkpoints/s2-pro/codec.pth"
```
!!! note
你可以调整 `--num-workers` 和 `--batch-size` 来提高提取速度, 但是请注意不要超过你的显存限制.
该命令会在 `data` 目录下创建 `.npy` 文件, 如下所示:
```
.
├── SPK1
│ ├── 21.15-26.44.lab
│ ├── 21.15-26.44.mp3
│ ├── 21.15-26.44.npy
│ ├── 27.51-29.98.lab
│ ├── 27.51-29.98.mp3
│ ├── 27.51-29.98.npy
│ ├── 30.1-32.71.lab
│ ├── 30.1-32.71.mp3
│ └── 30.1-32.71.npy
└── SPK2
├── 38.79-40.85.lab
├── 38.79-40.85.mp3
└── 38.79-40.85.npy
```
### 3. 打包数据集为 protobuf
```bash
python tools/llama/build_dataset.py \
--input "data" \
--output "data/protos" \
--text-extension .lab \
--num-workers 16
```
命令执行完毕后, 你应该能在 `data` 目录下看到 `protos` 文件.
### 4. 最后, 使用 LoRA 进行微调
同样的, 请确保你已经下载了 `LLAMA` 权重, 如果没有, 请运行以下命令:
```bash
huggingface-cli download fishaudio/s2-pro --local-dir checkpoints/s2-pro
```
最后, 你可以运行以下命令来启动微调:
```bash
python fish_speech/train.py --config-name text2semantic_finetune \
project=$project \
+lora@model.model.lora_config=r_8_alpha_16
```
!!! note
你可以通过修改 `fish_speech/configs/text2semantic_finetune.yaml` 来修改训练参数如 `batch_size`, `gradient_accumulation_steps` 等, 来适应你的显存.
!!! note
对于 Windows 用户, 你可以使用 `trainer.strategy.process_group_backend=gloo` 来避免 `nccl` 的问题.
训练结束后, 你可以参考 [推理](inference.md) 部分来测试你的模型.
!!! info
默认配置下, 基本只会学到说话人的发音方式, 而不包含音色, 你依然需要使用 prompt 来保证音色的稳定性.
如果你想要学到音色, 请将训练步数调大, 但这有可能会导致过拟合.
训练完成后, 你需要先将 loRA 的权重转为普通权重, 然后再进行推理.
```bash
python tools/llama/merge_lora.py \
--lora-config r_8_alpha_16 \
--base-weight checkpoints/s2-pro \
--lora-weight results/$project/checkpoints/step_000000010.ckpt \
--output checkpoints/s2-pro-yth-lora/
```
!!! note
你也可以尝试其他的 checkpoint, 我们建议你使用最早的满足你要求的 checkpoint, 他们通常在 OOD 上表现更好.
================================================
FILE: docs/zh/index.md
================================================
!!! info "许可声明"
此代码库及其相关的模型权重均在 **FISH AUDIO RESEARCH LICENSE** 下发布。更多详情请参考 [LICENSE](https://github.com/fishaudio/fish-speech/blob/main/LICENSE)。
!!! warning "法律免责声明"
我们不对代码库的任何非法使用承担责任。请参考您当地关于 DMCA 和其他相关法律的法规。
## 快速开始
### 文档入口
这里是 Fish Audio S2 的官方文档,请按照说明轻松入门。
- [安装](https://speech.fish.audio/zh/install/)
- [命令行推理](https://speech.fish.audio/zh/inference/)
- [WebUI 推理](https://speech.fish.audio/zh/inference/)
- [服务端推理](https://speech.fish.audio/zh/server/)
- [Docker 部署](https://speech.fish.audio/zh/install/)
> [!IMPORTANT]
> **如需使用 SGLang Server,请参考 [SGLang-Omni README](https://github.com/sgl-project/sglang-omni/blob/main/sglang_omni/models/fishaudio_s2_pro/README.md)。**
### LLM Agent 指南
```
请先阅读 https://speech.fish.audio/zh/install/ ,并按文档安装和配置 Fish Audio S2。
```
## Fish Audio S2
**在开源与闭源方案中都处于领先水平的文本转语音系统**
Fish Audio S2 是由 [Fish Audio](https://fish.audio/) 开发的最新模型。S2 在约 50 种语言、超过 1000 万小时音频数据上完成训练,并结合强化学习对齐与双自回归架构,能够生成自然、真实且情感丰富的语音。
S2 支持通过自然语言标签(如 `[laugh]`、`[whispers]`、`[super happy]`)对韵律和情绪进行细粒度行内控制,同时原生支持多说话人和多轮生成。
请访问 [Fish Audio 网站](https://fish.audio/) 体验在线演示,并阅读[博客文章](https://fish.audio/blog/fish-audio-open-sources-s2/)和[技术报告](https://arxiv.org/abs/2603.08823)了解更多细节。
### 模型变体
| 模型 | 大小 | 可用性 | 描述 |
|------|------|-------------|-------------|
| S2-Pro | 4B 参数 | [HuggingFace](https://huggingface.co/fishaudio/s2-pro) | 功能齐全的旗舰模型,具有最高质量和稳定性 |
有关模型的更多详情,请参见[技术报告](https://arxiv.org/abs/2411.01156)。
## 基准测试结果
| 基准 | Fish Audio S2 |
|------|------|
| Seed-TTS Eval — WER(中文) | **0.54%**(总体最佳) |
| Seed-TTS Eval — WER(英文) | **0.99%**(总体最佳) |
| Audio Turing Test(含指令) | **0.515** 后验均值 |
| EmergentTTS-Eval — 胜率 | **81.88%**(总体最高) |
| Fish Instruction Benchmark — TAR | **93.3%** |
| Fish Instruction Benchmark — 质量 | **4.51 / 5.0** |
| 多语言(MiniMax Testset)— 最佳 WER | **24** 种语言中的 **11** 种 |
| 多语言(MiniMax Testset)— 最佳 SIM | **24** 种语言中的 **17** 种 |
在 Seed-TTS Eval 上,S2 在所有已评估模型(包括闭源系统)中实现了最低 WER:Qwen3-TTS(0.77/1.24)、MiniMax Speech-02(0.99/1.90)、Seed-TTS(1.12/2.25)。在 Audio Turing Test 上,S2 的 0.515 相比 Seed-TTS(0.417)提升 24%,相比 MiniMax-Speech(0.387)提升 33%。在 EmergentTTS-Eval 中,S2 在副语言学(91.61% 胜率)、疑问句(84.41%)和句法复杂度(83.39%)等维度表现尤为突出。
## 亮点
### 通过自然语言进行细粒度行内控制
Fish Audio S2 支持在文本中的特定词或短语位置直接嵌入自然语言指令,从而对语音生成进行局部控制。与依赖固定预设标签不同,S2 接受自由形式的文本描述,例如 [whisper in small voice]、[professional broadcast tone] 或 [pitch up],实现词级别的开放式表达控制。
### 双自回归架构(Dual-Autoregressive)
S2 基于仅解码器 Transformer,并结合 RVQ 音频编解码器(10 个码本,约 21 Hz 帧率)。Dual-AR 架构将生成拆分为两个阶段:
- **Slow AR** 沿时间轴运行,预测主语义码本。
- **Fast AR** 在每个时间步生成剩余 9 个残差码本,用于重建细粒度声学细节。
这种非对称设计(时间轴 4B 参数、深度轴 400M 参数)在保持音频保真度的同时,提高了推理效率。
### 强化学习对齐
S2 使用 Group Relative Policy Optimization(GRPO)进行后训练对齐。用于过滤和标注训练数据的同一批模型被直接复用为 RL 的奖励模型,从而避免了预训练数据分布与后训练目标之间的不匹配。奖励信号综合了语义准确性、指令遵循、声学偏好评分与音色相似度。
### 基于 SGLang 的生产级流式推理
由于 Dual-AR 架构在结构上与标准自回归 LLM 同构,S2 可以直接继承 SGLang 提供的 LLM 原生服务优化能力,包括连续批处理、分页 KV Cache、CUDA Graph Replay 与基于 RadixAttention 的前缀缓存。
在单张 NVIDIA H200 GPU 上:
- **实时因子(RTF):** 0.195
- **首音频延迟:** 约 100 ms
- **吞吐:** 在 RTF 低于 0.5 的情况下达到 3,000+ acoustic tokens/s
### 多语言支持
Fish Audio S2 支持高质量的多语言文本转语音,无需音素或特定语言的预处理。包括:
**英语、中文、日语、韩语、阿拉伯语、德语、法语...**
**以及更多!**
列表正在不断扩大,请查看 [Fish Audio](https://fish.audio/) 获取最新发布。
### 原生多说话人生成
Fish Audio S2 允许用户上传包含多个说话人的参考音频,模型将通过 `<|speaker:i|>` 令牌处理每个说话人的特征。之后您可以通过说话人 ID 令牌控制模型的表现,从而实现一次生成中包含多个说话人。再也不需要像以前那样针对每个说话人都单独上传参考音频与生成语音了。
### 多轮对话生成
得益于模型上下文的扩展,我们的模型现在可以借助上文的信息提高后续生成内容的表现力,从而提升内容的自然度。
### 快速语音克隆
Fish Audio S2 支持使用短参考样本(通常为 10-30 秒)进行准确的语音克隆。模型可以捕捉音色、说话风格和情感倾向,无需额外微调即可生成逼真且一致的克隆语音。
如需使用 SGLang Server,请参考 [SGLang-Omni README](https://github.com/sgl-project/sglang-omni/blob/main/sglang_omni/models/fishaudio_s2_pro/README.md) 。
---
## 致谢
- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
- [GPT VITS](https://github.com/innnky/gpt-vits)
- [MQTTS](https://github.com/b04901014/MQTTS)
- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
- [Qwen3](https://github.com/QwenLM/Qwen3)
## 技术报告
```bibtex
@misc{fish-speech-v1.4,
title={Fish-Speech: Leveraging Large Language Models for Advanced Multilingual Text-to-Speech Synthesis},
author={Shijia Liao and Yuxuan Wang and Tianyu Li and Yifan Cheng and Ruoyi Zhang and Rongzhi Zhou and Yijin Xing},
year={2024},
eprint={2411.01156},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2411.01156},
}
@misc{liao2026fishaudios2technical,
title={Fish Audio S2 Technical Report},
author={Shijia Liao and Yuxuan Wang and Songting Liu and Yifan Cheng and Ruoyi Zhang and Tianyu Li and Shidong Li and Yisheng Zheng and Xingwei Liu and Qingzheng Wang and Zhizhuo Zhou and Jiahua Liu and Xin Chen and Dawei Han},
year={2026},
eprint={2603.08823},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2603.08823},
}
```
================================================
FILE: docs/zh/inference.md
================================================
# 推理
Fish Audio S2 模型需要较大的显存,我们推荐您使用至少24GB的GPU进行推理。
## 下载权重
首先您需要下载模型权重:
```bash
hf download fishaudio/s2-pro --local-dir checkpoints/s2-pro
```
## 命令行推理
!!! note
如果您计划让模型随机选择音色,可以跳过此步骤。
### 1. 从参考音频获取 VQ tokens
```bash
python fish_speech/models/dac/inference.py \
-i "test.wav" \
--checkpoint-path "checkpoints/s2-pro/codec.pth"
```
您应该会得到一个 `fake.npy` 和一个 `fake.wav`。
### 2. 从文本生成 Semantic tokens:
```bash
python fish_speech/models/text2semantic/inference.py \
--text "您想要转换的文本" \
--prompt-text "您的参考文本" \
--prompt-tokens "fake.npy" \
# --compile
```
此命令将在工作目录中创建一个 `codes_N` 文件,其中 N 是从 0 开始的整数。
!!! note
您可能希望使用 `--compile` 来融合 CUDA 内核以实现更快的推理,但是我们更推荐您使用我们sglang的推理加速优化。
相应地,如果您不计划使用加速,可以注释掉 `--compile` 参数。
!!! info
对于不支持 bf16 的 GPU,您可能需要使用 `--half` 参数。
### 3. 从语义令牌生成声音:
```bash
python fish_speech/models/dac/inference.py \
-i "codes_0.npy" \
```
之后你会得到一个fake.wav文件。
## WebUI 推理
### 1. Gradio WebUI
为了保持兼容,我们保留了以往的Gradio WebUI。
```bash
python tools/run_webui.py # --compile 如果你需要加速的话
```
### 2. Awesome WebUI
Awesome WebUI 是一个基于 TypeScript 开发的现代化 Web 界面,提供更丰富的功能和更好的交互体验。
**构建 WebUI:**
您需要先在本地或者服务器上安装 Node.js 和 npm。
1. 进入 `awesome_webui` 目录:
```bash
cd awesome_webui
```
2. 安装依赖:
```bash
npm install
```
3. 构建 WebUI:
```bash
npm run build
```
**启动后端服务器:**
WebUI 构建完成后,返回项目根目录,启动 API 服务器:
```bash
python tools/api_server.py --listen 0.0.0.0:8888 --compile
```
**访问:**
在服务器启动后,您可以通过浏览器访问以下地址体验:
`http://localhost:8888/ui`
================================================
FILE: docs/zh/install.md
================================================
## 系统要求
- GPU 显存:24GB(用于推理)
- 系统:Linux、WSL
## 系统设置
Fish Audio S2 支持多种安装方式。请选择最适合你当前开发环境的方案。
**前置依赖**:先安装音频处理所需的系统依赖:
```bash
apt install portaudio19-dev libsox-dev ffmpeg
```
### Conda
```bash
conda create -n fish-speech python=3.12
conda activate fish-speech
# GPU 安装(选择 CUDA 版本:cu126、cu128、cu129)
pip install -e .[cu129]
# 仅 CPU 安装
pip install -e .[cpu]
# 默认安装(使用 PyTorch 默认索引)
pip install -e .
# 如果因 pyaudio 导致安装报错,可以先执行:
# conda install pyaudio
# 然后重新执行 pip install -e .
```
### UV
UV 可以更快地完成依赖解析与安装:
```bash
# GPU 安装(选择 CUDA 版本:cu126、cu128、cu129)
uv sync --python 3.12 --extra cu129
# 仅 CPU 安装
uv sync --python 3.12 --extra cpu
```
### Intel Arc XPU 支持
如果你使用 Intel Arc GPU,可按以下方式安装 XPU 支持:
```bash
conda create -n fish-speech python=3.12
conda activate fish-speech
# 安装必需的 C++ 标准库
conda install libstdcxx -c conda-forge
# 安装支持 Intel XPU 的 PyTorch
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu
# 安装 Fish Speech
pip install -e .
```
!!! warning
`compile` 选项暂不支持 Windows 和 macOS。若你希望启用 compile,请手动安装 Triton。
## Docker 设置
Fish Audio S2 系列模型提供多种 Docker 部署方式,适配不同场景。你可以直接使用 Docker Hub 预构建镜像,也可以用 Docker Compose 本地构建,或手动构建自定义镜像。
我们提供 WebUI 与 API Server 的 GPU(默认 CUDA126)和 CPU 镜像。你可以直接用 Docker Hub 镜像,也可以在本地构建。如果你只想使用预构建镜像,请参考[inference guide](inference.md)。
### 前置条件
- 已安装 Docker 和 Docker Compose
- (GPU 场景)已安装 NVIDIA Docker runtime
- CUDA 推理建议至少 24GB 显存
# 使用 Docker Compose
如果你需要开发或自定义,推荐使用 Docker Compose 在本地构建并运行:
```bash
# 先克隆仓库
git clone https://github.com/fishaudio/fish-speech.git
cd fish-speech
# 使用 CUDA 启动 WebUI
docker compose --profile webui up
# 启用 compile 优化启动 WebUI
COMPILE=1 docker compose --profile webui up
# 启动 API Server
docker compose --profile server up
# 启用 compile 优化启动 API Server
COMPILE=1 docker compose --profile server up
# 仅 CPU 部署
BACKEND=cpu docker compose --profile webui up
```
#### Docker Compose 环境变量
你可以通过环境变量定制部署参数:
```bash
# .env 文件示例
BACKEND=cuda # 或 cpu
COMPILE=1 # 启用 compile 优化
GRADIO_PORT=7860 # WebUI 端口
API_PORT=8080 # API Server 端口
UV_VERSION=0.8.15 # UV 包管理器版本
```
命令执行后会自动构建镜像并启动容器。你可以通过 `http://localhost:7860` 访问 WebUI,通过 `http://localhost:8080` 访问 API Server。
### 手动 Docker 构建
如果你需要更细粒度的构建控制,可以手动构建:
```bash
# 构建支持 CUDA 的 WebUI 镜像
docker build \
--platform linux/amd64 \
-f docker/Dockerfile \
--build-arg BACKEND=cuda \
--build-arg CUDA_VER=12.6.0 \
--build-arg UV_EXTRA=cu126 \
--target webui \
-t fish-speech-webui:cuda .
# 构建支持 CUDA 的 API Server 镜像
docker build \
--platform linux/amd64 \
-f docker/Dockerfile \
--build-arg BACKEND=cuda \
--build-arg CUDA_VER=12.6.0 \
--build-arg UV_EXTRA=cu126 \
--target server \
-t fish-speech-server:cuda .
# 构建仅 CPU 镜像(支持多平台)
docker build \
--platform linux/amd64,linux/arm64 \
-f docker/Dockerfile \
--build-arg BACKEND=cpu \
--target webui \
-t fish-speech-webui:cpu .
# 构建开发镜像
docker build \
--platform linux/amd64 \
-f docker/Dockerfile \
--build-arg BACKEND=cuda \
--target dev \
-t fish-speech-dev:cuda .
```
#### 构建参数
- `BACKEND`:`cuda` 或 `cpu`(默认:`cuda`)
- `CUDA_VER`:CUDA 版本(默认:`12.6.0`)
- `UV_EXTRA`:UV 的 CUDA 扩展(默认:`cu126`)
- `UBUNTU_VER`:Ubuntu 版本(默认:`24.04`)
- `PY_VER`:Python 版本(默认:`3.12`)
### 卷挂载
两种方法都需要挂载以下目录:
- `./checkpoints:/app/checkpoints` - 模型权重目录
- `./references:/app/references` - 参考音频目录
### 环境变量
- `COMPILE=1` - 启用 `torch.compile`,可提升推理速度(约 10 倍)
- `GRADIO_SERVER_NAME=0.0.0.0` - WebUI 服务地址
- `GRADIO_SERVER_PORT=7860` - WebUI 服务端口
- `API_SERVER_NAME=0.0.0.0` - API 服务地址
- `API_SERVER_PORT=8080` - API 服务端口
!!! note
Docker 容器默认从 `/app/checkpoints` 读取模型权重。启动容器前请先下载好所需权重。
!!! warning
GPU 支持需要 NVIDIA Docker runtime。若仅使用 CPU,请移除 `--gpus all` 并使用 CPU 镜像。
================================================
FILE: entrypoint.sh
================================================
#!/bin/bash
CUDA_ENABLED=${CUDA_ENABLED:-true}
DEVICE=""
if [ "${CUDA_ENABLED}" != "true" ]; then
DEVICE="--device cpu"
fi
exec python tools/run_webui.py ${DEVICE}
================================================
FILE: fish_speech/callbacks/__init__.py
================================================
from .grad_norm import GradNormMonitor
__all__ = ["GradNormMonitor"]
================================================
FILE: fish_speech/callbacks/grad_norm.py
================================================
from typing import Optional, Union
import lightning.pytorch as pl
import torch
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback
from torch import Tensor, nn
from torch.utils._foreach_utils import (
_group_tensors_by_device_and_dtype,
_has_foreach_support,
)
@torch.no_grad()
def grad_norm(
parameters: Union[Tensor, list[Tensor]],
norm_type: float = 2.0,
) -> float:
"""
Returns the norm of the gradients of the given parameters.
Args:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
norm_type (float): type of the used p-norm.
Returns:
Total norm of the parameter gradients (viewed as a single vector).
""" # noqa: E501
if isinstance(parameters, Tensor):
parameters = [parameters]
grads = [p.grad for p in parameters if p.grad is not None]
if len(grads) == 0:
return None
first_device = grads[0].device
grouped_grads: dict[
tuple[torch.device, torch.dtype], list[list[Tensor]]
] = _group_tensors_by_device_and_dtype(
[[g.detach() for g in grads]]
) # type: ignore[assignment]
norms = []
for (device, _), ([grads], _) in grouped_grads.items():
if _has_foreach_support(grads, device=device):
norms.extend(torch._foreach_norm(grads, norm_type))
else:
norms.extend([torch.norm(g, norm_type) for g in grads])
return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
class GradNormMonitor(Callback):
"""
Callback that computes the gradient norm of the model parameters.
"""
def __init__(
self,
norm_type: float = 2.0,
logging_interval: str = "step",
sub_module: Optional[Union[str, list[str]]] = None,
) -> None:
"""
Args:
norm_type (float): type of the used p-norm.
logging_interval (str): "step" or "epoch".
"""
super().__init__()
self.norm_type = norm_type
self.logging_interval = logging_interval
self.sub_module = sub_module
def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None:
"""
Computes the gradient norm of the model parameters and logs it to the logger.
Args:
trainer (Trainer): The trainer object
model (LightningModule): The current lightningModule
"""
lightning_model = model
if self.sub_module is None:
return self.log_sub_module_grad_norm(lightning_model, model, "")
sub_modules = self.sub_module
if isinstance(sub_modules, str):
sub_modules = [sub_modules]
for sub_module in sub_modules:
self.log_sub_module_grad_norm(
lightning_model, getattr(model, sub_module), f"/{sub_module}"
)
def log_sub_module_grad_norm(
self, lightning_model: LightningModule, model: nn.Module, path: str
) -> None:
grad_norm_val = grad_norm(model.parameters(), self.norm_type)
if grad_norm_val is None:
return
on_step = self.logging_interval == "step"
lightning_model.log(
f"train{path}/grad_norm",
grad_norm_val,
on_step=on_step,
on_epoch=not on_step,
)
================================================
FILE: fish_speech/configs/base.yaml
================================================
# Base configuration for training a model
paths:
run_dir: results/${project}
ckpt_dir: ${paths.run_dir}/checkpoints
hydra:
run:
dir: ${paths.run_dir}
# Lightning Trainer
trainer:
_target_: lightning.pytorch.trainer.Trainer
default_root_dir: ${paths.run_dir}
accelerator: gpu
num_nodes: 1
devices: auto
strategy:
_target_: lightning.pytorch.strategies.DDPStrategy
process_group_backend: nccl # This should be override when training on windows
precision: bf16-mixed
# disable validation by epoch end
check_val_every_n_epoch: null
val_check_interval: 5000
max_steps: 100_000
# Use torch.backends.cudnn.benchmark to speed up training
benchmark: true
# Callbacks
callbacks:
model_checkpoint:
_target_: lightning.pytorch.callbacks.ModelCheckpoint
dirpath: ${paths.ckpt_dir}
filename: "step_{step:09d}"
save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt
save_top_k: 5 # save 5 latest checkpoints
monitor: step # use step to monitor checkpoints
mode: max # save the latest checkpoint with the highest global_step
every_n_epochs: null # don't save checkpoints by epoch end
every_n_train_steps: 5000 # save checkpoints every 5000 steps
auto_insert_metric_name: false
model_summary:
_target_: lightning.pytorch.callbacks.ModelSummary
max_depth: 2 # the maximum depth of layer nesting that the summary will include
learning_rate_monitor:
_target_: lightning.pytorch.callbacks.LearningRateMonitor
logging_interval: step
log_momentum: false
grad_norm_monitor:
_target_: fish_speech.callbacks.GradNormMonitor
norm_type: 2
logging_interval: step
# Logger
logger:
tensorboard:
_target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
save_dir: "${paths.run_dir}/tensorboard/"
name: null
log_graph: false
default_hp_metric: true
prefix: ""
# wandb:
# _target_: lightning.pytorch.loggers.wandb.WandbLogger
# # name: "" # name of the run (normally generated by wandb)
# save_dir: "${paths.run_dir}"
# offline: False
# id: null # pass correct id to resume experiment!
# anonymous: null # enable anonymous logging
# project: "fish-speech"
# log_model: False # upload lightning ckpts
# prefix: "" # a string to put at the beginning of metric keys
# # entity: "" # set to name of your wandb team
# group: ""
# tags: ["vq", "hq", "finetune"]
# job_type: ""
# Loop
train: true
test: false
================================================
FILE: fish_speech/configs/lora/r_8_alpha_16.yaml
================================================
_target_: fish_speech.models.text2semantic.lora.LoraConfig
r: 8
lora_alpha: 16
lora_dropout: 0.01
================================================
FILE: fish_speech/configs/modded_dac_vq.yaml
================================================
_target_: fish_speech.models.dac.modded_dac.DAC
# Model setup
sample_rate: 44100
encoder_dim: 64
encoder_rates: [2, 4, 8, 8]
decoder_dim: 1536
decoder_rates: [8, 8, 4, 2]
encoder_transformer_layers: [0, 0, 0, 4]
decoder_transformer_layers: [4, 0, 0, 0]
transformer_general_config:
_target_: fish_speech.models.dac.modded_dac.ModelArgs
_partial_: true
block_size: 8192
n_local_heads: -1
head_dim: 64
rope_base: 10000
norm_eps: 1e-5
dropout_rate: 0.1
attn_dropout_rate: 0.1
channels_first: true
# Quantization
quantizer:
_target_: fish_speech.models.dac.rvq.DownsampleResidualVectorQuantize
input_dim: 1024
n_codebooks: 9
codebook_size: 1024
codebook_dim: 8
quantizer_dropout: 0.5
downsample_factor: [2, 2]
post_module: &transformer_module
_target_: fish_speech.models.dac.modded_dac.WindowLimitedTransformer
causal: true
window_size: 128 # empirically this does not seem to matter
input_dim: 1024
config: &transformer_config
_target_: fish_speech.models.dac.modded_dac.ModelArgs
block_size: 2048
n_layer: 8
n_head: 16
dim: 1024
intermediate_size: 3072
n_local_heads: -1
head_dim: 64
rope_base: 10000
norm_eps: 1e-5
dropout_rate: 0.1
attn_dropout_rate: 0.1
channels_first: true
pre_module: *transformer_module
semantic_codebook_size: 4096
================================================
FILE: fish_speech/configs/text2semantic_finetune.yaml
================================================
defaults:
- base
- _self_
project: text2semantic_finetune_dual_ar
max_length: 4096
pretrained_ckpt_path: checkpoints/openaudio-s1-mini
# Lightning Trainer
trainer:
accumulate_grad_batches: 1
gradient_clip_val: 1.0
gradient_clip_algorithm: "norm"
max_steps: 10000
precision: bf16-true
limit_val_batches: 10
val_check_interval: 100
# strategy:
# find_unused_parameters: true
# static_graph: true
# Dataset Configuration
tokenizer:
_target_: fish_speech.tokenizer.FishTokenizer
model_path: ${pretrained_ckpt_path}/tokenizer.tiktoken
# Dataset Configuration
train_dataset:
_target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionIterableDataset
proto_files:
- data/protos
tokenizer: ${tokenizer}
causal: true
max_length: ${max_length}
use_speaker: false
interactive_prob: 0.7
val_dataset:
_target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionIterableDataset
proto_files:
- data/protos
tokenizer: ${tokenizer}
causal: true
max_length: ${max_length}
use_speaker: false
interactive_prob: 0.7
data:
_target_: fish_speech.datasets.semantic.SemanticDataModule
train_dataset: ${train_dataset}
val_dataset: ${val_dataset}
num_workers: 4
batch_size: 4
tokenizer: ${tokenizer}
max_length: ${max_length}
# Model Configuration
model:
_target_: fish_speech.models.text2semantic.lit_module.TextToSemantic
model:
_target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained
path: ${pretrained_ckpt_path}
load_weights: true
max_length: ${max_length}
lora_config: null
optimizer:
_target_: torch.optim.AdamW
_partial_: true
lr: 1e-4
weight_decay: 0
betas: [0.9, 0.95]
eps: 1e-5
lr_scheduler:
_target_: torch.optim.lr_scheduler.LambdaLR
_partial_: true
lr_lambda:
_target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda
_partial_: true
num_warmup_steps: 10
# Callbacks
callbacks:
model_checkpoint:
every_n_train_steps: ${trainer.val_check_interval}
================================================
FILE: fish_speech/content_sequence.py
================================================
from dataclasses import dataclass, field
from typing import List, Literal, Union
import numpy as np
import torch
from fish_speech.tokenizer import (
IM_END_TOKEN,
MODALITY_TOKENS,
FishTokenizer,
)
def restore_ndarray(obj, to_tensor: bool = False):
if isinstance(obj, dict) and "__ndarray__" in obj:
obj = np.frombuffer(obj["data"], dtype=obj["dtype"]).reshape(obj["shape"])
if to_tensor and isinstance(obj, np.ndarray):
obj = torch.from_numpy(obj.copy())
return obj
@dataclass
class BasePart:
type: Literal["text", "vq", "audio"] | None = None
cal_loss: bool = False
@dataclass(kw_only=True)
class VQPart(BasePart):
type = "vq"
codes: torch.Tensor
def __post_init__(self: "VQPart"):
self.type = "vq"
self.codes = restore_ndarray(self.codes, to_tensor=True)
@dataclass(kw_only=True)
class TextPart(BasePart):
type = "text"
text: str | None = None
tokens: list[int] | None = None
def __post_init__(self: "TextPart"):
self.type = "text"
if self.text is None and self.tokens is None:
raise ValueError("Either text or tokens must be provided")
@dataclass(kw_only=True)
class AudioPart(BasePart):
type = "audio"
features: torch.Tensor
def __post_init__(self: "AudioPart"):
self.type = "audio"
self.features = restore_ndarray(self.features, to_tensor=True)
@dataclass(kw_only=True)
class EncodedMessage:
tokens: torch.Tensor
labels: torch.Tensor
vq_mask_tokens: torch.Tensor | None = None
vq_mask_labels: torch.Tensor | None = None
vq_parts: list[torch.Tensor]
vq_require_losses: torch.Tensor | None = None
audio_parts: list[torch.Tensor]
audio_masks: torch.Tensor | None = None
metadata: dict | None = None
@dataclass
class ContentSequence:
"""
Flexible sequence of content parts that supports interleaved multimodal format.
Example format: <|interleave|><|speaker:1|> TEXT AUDIO <|im_end|><|speaker:2|> TEXT AUDIO <|im_end|>
"""
parts: list[BasePart] = field(default_factory=list)
modality: Literal["text", "voice", "interleave"] | None = None
metadata: dict | None = None
def __init__(
self: "ContentSequence",
parts: list[BasePart | dict] | None = None,
modality: Literal["text", "voice", "interleave"] | None = None,
metadata: dict | None = None,
):
self.modality = modality
self.metadata = metadata or {}
fixed_parts = []
for part in parts or []:
if isinstance(part, dict):
if part["type"] == "vq":
part = VQPart(**part)
elif part["type"] == "audio":
part = AudioPart(**part)
elif part["type"] == "text":
part = TextPart(**part)
else:
raise ValueError(f"Unsupported part type: {part['type']}")
fixed_parts.append(part)
self.parts = fixed_parts
# If modality is specified, add it at the beginning if it's not already there
if self.modality and not (
len(self.parts) > 0
and isinstance(self.parts[0], dict) is False
and isinstance(self.parts[0], TextPart)
and self.parts[0].text is not None
and self.parts[0].text.startswith(MODALITY_TOKENS[self.modality])
):
modality_token = MODALITY_TOKENS[self.modality]
self.parts.insert(0, TextPart(text=modality_token))
def append(
self: "ContentSequence",
part_or_parts: Union[BasePart, List[BasePart]],
add_end: bool = False,
speaker: Union[str, int] | None = None,
):
"""
Append a part or list of parts to the sequence.
Args:
part_or_parts: A single part or list of parts to add
add_end: Whether to add the IM_END_TOKEN after these parts
speaker: Optional speaker identifier (name or ID) to add before the parts
"""
# Convert single part to list
parts_to_add = (
[part_or_parts] if not isinstance(part_or_parts, list) else part_or_parts
)
# Add speaker token if specified
if speaker is not None:
speaker_token = f"<|speaker:{speaker}|>"
self.parts.append(TextPart(text=speaker_token))
# Add all the parts
self.parts.extend(parts_to_add)
# Add end token if requested
if add_end:
self.parts.append(
TextPart(text=IM_END_TOKEN, cal_loss=self.parts[-1].cal_loss)
)
def encode(
self: "ContentSequence",
tokenizer: FishTokenizer,
add_shift: bool = True,
ignore_loss_tokens: list[str] = [],
) -> EncodedMessage:
"""
Encode the sequence parts into tokens for the model.
Args:
tokenizer: The tokenizer to use
add_shift: Whether to shift tokens for next-token prediction
ignore_loss_tokens: List of token strings to ignore when calculating loss
Returns:
EncodedMessage with tensors ready for the model
"""
all_tokens = []
all_labels = []
# Multi-modal elements
vq_parts = []
vq_masks = []
vq_require_losses = []
audio_parts = []
audio_masks = []
# Optimization: Batch conversion for ignore tokens
ignore_loss_token_ids = []
if ignore_loss_tokens:
# Use the wrapper method which uses convert_tokens_to_ids
ignore_loss_token_ids = [
tokenizer.get_token_id(i) for i in ignore_loss_tokens
]
for part in self.parts:
if isinstance(part, TextPart):
if part.tokens is None:
assert part.text is not None
# Optimization: Explicitly disable special tokens (BOS/EOS)
# because we are constructing the sequence manually
tokens = tokenizer.encode(part.text, add_special_tokens=False)
else:
tokens = part.tokens
tokens = torch.tensor(tokens, dtype=torch.long)
elif isinstance(part, VQPart):
# Critical Optimization: Vectorized mapping
# Instead of loop lookup: [tokenizer.semantic_id_to_token_id[i] for i in codes]
# We use arithmetic offset: code + semantic_begin_id
# This assumes semantic tokens are contiguous in the vocab (DualAR requirement)
curr_codes = part.codes.clone().to(torch.int)
# Use int64 (long) for token IDs to avoid overflow or type mismatch in embedding
tokens = (curr_codes[0] + tokenizer.semantic_begin_id).to(torch.long)
vq_parts.append(curr_codes)
vq_require_losses.append(part.cal_loss)
else:
raise ValueError(f"Unsupported part type: {type(part)}")
all_tokens.append(tokens)
# Set masks for different part types
if isinstance(part, VQPart):
vq_masks.append(torch.ones_like(tokens, dtype=torch.bool))
audio_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
elif isinstance(part, AudioPart):
vq_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
audio_mask = torch.ones_like(tokens, dtype=torch.bool)
audio_mask[0] = False # Skip start token
audio_mask[-1] = False # Skip end token
audio_masks.append(audio_mask)
else:
vq_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
audio_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
# Set labels based on whether we want to calculate loss for this part
if part.cal_loss and not isinstance(part, AudioPart):
all_labels.append(tokens.clone())
else:
all_labels.append(torch.full_like(tokens, -100))
# Concatenate all tensors
if not all_tokens:
# Handle empty case safely
tokens = torch.empty(0, dtype=torch.long)
labels = torch.empty(0, dtype=torch.long)
vq_masks = torch.empty(0, dtype=torch.bool)
audio_masks = torch.empty(0, dtype=torch.bool)
else:
tokens = torch.cat(all_tokens, dim=0)
labels = torch.cat(all_labels, dim=0)
vq_masks = torch.cat(vq_masks, dim=0)
audio_masks = torch.cat(audio_masks, dim=0)
vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool)
# Apply shift if needed for next-token prediction
vq_mask_tokens = vq_masks
vq_mask_labels = vq_masks
if add_shift and len(tokens) > 0:
tokens = tokens[:-1]
labels = labels[1:]
vq_masks = vq_masks[:-1]
vq_mask_tokens = vq_mask_tokens[:-1]
vq_mask_labels = vq_mask_labels[1:]
audio_masks = audio_masks[:-1]
# Ignore specified tokens
for i in ignore_loss_token_ids:
if i is not None:
labels[labels == i] = -100
return EncodedMessage(
tokens=tokens,
labels=labels,
vq_parts=vq_parts,
vq_mask_tokens=vq_mask_tokens,
vq_mask_labels=vq_mask_labels,
vq_require_losses=vq_require_losses,
audio_parts=audio_parts,
audio_masks=audio_masks,
metadata=self.metadata,
)
def encode_for_inference(
self: "ContentSequence",
tokenizer: FishTokenizer,
num_codebooks: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
encoded = self.encode(tokenizer, add_shift=False)
tokens = encoded.tokens
# Use int32 for prompt cache to save memory, convert to model dtype later if needed
# Or keep as input_ids (long)
values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.long)
values[0] = tokens
if (encoded.vq_parts is None or len(encoded.vq_parts) == 0) and (
encoded.audio_parts is None or len(encoded.audio_parts) == 0
):
return values, None, None
audio_parts = None
audio_masks = None
if encoded.vq_parts is not None and len(encoded.vq_parts) > 0:
vq_parts = encoded.vq_parts
# List[Tensor(1, T)] -> Tensor(1, Total_T) -> Tensor(1, Total_T)
# Ensure we are handling the list concatenation correctly
if len(vq_parts) > 1:
# We need to be careful here: vq_parts is a list of tensors from different VQPart segments
# They correspond to encoded.vq_mask_tokens
# Since we just want to fill the 'values' tensor at the right positions:
all_vq_codes = torch.cat(
vq_parts, dim=1
) # Shape: (C, Total_Semantic_Tokens)
else:
all_vq_codes = vq_parts[0]
# Values[0] is already the Main Token ID (Semantic Begin + Code)
# Values[1:] should be the codes themselves
values[1:, encoded.vq_mask_tokens] = all_vq_codes.to(dtype=torch.long)
if encoded.audio_parts is not None and len(encoded.audio_parts) > 0:
audio_parts = torch.cat(encoded.audio_parts, dim=0)
audio_masks = encoded.audio_masks[None, :]
return values, audio_masks, audio_parts
def visualize(
self: "ContentSequence",
tokenizer: FishTokenizer,
ignore_loss_tokens: list[str] = [],
merge_semantic_tokens: bool = False,
):
"""
Visualize the encoded sequence with color-coded tokens.
Blue/cyan tokens contribute to loss, green tokens do not.
"""
encoded = self.encode(
tokenizer, add_shift=False, ignore_loss_tokens=ignore_loss_tokens
)
# Colors for alternating tokens
colors = {
"blue": "\033[94m", # Light blue
"cyan": "\033[96m", # Cyan
"green": "\033[92m", # Light green
"dark_green": "\033[32m", # Dark green
}
blue_idx = 0
green_idx = 0
def print_in_blue(x):
nonlocal blue_idx
color = colors["blue"] if blue_idx % 2 == 0 else colors["cyan"]
print(f"{color}{x}\033[0m", end="")
blue_idx += 1
def print_in_green(x):
nonlocal green_idx
color = colors["green"] if green_idx % 2 == 0 else colors["dark_green"]
print(f"{color}{x}\033[0m", end="")
green_idx += 1
def print_semantic_token(x, count):
val = f"[<|semantic|>x{count}]"
if x == -100:
print_in_green(val)
else:
print_in_blue(val)
count_semantic_tokens = 0
semantic_label = None
for tok, lab in zip(encoded.tokens, encoded.labels):
token_id = int(tok.item())
if merge_semantic_tokens:
if (
tokenizer.semantic_begin_id <= token_id <= tokenizer.semantic_end_id
and (semantic_label is None or semantic_label == lab)
):
count_semantic_tokens += 1
semantic_label = lab
continue
elif count_semantic_tokens > 0:
print_semantic_token(semantic_label, count_semantic_tokens)
count_semantic_tokens = 0
semantic_label = None
# Use HF decode
val = tokenizer.decode([token_id])
# Simple fallback for visualization if decode returns empty or weird stuff for special tokens
if not val:
val = f"<{token_id}>"
if lab == -100:
print_in_green(val)
else:
print_in_blue(val)
if merge_semantic_tokens and count_semantic_tokens > 0:
print_semantic_token(semantic_label, count_semantic_tokens)
print()
================================================
FILE: fish_speech/conversation.py
================================================
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Literal
import torch
from transformers import PreTrainedTokenizerFast
from fish_speech.content_sequence import (
AudioPart,
BasePart,
ContentSequence,
EncodedMessage,
TextPart,
VQPart,
)
from fish_speech.tokenizer import IM_END_TOKEN, IM_START_TOKEN, MODALITY_TOKENS
@dataclass(kw_only=True)
class Message:
role: Literal["system", "user", "assistant"]
parts: list[BasePart] = field(default_factory=list)
add_im_start: bool = True
add_im_end: bool = True
cal_loss: bool = False
modality: Literal["text", "voice", "interleave"] | None = None
# By default, ignore the loss of the auto-generated im_start token
ignore_im_start_loss: bool = True
@dataclass
class Conversation:
messages: list[Message]
def __init__(self: "Conversation", messages: list[Message] | None = None):
self.messages = messages or []
def _build_content_sequence(
self: "Conversation",
metadata: dict | None = None,
) -> ContentSequence:
"""
Build a ContentSequence from all messages.
Handles cal_loss inheritance from message to part level.
"""
all_parts = []
for message in self.messages:
# Add im_start
if message.add_im_start:
modality_token = (
MODALITY_TOKENS[message.modality] if message.modality else ""
)
all_parts.append(
TextPart(
text=f"{IM_START_TOKEN}{message.role}\n{modality_token}",
cal_loss=not message.ignore_im_start_loss,
)
)
# Add message parts
for part in message.parts:
# Inherit cal_loss from message if not set at part level
if not hasattr(part, "cal_loss") or part.cal_loss is False:
new_part = deepcopy(part)
new_part.cal_loss = message.cal_loss
all_parts.append(new_part)
else:
all_parts.append(part)
# Add im_end
if message.add_im_end:
all_parts.append(
TextPart(text=IM_END_TOKEN + "\n", cal_loss=message.cal_loss)
)
return ContentSequence(parts=all_parts, modality=None, metadata=metadata)
def encode(
self: "Conversation",
tokenizer: any,
add_shift: bool = True,
ignore_loss_tokens: list[str] = [],
metadata: dict | None = None,
max_length: int | None = None,
) -> EncodedMessage:
# Build ContentSequence from messages
content_seq = self._build_content_sequence(metadata=metadata)
return content_seq.encode(
tokenizer,
add_shift=add_shift,
ignore_loss_tokens=ignore_loss_tokens,
max_length=max_length,
)
def encode_for_inference(
self: "Conversation",
tokenizer: any,
num_codebooks: int,
metadata: dict | None = None,
):
content_seq = self._build_content_sequence(metadata=metadata)
return content_seq.encode_for_inference(tokenizer, num_codebooks=num_codebooks)
def visualize(
self: "Conversation",
tokenizer: PreTrainedTokenizerFast,
ignore_loss_tokens: list[str] = [],
merge_semantic_tokens: bool = False,
merge_audio_tokens: bool = False,
use_color: bool = True,
):
"""
Visualize the encoded sequence with color-coded tokens.
Blue/cyan tokens contribute to loss, green tokens do not.
"""
# Build ContentSequence from messages and use its visualize method
content_seq = self._build_content_sequence()
content_seq.visualize(
tokenizer,
ignore_loss_tokens=ignore_loss_tokens,
merge_semantic_tokens=merge_semantic_tokens,
)
def append(self: "Conversation", message: Message):
self.messages.append(message)
def to_content_sequence(
self: "Conversation",
metadata: dict | None = None,
) -> ContentSequence:
"""
Convert the Conversation to a ContentSequence.
This method builds a ContentSequence from all messages,
handling cal_loss inheritance from message to part level.
Args:
metadata: Optional metadata to include in the ContentSequence
Returns:
ContentSequence with all messages converted to parts
"""
return self._build_content_sequence(metadata=metadata)
if __name__ == "__main__":
# Test the new implementation with the same API
message0 = Message(
role="user",
parts=[
TextPart(text="Hello, how are you?"),
VQPart(codes=torch.zeros((4, 10))),
],
cal_loss=False,
)
message1 = Message(
role="assistant",
parts=[TextPart(text="I'm fine, thank you.")],
cal_loss=True,
)
conversation = Conversation([message0, message1])
tokenizer = PreTrainedTokenizerFast.from_pretrained("checkpoints/agent-0.6b-debug")
# Test with enhanced visualization from ContentSequence
print("Basic visualization:")
conversation.visualize(tokenizer)
print("\nWith merged semantic tokens:")
conversation.visualize(tokenizer, merge_semantic_tokens=True)
print("\nWithout colors:")
conversation.visualize(tokenizer, use_color=False)
================================================
FILE: fish_speech/datasets/concat_repeat.py
================================================
import bisect
import random
from typing import Iterable
from torch.utils.data import Dataset, IterableDataset
class ConcatRepeatDataset(Dataset):
datasets: list[Dataset]
cumulative_sizes: list[int]
repeats: list[int]
@staticmethod
def cumsum(sequence, repeats):
r, s = [], 0
for dataset, repeat in zip(sequence, repeats):
l = len(dataset) * repeat
r.append(l + s)
s += l
return r
def __init__(self, datasets: Iterable[Dataset], repeats: list[int]):
super().__init__()
self.datasets = list(datasets)
self.repeats = repeats
assert len(self.datasets) > 0, "datasets should not be an empty iterable"
assert len(self.datasets) == len(
repeats
), "datasets and repeats should have the same length"
for d in self.datasets:
assert not isinstance(
d, IterableDataset
), "ConcatRepeatDataset does not support IterableDataset"
self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
dataset = self.datasets[dataset_idx]
return dataset[sample_idx % len(dataset)]
================================================
FILE: fish_speech/datasets/protos/text-data.proto
================================================
syntax = "proto3";
package text_data;
message Semantics {
repeated uint32 values = 1;
}
message Sentence {
repeated string texts = 1;
repeated Semantics semantics = 3;
}
message TextData {
string source = 1;
string name = 2;
repeated Sentence sentences = 4;
}
message SampledData {
string source = 1;
string name = 2;
repeated Sentence samples = 3;
}
================================================
FILE: fish_speech/datasets/protos/text_data_pb2.py
================================================
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: text-data.proto
# Protobuf Python Version: 4.25.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3'
)
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_globals["_SEMANTICS"]._serialized_start = 30
_globals["_SEMANTICS"]._serialized_end = 57
_globals["_SENTENCE"]._serialized_start = 59
_globals["_SENTENCE"]._serialized_end = 125
_globals["_TEXTDATA"]._serialized_start = 127
_globals["_TEXTDATA"]._serialized_end = 207
_globals["_SAMPLEDDATA"]._serialized_start = 209
_globals["_SAMPLEDDATA"]._serialized_end = 290
# @@protoc_insertion_point(module_scope)
================================================
FILE: fish_speech/datasets/protos/text_data_stream.py
================================================
import struct
from .text_data_pb2 import TextData
def read_pb_stream(f):
while True:
buf = f.read(4)
if len(buf) == 0:
break
size = struct.unpack("I", buf)[0]
buf = f.read(size)
text_data = TextData()
text_data.ParseFromString(buf)
yield text_data
def write_pb_stream(f, text_data):
buf = text_data.SerializeToString()
f.write(struct.pack("I", len(buf)))
f.write(buf)
def pack_pb_stream(text_data):
buf = text_data.SerializeToString()
return struct.pack("I", len(buf)) + buf
def split_pb_stream(f):
while True:
head = f.read(4)
if len(head) == 0:
break
size = struct.unpack("I", head)[0]
buf = f.read(size)
yield head + buf
================================================
FILE: fish_speech/datasets/semantic.py
================================================
import random
from dataclasses import dataclass
from itertools import chain
from pathlib import Path
from random import Random
from typing import Optional, Union
import numpy as np
import pyarrow.parquet as pq
import torch
import torch.nn.functional as F
from datasets.download.streaming_download_manager import xopen
from huggingface_hub import HfApi
from lightning import LightningDataModule
from torch.distributed import get_rank, get_world_size, is_initialized
from torch.utils.data import DataLoader, Dataset, IterableDataset, get_worker_info
from fish_speech.content_sequence import ContentSequence, TextPart, VQPart
CODEBOOK_PAD_TOKEN_ID = 0
from fish_speech.datasets.protos.text_data_pb2 import SampledData
from fish_speech.datasets.protos.text_data_stream import read_pb_stream
from fish_speech.text.clean import clean_text
from fish_speech.tokenizer import FishTokenizer
from fish_speech.utils import RankedLogger
from fish_speech.utils.braceexpand import braceexpand
log = RankedLogger(__name__, rank_zero_only=True)
def split_by_rank_worker(files):
# We need to know the total number of devices
# to split the data properly
total_devices = 1
if is_initialized():
total_devices = get_world_size()
worker_info = get_worker_info()
if worker_info is not None:
total_devices *= worker_info.num_workers
if len(files) < total_devices:
# Repeat the files N times to match the number of devices
files = files * (total_devices // len(files) + 1)
# DDP
if is_initialized():
files = files[get_rank() :: get_world_size()]
# Split by worker
if worker_info is not None:
files = files[worker_info.id :: worker_info.num_workers]
return files
class AutoTextSemanticInstructionIterableDataset(IterableDataset):
"""
Auto Augment Dataset by Speaker
1. Random concatenate multiple sentences from the same speaker to form a longer sentence
2. Automatically normalize the text
For interactive mode, we use the following format (multiple sequences):
[INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
For non-interactive mode, we use the following format (one long sequence):
[INST] text [/INST] ...
"""
def __init__(
self,
proto_files: list[str],
seed: int = 42,
interactive_prob: float = 0.5,
max_length: int = 1024,
tokenizer: FishTokenizer = None,
use_speaker: bool | float = True,
causal: bool = True,
num_codebooks: Optional[int] = None,
skip_text_prob: float = 0.0,
):
"""
Args:
proto_files: proto buf files if using local data
seed: random seed
interactive_prob: probability to use interactive mode
max_length: max length of the text
tokenizer: tokenizer
use_speaker: include speaker information in the prompt
causal: use causal sampling when using local data, disable will lead to random sampling
num_codebooks: number of codebooks, if None, it will be automatically detected
skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
"""
super().__init__()
assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
self.seed = seed
self.max_length = max_length
self.tokenizer = tokenizer
self.interactive_prob = interactive_prob
self.use_speaker = use_speaker
self.proto_files = proto_files
self.causal = causal
self.num_codebooks = num_codebooks
self.skip_text_prob = skip_text_prob
self.groups = None
def __iter__(self):
while True:
yield self.augment()
def init_mock_data_server(self):
if self.groups is not None:
return
# Expand the proto files
expanded_proto_files = []
for filename in self.proto_files:
for i in braceexpand(filename):
i = Path(i)
if i.is_file():
expanded_proto_files.append(i)
elif i.is_dir():
expanded_proto_files.extend(i.rglob("*.proto"))
expanded_proto_files.extend(i.rglob("*.protos"))
else:
raise ValueError(f"{i} is not a file or directory")
expanded_proto_files = sorted(expanded_proto_files)
Random(self.seed).shuffle(expanded_proto_files)
self.groups = []
shard_proto_files = split_by_rank_worker(expanded_proto_files)
log.info(
f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
)
count = 0
for filename in shard_proto_files:
with open(filename, "rb") as f:
for text_data in read_pb_stream(f):
self.groups.append(text_data)
count += 1
log.info(f"Read total {count} groups of data")
# Shuffle the lines
Random(self.seed).shuffle(self.groups)
self.group_weights = [len(i.sentences) for i in self.groups]
def sample_data(self):
if self.groups is None:
self.init_mock_data_server()
# Shuffle unique lines, estimate that each sample is at least 20 tokens
num_samples = self.max_length // 20
# choice group based on their number of samples
group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
if self.causal:
# Sample in order
if num_samples >= len(group.sentences):
samples = group.sentences
else:
begin = random.randint(0, len(group.sentences) - num_samples)
samples = group.sentences[begin : begin + num_samples]
else:
samples = random.choices(
group.sentences, k=min(num_samples, len(group.sentences))
)
return SampledData(
source=group.source,
name=group.name,
samples=samples,
)
def pack_sentences(
self,
sentences: list[str],
semantics: list,
# speaker: Optional[str] = None, # speaker is now handled by tokens
skip_text: bool = False,
):
seq = ContentSequence()
seq.append(TextPart(text="Speak out the provided text."))
# User's turn
cated_sentences = " ".join(sentences)
if skip_text:
cated_sentences = "<|skip_text|>"
seq.append(
TextPart(text=f"<|speaker:user|> {cated_sentences}"),
add_end=True,
)
# Assistant's turn
vq_codes = [x.values for x in semantics[0]]
vq_codes_tensor = torch.tensor(vq_codes).to(torch.int32)
# 将 cal_loss=True 直接关联到 VQPart 上,这比之前更精确
vq_part = VQPart(codes=vq_codes_tensor, cal_loss=True)
# 将多个 parts 一起添加,最后也加上 <|im_end|>
seq.append(
[TextPart(text="<|speaker:assistant|> <|voice|>"), vq_part],
add_end=True,
)
encoded = seq.encode(
tokenizer=self.tokenizer,
)
num_codebooks = (
len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
)
tokens_raw = encoded.tokens
tokens = torch.zeros((num_codebooks + 1, len(tokens_raw)), dtype=torch.int)
tokens[0] = tokens_raw
vq_parts = encoded.vq_parts
vq_parts = [part.to(tokens.device) for part in vq_parts]
vq_parts = torch.cat(vq_parts, dim=1)
tokens[1:, encoded.vq_mask_tokens] = vq_parts
labels_raw = encoded.labels
labels = torch.full((num_codebooks + 1, len(labels_raw)), -100, dtype=torch.int)
labels[0, :] = labels_raw
labels[1:, encoded.vq_mask_labels] = vq_parts
labels[1:, -1:] = CODEBOOK_PAD_TOKEN_ID
tokens = tokens.long()
labels = labels.long()
# Verify the padding is correct, and the last token is eos
assert (tokens[1:, ~(encoded.vq_mask_tokens)] == CODEBOOK_PAD_TOKEN_ID).all()
assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
return tokens, labels
def augment(self):
response = self.sample_data()
if len(response.samples) == 0:
# Invalid group
return None
samples = list(response.samples)
all_tokens, all_labels = [], []
while len(samples) > 0:
sentence = samples.pop(0)
text = clean_text(random.choice(sentence.texts))
tokens, labels = self.pack_sentences(
sentences=[text],
semantics=[sentence.semantics],
# speaker=response.name if use_speaker else None,
skip_text=random.random() < self.skip_text_prob,
)
all_tokens.append(tokens)
all_labels.append(labels)
tokens = torch.cat(all_tokens, dim=1)
labels = torch.cat(all_labels, dim=1)
# Verify that the length is correct
assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
data = {"tokens": tokens, "labels": labels}
return data
class AutoTextSemanticInstructionDataset(Dataset):
"""
Auto Augment Dataset by Speaker
1. Random concatenate multiple sentences from the same speaker to form a longer sentence
2. Automatically normalize the text
For interactive mode, we use the following format (multiple sequences):
[INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
For non-interactive mode, we use the following format (one long sequence):
[INST] text [/INST] ...
"""
def __init__(
self,
proto_files: list[str],
seed: int = 42,
interactive_prob: float = 0.5,
max_length: int = 1024,
tokenizer: FishTokenizer = None,
use_speaker: bool | float = True,
causal: bool = True,
num_codebooks: Optional[int] = None,
skip_text_prob: float = 0.0,
):
"""
Args:
proto_files: proto buf files if using local data
seed: random seed
interactive_prob: probability to use interactive mode
max_length: max length of the text
tokenizer: tokenizer
use_speaker: include speaker information in the prompt
causal: use causal sampling when using local data, disable will lead to random sampling
num_codebooks: number of codebooks, if None, it will be automatically detected
skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
"""
super().__init__()
assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
self.seed = seed
self.max_length = max_length
self.tokenizer = tokenizer
self.interactive_prob = interactive_prob
self.use_speaker = use_speaker
self.proto_files = proto_files
self.causal = causal
self.num_codebooks = num_codebooks
self.skip_text_prob = skip_text_prob
self.data = []
self._init_data()
def _init_data(self):
expanded_proto_files = []
for filename in self.proto_files:
for i in braceexpand(filename):
i = Path(i)
if i.is_file():
expanded_proto_files.append(i)
elif i.is_dir():
expanded_proto_files.extend(i.rglob("*.proto"))
expanded_proto_files.extend(i.rglob("*.protos"))
else:
raise ValueError(f"{i} is not a file or directory")
expanded_proto_files = sorted(expanded_proto_files)
Random(self.seed).shuffle(expanded_proto_files)
groups = []
shard_proto_files = split_by_rank_worker(expanded_proto_files)
log.info(
f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
)
count = 0
for filename in shard_proto_files:
with open(filename, "rb") as f:
for text_data in read_pb_stream(f):
groups.append(text_data)
count += 1
log.info(f"Read total {count} groups of data")
for group in groups:
if len(group.sentences) == 0:
continue
samples = list(group.sentences)
for sentence in samples:
text = clean_text(random.choice(sentence.texts))
tokens, labels = self.pack_sentences(
sentences=[text],
semantics=[sentence.semantics],
skip_text=random.random() < self.skip_text_prob,
)
self.data.append({"tokens": tokens, "labels": labels})
random.Random(self.seed).shuffle(self.data)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def pack_sentences(
self,
sentences: list[str],
semantics: list,
skip_text: bool = False,
):
messages = [
Message(
role="system",
parts=[TextPart(text="Speak out the provided text.")],
)
]
cated_sentences = " ".join(sentences)
if skip_text:
cated_sentences = "<|skip_text|>"
messages.append(
Message(
role="user",
parts=[TextPart(text=cated_sentences)],
)
)
vq_codes = [x.values for x in semantics[0]]
vq_codes_tensor = torch.tensor(vq_codes).to(torch.int32)
vqpart = VQPart(codes=vq_codes_tensor)
messages.append(
Message(
role="assistant",
parts=[TextPart(text="<|voice|>"), vqpart],
cal_loss=True,
)
)
num_codebooks = (
len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
)
conversation = Conversation(messages=messages)
encoded = conversation.encode(
tokenizer=self.tokenizer,
)
tokens_raw = encoded.tokens
tokens = torch.zeros((num_codebooks + 1, len(tokens_raw)), dtype=torch.int)
tokens[0] = tokens_raw
vq_parts = encoded.vq_parts
vq_parts = [part.to(tokens.device) for part in vq_parts]
vq_parts = torch.cat(vq_parts, dim=1)
tokens[1:, encoded.vq_mask_tokens] = vq_parts
labels_raw = encoded.labels
labels = torch.full((num_codebooks + 1, len(labels_raw)), -100, dtype=torch.int)
labels[0, :] = labels_raw
labels[1:, encoded.vq_mask_labels] = vq_parts
labels[1:, -1:] = CODEBOOK_PAD_TOKEN_ID
tokens = tokens.long()
labels = labels.long()
assert (tokens[1:, ~(encoded.vq_mask_tokens)] == CODEBOOK_PAD_TOKEN_ID).all()
assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
return tokens, labels
class InterleaveDataset(IterableDataset):
def __init__(
self,
datasets: list[IterableDataset],
probabilities: list[float],
seed: int = 42,
):
super().__init__()
self.datasets = datasets
self.probabilities = probabilities
self.seed = seed
def __iter__(self):
rng = np.random.default_rng(self.seed)
dataset_iterators = [iter(dataset) for dataset in self.datasets]
while True:
# Random choice one
dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
dataset_iterator = dataset_iterators[dataset_idx]
try:
yield next(dataset_iterator)
except StopIteration:
# Exhausted, create a new iterator
dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
yield next(dataset_iterators[dataset_idx])
@dataclass
class TextDataCollator:
tokenizer: FishTokenizer
max_length: int = 1024
def __call__(self, examples):
if "negative_tokens" in examples:
positive_examples = []
negative_examples = []
for i in examples:
positive_examples.append(
{
"tokens": i["tokens"],
"labels": i["labels"],
}
)
negative_examples.append(
{
"tokens": i["negative_tokens"],
"labels": i["negative_labels"],
}
)
examples = positive_examples + negative_examples
return self.batchify(examples)
def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
tokens, attention_masks, labels = [], [], []
# Calculate the max length
max_tokens_length = 0
for example in examples:
max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
max_tokens_length = min(max_tokens_length, self.max_length)
for example in examples:
_tokens = example[tokens_key][:, :max_tokens_length]
_labels = example[labels_key][:, :max_tokens_length]
_attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
tokens_length = _tokens.size(1)
_attention_mask[:tokens_length] = False
assert tokens_length == _labels.size(
1
), f"{tokens_length} != {_labels.size(1)}"
if tokens_length < max_tokens_length:
_tokens = F.pad(
_tokens,
(0, max_tokens_length - tokens_length),
value=self.tokenizer.get_token_id("<|end_of_text|>"),
)
_tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
_labels = F.pad(
_labels, (0, max_tokens_length - _labels.size(1)), value=-100
)
tokens.append(_tokens)
attention_masks.append(_attention_mask)
labels.append(_labels)
tokens = torch.stack(tokens, dim=0)
attention_masks = torch.stack(attention_masks, dim=0)
labels = torch.stack(labels, dim=0)
return {
"inputs": tokens,
"attention_masks": attention_masks,
"labels": labels,
}
class SemanticDataModule(LightningDataModule):
def __init__(
self,
train_dataset: Union[
AutoTextSemanticInstructionDataset,
AutoTextSemanticInstructionIterableDataset,
InterleaveDataset,
],
val_dataset: Union[
AutoTextSemanticInstructionDataset,
AutoTextSemanticInstructionIterableDataset,
InterleaveDataset,
],
batch_size: int = 32,
tokenizer: FishTokenizer = None,
max_length: int = 1024,
num_workers: int = 4,
):
super().__init__()
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.batch_size = batch_size
self.tokenizer = tokenizer
self.max_length = max_length
self.num_workers = num_workers
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
collate_fn=TextDataCollator(self.tokenizer, self.max_length),
num_workers=self.num_workers,
persistent_workers=True,
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
collate_fn=TextDataCollator(self.tokenizer, self.max_length),
num_workers=self.num_workers,
persistent_workers=True,
)
if __name__ == "__main__":
from tqdm import tqdm
ds = AutoTextSemanticInstructionDataset(
["data/protos"],
tokenizer=FishTokenizer("checkpoints/fish-speech-1.5/tokenizer.tiktoken"),
use_speaker=False,
interactive_prob=1.0,
skip_text_prob=0.5,
)
for i in range(100):
# Please uncomment line 235 to visualize the tokenized message
print(ds[i])
================================================
FILE: fish_speech/datasets/vqgan.py
================================================
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import librosa
import numpy as np
import torch
from lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset
from fish_speech.utils import RankedLogger
logger = RankedLogger(__name__, rank_zero_only=False)
class VQGANDataset(Dataset):
def __init__(
self,
filelist: str,
sample_rate: int = 32000,
hop_length: int = 640,
slice_frames: Optional[int] = None,
):
super().__init__()
filelist = Path(filelist)
root = filelist.parent
self.files = [
root / line.strip()
for line in filelist.read_text(encoding="utf-8").splitlines()
if line.strip()
]
self.sample_rate = sample_rate
self.hop_length = hop_length
self.slice_frames = slice_frames
def __len__(self):
return len(self.files)
def get_item(self, idx):
file = self.files[idx]
audio, _ = librosa.load(file, sr=self.sample_rate, mono=True)
# Slice audio and features
if (
self.slice_frames is not None
and audio.shape[0] > self.slice_frames * self.hop_length
):
start = np.random.randint(
0, audio.shape[0] - self.slice_frames * self.hop_length
)
audio = audio[start : start + self.slice_frames * self.hop_length]
if len(audio) == 0:
return None
max_value = np.abs(audio).max()
if max_value > 1.0:
audio = audio / max_value
return {
"audio": torch.from_numpy(audio),
}
def __getitem__(self, idx):
try:
return self.get_item(idx)
except Exception as e:
import traceback
traceback.print_exc()
logger.error(f"Error loading {self.files[idx]}: {e}")
return None
@dataclass
class VQGANCollator:
def __call__(self, batch):
batch = [x for x in batch if x is not None]
audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
audio_maxlen = audio_lengths.max()
# Rounds up to nearest multiple of 2 (audio_lengths)
audios = []
for x in batch:
audios.append(
torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
)
return {
"audios": torch.stack(audios),
"audio_lengths": audio_lengths,
}
class VQGANDataModule(LightningDataModule):
def __init__(
self,
train_dataset: VQGANDataset,
val_dataset: VQGANDataset,
batch_size: int = 32,
num_workers: int = 4,
val_batch_size: Optional[int] = None,
):
super().__init__()
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.batch_size = batch_size
self.val_batch_size = val_batch_size or batch_size
self.num_workers = num_workers
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
collate_fn=VQGANCollator(),
num_workers=self.num_workers,
shuffle=True,
persistent_workers=True,
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.val_batch_size,
collate_fn=VQGANCollator(),
num_workers=self.num_workers,
persistent_workers=True,
)
if __name__ == "__main__":
dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt")
dataloader = DataLoader(
dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator()
)
for batch in dataloader:
print(batch["audios"].shape)
print(batch["features"].shape)
print(batch["audio_lengths"])
print(batch["feature_lengths"])
break
================================================
FILE: fish_speech/i18n/README.md
================================================
## i18n Folder Attribution
The `i18n` folder within the `fish_speech` directory contains files initially sourced from the RVC project. In compliance with the MIT license under which these files were released, we acknowledge the original authors and sources below:
### fish_speech/i18n/core.py
**Related code from RVC:**
[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py)
**Initial commit:**
add localization(添加本地化) [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#35](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/35)
**Initial author:**
[@L4Ph](https://github.com/L4Ph)
### fish_speech/i18n/scan.py
**Related code from RVC:**
[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py)
**Initial commit:**
File for detecting i18n missing keys [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#1058](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/1058)
**Initial author:**
[@towzeur](https://github.com/towzeur)
We appreciate the contributions of the RVC project and its authors.
================================================
FILE: fish_speech/i18n/__init__.py
================================================
from .core import i18n
__all__ = ["i18n"]
================================================
FILE: fish_speech/i18n/core.py
================================================
import json
import locale
from pathlib import Path
I18N_FILE_PATH = Path(__file__).parent / "locale"
DEFAULT_LANGUAGE = "en_US"
def load_language_list(language):
with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f:
language_list = json.load(f)
return language_list
class I18nAuto:
def __init__(self):
i18n_file = Path(".locale")
if i18n_file.exists():
with open(i18n_file, "r", encoding="utf-8") as f:
language = f.read().strip()
else:
# getlocale can't identify the system's language ((None, None))
language = locale.getdefaultlocale()[0]
if (I18N_FILE_PATH / f"{language}.json").exists() is False:
language = DEFAULT_LANGUAGE
self.language = language
self.language_map = load_language_list(language)
def __call__(self, key):
return self.language_map.get(key, key)
def __repr__(self):
return "Use Language: " + self.language
i18n = I18nAuto()
================================================
FILE: fish_speech/i18n/locale/en_US.json
================================================
{
"16-mixed is recommended for 10+ series GPU": "16-mixed is recommended for 10+ series GPU",
"5 to 10 seconds of reference audio, useful for specifying speaker.": "5 to 10 seconds of reference audio, useful for specifying speaker.",
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).",
"Accumulate Gradient Batches": "Accumulate Gradient Batches",
"Add to Processing Area": "Add to Processing Area",
"Added path successfully!": "Added path successfully!",
"Advanced Config": "Advanced Config",
"Base LLAMA Model": "Base LLAMA Model",
"Batch Inference": "Batch Inference",
"Batch Size": "Batch Size",
"Changing with the Model Path": "Changing with the Model Path",
"Chinese": "Chinese",
"Compile Model": "Compile Model",
"Compile the model can significantly reduce the inference time, but will increase cold start time": "Compile the model can significantly reduce the inference time, but will increase cold start time",
"Copy": "Copy",
"Data Preprocessing": "Data Preprocessing",
"Data Preprocessing Path": "Data Preprocessing Path",
"Data Source": "Data Source",
"Decoder Model Config": "Decoder Model Config",
"Decoder Model Path": "Decoder Model Path",
"Disabled": "Disabled",
"Enable Reference Audio": "Enable Reference Audio",
"English": "English",
"Error Message": "Error Message",
"File Preprocessing": "File Preprocessing",
"Generate": "Generate",
"Generated Audio": "Generated Audio",
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format",
"Infer interface is closed": "Infer interface is closed",
"Inference Configuration": "Inference Configuration",
"Inference Server Configuration": "Inference Server Configuration",
"Inference Server Error": "Inference Server Error",
"Inferring interface is launched at {}": "Inferring interface is launched at {}",
"Initial Learning Rate": "Initial Learning Rate",
"Input Audio & Source Path for Transcription": "Input Audio & Source Path for Transcription",
"Input Text": "Input Text",
"Invalid path: {}": "Invalid path: {}",
"It is recommended to use CUDA, if you have low configuration, use CPU": "It is recommended to use CUDA, if you have low configuration, use CPU",
"Iterative Prompt Length, 0 means off": "Iterative Prompt Length, 0 means off",
"Japanese": "Japanese",
"LLAMA Configuration": "LLAMA Configuration",
"LLAMA Model Config": "LLAMA Model Config",
"LLAMA Model Path": "LLAMA Model Path",
"Labeling Device": "Labeling Device",
"LoRA Model to be merged": "LoRA Model to be merged",
"Maximum Audio Duration": "Maximum Audio Duration",
"Maximum Length per Sample": "Maximum Length per Sample",
"Maximum Training Steps": "Maximum Training Steps",
"Maximum tokens per batch, 0 means no limit": "Maximum tokens per batch, 0 means no limit",
"Merge": "Merge",
"Merge LoRA": "Merge LoRA",
"Merge successfully": "Merge successfully",
"Minimum Audio Duration": "Minimum Audio Duration",
"Model Output Path": "Model Output Path",
"Model Size": "Model Size",
"Move": "Move",
"Move files successfully": "Move files successfully",
"No audio generated, please check the input text.": "No audio generated, please check the input text.",
"No selected options": "No selected options",
"Number of Workers": "Number of Workers",
"Open Inference Server": "Open Inference Server",
"Open Labeler WebUI": "Open Labeler WebUI",
"Open Tensorboard": "Open Tensorboard",
"Opened labeler in browser": "Opened labeler in browser",
"Optional Label Language": "Optional Label Language",
"Optional online ver": "Optional online ver",
"Output Path": "Output Path",
"Path error, please check the model file exists in the corresponding path": "Path error, please check the model file exists in the corresponding path",
"Precision": "Precision",
"Probability of applying Speaker Condition": "Probability of applying Speaker Condition",
"Put your text here.": "Put your text here.",
"Reference Audio": "Reference Audio",
"Reference Text": "Reference Text",
"Related code and weights are released under FISH AUDIO RESEARCH LICENSE.": "Related code and weights are released under FISH AUDIO RESEARCH LICENSE.",
"Remove Selected Data": "Remove Selected Data",
"Removed path successfully!": "Removed path successfully!",
"Repetition Penalty": "Repetition Penalty",
"Save model every n steps": "Save model every n steps",
"Select LLAMA ckpt": "Select LLAMA ckpt",
"Select VITS ckpt": "Select VITS ckpt",
"Select VQGAN ckpt": "Select VQGAN ckpt",
"Select source file processing method": "Select source file processing method",
"Select the model to be trained (Depending on the Tab page you are on)": "Select the model to be trained (Depending on the Tab page you are on)",
"Selected: {}": "Selected: {}",
"Speaker": "Speaker",
"Speaker is identified by the folder name": "Speaker is identified by the folder name",
"Start Training": "Start Training",
"Streaming Audio": "Streaming Audio",
"Streaming Generate": "Streaming Generate",
"Tensorboard Host": "Tensorboard Host",
"Tensorboard Log Path": "Tensorboard Log Path",
"Tensorboard Port": "Tensorboard Port",
"Tensorboard interface is closed": "Tensorboard interface is closed",
"Tensorboard interface is launched at {}": "Tensorboard interface is launched at {}",
"Text is too long, please keep it under {} characters.": "Text is too long, please keep it under {} characters.",
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.",
"Training Configuration": "Training Configuration",
"Training Error": "Training Error",
"Training stopped": "Training stopped",
"Type name of the speaker": "Type name of the speaker",
"Type the path or select from the dropdown": "Type the path or select from the dropdown",
"Use LoRA": "Use LoRA",
"Use LoRA can save GPU memory, but may reduce the quality of the model": "Use LoRA can save GPU memory, but may reduce the quality of the model",
"Use filelist": "Use filelist",
"Use large for 10G+ GPU, medium for 5G, small for 2G": "Use large for 10G+ GPU, medium for 5G, small for 2G",
"VITS Configuration": "VITS Configuration",
"VQGAN Configuration": "VQGAN Configuration",
"Validation Batch Size": "Validation Batch Size",
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "View the status of the preprocessing folder (use the slider to control the depth of the tree)",
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.",
"WebUI Host": "WebUI Host",
"WebUI Port": "WebUI Port",
"Whisper Model": "Whisper Model",
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5).": "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5).",
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU",
"latest": "latest",
"new": "new",
"Realtime Transform Text": "Realtime Transform Text",
"Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)",
"Text Normalization": "Text Normalization",
"Select Example Audio": "Select Example Audio"
}
================================================
FILE: fish_speech/i18n/locale/es_ES.json
================================================
{
"16-mixed is recommended for 10+ series GPU": "se recomienda 16-mixed para GPU de la serie 10+",
"5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de audio de referencia, útil para especificar el hablante.",
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Un modelo de texto a voz basado en VQ-GAN y Llama desarrollado por [Fish Audio](https://fish.audio).",
"Accumulate Gradient Batches": "Acumular lotes de gradientes",
"Add to Processing Area": "Agregar al Área de Procesamiento",
"Added path successfully!": "¡Ruta agregada exitosamente!",
"Advanced Config": "Configuración Avanzada",
"Base LLAMA Model": "Modelo Base LLAMA",
"Batch Inference": "Inferencia por Lote",
"Batch Size": "Tamaño del Lote",
"Changing with the Model Path": "Cambiando con la Ruta del Modelo",
"Chinese": "Chino",
"Compile Model": "Compilar Modelo",
"Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar el modelo puede reducir significativamente el tiempo de inferencia, pero aumentará el tiempo de inicio en frío",
"Copy": "Copiar",
"Data Preprocessing": "Preprocesamiento de Datos",
"Data Preprocessing Path": "Ruta de Preprocesamiento de Datos",
"Data Source": "Fuente de Datos",
"Decoder Model Config": "Configuración del modelo decodificador",
"Decoder Model Path": "Ruta del modelo decodificador",
"Disabled": "Desactivado",
"Enable Reference Audio": "Habilitar Audio de Referencia",
"English": "Inglés",
"Error Message": "Mensaje de Error",
"File Preprocessing": "Preprocesamiento de Archivos",
"Generate": "Generar",
"Generated Audio": "Audio Generado",
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Si no hay texto correspondiente para el audio, aplique ASR para asistencia, soporte para formato .txt o .lab",
"Infer interface is closed": "La interfaz de inferencia está cerrada",
"Inference Configuration": "Configuración de Inferencia",
"Inference Server Configuration": "Configuración del Servidor de Inferencia",
"Inference Server Error": "Error del Servidor de Inferencia",
"Inferring interface is launched at {}": "La interfaz de inferencia se ha lanzado en {}",
"Initial Learning Rate": "Tasa de Aprendizaje Inicial",
"Input Audio & Source Path for Transcription": "Audio de Entrada y Ruta de Origen para Transcripción",
"Input Text": "Texto de Entrada",
"Invalid path: {}": "Ruta inválida: {}",
"It is recommended to use CUDA, if you have low configuration, use CPU": "Se recomienda usar CUDA, si tiene una configuración baja, use CPU",
"Iterative Prompt Length, 0 means off": "Longitud de la Indicación Iterativa, 0 significa apagado",
"Japanese": "Japonés",
"LLAMA Configuration": "Configuración de LLAMA",
"LLAMA Model Config": "Configuración del Modelo LLAMA",
"LLAMA Model Path": "Ruta del Modelo LLAMA",
"Labeling Device": "Dispositivo de Etiquetado",
"LoRA Model to be merged": "Modelo LoRA a fusionar",
"Maximum Audio Duration": "Duración máxima de audio",
"Maximum Length per Sample": "Longitud Máxima por Muestra",
"Maximum Training Steps": "Pasos Máximos de Entrenamiento",
"Maximum tokens per batch, 0 means no limit": "Máximo de tokens por lote, 0 significa sin límite",
"Merge": "Fusionar",
"Merge LoRA": "Fusionar LoRA",
"Merge successfully": "Fusionado exitosamente",
"Minimum Audio Duration": "Duración mínima de audio",
"Model Output Path": "Ruta de Salida del Modelo",
"Model Size": "Tamaño del Modelo",
"Move": "Mover",
"Move files successfully": "Archivos movidos exitosamente",
"No audio generated, please check the input text.": "No se generó audio, por favor verifique el texto de entrada.",
"No selected options": "No hay opciones seleccionadas",
"Number of Workers": "Número de Trabajadores",
"Open Inference Server": "Abrir Servidor de Inferencia",
"Open Labeler WebUI": "Abrir Interfaz Web del Etiquetador",
"Open Tensorboard": "Abrir Tensorboard",
"Opened labeler in browser": "Se abrió el etiquetador en el navegador",
"Optional Label Language": "Idioma de Etiquetado Opcional",
"Optional online ver": "Ver en línea opcional",
"Output Path": "Ruta de Salida",
"Path error, please check the model file exists in the corresponding path": "Error de ruta, por favor verifique que el archivo del modelo exista en la ruta correspondiente",
"Precision": "Precisión",
"Probability of applying Speaker Condition": "Probabilidad de aplicar Condición de Hablante",
"Put your text here.": "Ponga su texto aquí.",
"Reference Audio": "Audio de Referencia",
"Reference Text": "Texto de Referencia",
"Related code and weights are released under FISH AUDIO RESEARCH LICENSE.": "El código relacionado y los pesos se publican bajo la FISH AUDIO RESEARCH LICENSE.",
"Remove Selected Data": "Eliminar Datos Seleccionados",
"Removed path successfully!": "¡Ruta eliminada exitosamente!",
"Repetition Penalty": "Penalización por Repetición",
"Save model every n steps": "Guardar modelo cada n pasos",
"Select LLAMA ckpt": "Seleccionar punto de control LLAMA",
"Select VITS ckpt": "Seleccionar punto de control VITS",
"Select VQGAN ckpt": "Seleccionar punto de control VQGAN",
"Select source file processing method": "Seleccione el método de procesamiento de archivos fuente",
"Select the model to be trained (Depending on the Tab page you are on)": "Seleccione el modelo a entrenar (Dependiendo de la pestaña en la que se encuentre)",
"Selected: {}": "Seleccionado: {}",
"Speaker": "Hablante",
"Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta",
"Start Training": "Iniciar Entrenamiento",
"Streaming Audio": "transmisión de audio",
"Streaming Generate": "síntesis en flujo",
"Tensorboard Host": "Host de Tensorboard",
"Tensorboard Log Path": "Ruta de Registro de Tensorboard",
"Tensorboard Port": "Puerto de Tensorboard",
"Tensorboard interface is closed": "La interfaz de Tensorboard está cerrada",
"Tensorboard interface is launched at {}": "La interfaz de Tensorboard se ha lanzado en {}",
"Text is too long, please keep it under {} characters.": "El texto es demasiado largo, por favor manténgalo por debajo de {} caracteres.",
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "La ruta de la carpeta de entrada a la izquierda o la lista de archivos. Ya sea que esté marcado o no, se utilizará para el entrenamiento posterior en esta lista.",
"Training Configuration": "Configuración de Entrenamiento",
"Training Error": "Error de Entrenamiento",
"Training stopped": "Entrenamiento detenido",
"Type name of the speaker": "Escriba el nombre del hablante",
"Type the path or select from the dropdown": "Escriba la ruta o seleccione de la lista desplegable",
"Use LoRA": "Usar LoRA",
"Use LoRA can save GPU memory, but may reduce the quality of the model": "Usar LoRA puede ahorrar memoria GPU, pero puede reducir la calidad del modelo",
"Use filelist": "Usar lista de archivos",
"Use large for 10G+ GPU, medium for 5G, small for 2G": "Use grande para GPU de 10G+, mediano para 5G, pequeño para 2G",
"VITS Configuration": "Configuración de VITS",
"VQGAN Configuration": "Configuración de VQGAN",
"Validation Batch Size": "Tamaño del Lote de Validación",
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Vea el estado de la carpeta de preprocesamiento (use el control deslizante para controlar la profundidad del árbol)",
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "No somos responsables de ningún mal uso del modelo, por favor considere sus leyes y regulaciones locales antes de usarlo.",
"WebUI Host": "Host de WebUI",
"WebUI Port": "Puerto de WebUI",
"Whisper Model": "Modelo Whisper",
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5).": "Puede encontrar el código fuente [aquí](https://github.com/fishaudio/fish-speech) y los modelos [aquí](https://huggingface.co/fishaudio/fish-speech-1.5).",
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "Se recomienda bf16-true para GPU de la serie 30+, se recomienda 16-mixed para GPU de la serie 10+",
"latest": "más reciente",
"new": "nuevo",
"Realtime Transform Text": "Transformación de Texto en Tiempo Real",
"Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)",
"Text Normalization": "Normalización de Texto",
"Select Example Audio": "Selecionar áudio de exemplo"
}
================================================
FILE: fish_speech/i18n/locale/ja_JP.json
================================================
{
"16-mixed is recommended for 10+ series GPU": "10シリーズ以降のGPUには16-mixedをお勧めします",
"5 to 10 seconds of reference audio, useful for specifying speaker.": "話者を指定するのに役立つ、5~10秒のリファレンスオーディオ。",
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)が開発したVQ-GANとLlamaに基づくテキスト音声合成モデル。",
"Accumulate Gradient Batches": "勾配バッチの累積",
"Add to Processing Area": "処理エリアに追加",
"Added path successfully!": "パスの追加に成功しました!",
"Advanced Config": "詳細設定",
"Base LLAMA Model": "基本LLAMAモデル",
"Batch Inference": "バッチ推論",
"Batch Size": "バッチサイズ",
"Changing with the Model Path": "モデルのパスに伴って変化する",
"Chinese": "中国語",
"Compile Model": "モデルのコンパイル",
"Compile the model can significantly reduce the inference time, but will increase cold start time": "モデルをコンパイルすると推論時間を大幅に短縮できますが、コールドスタート時間が長くなります",
"Copy": "コピー",
"Data Preprocessing": "データ前処理",
"Data Preprocessing Path": "データ前処理パス",
"Data Source": "データソース",
"Decoder Model Config": "デコーダーモデルの構成",
"Decoder Model Path": "デコーダーモデルのパス",
"Disabled": "無効",
"Enable Reference Audio": "リファレンスオーディオを有効にする",
"English": "英語",
"Error Message": "エラーメッセージ",
"File Preprocessing": "文書前处理",
"Generate": "生成",
"Generated Audio": "生成されたオーディオ",
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "音声に対応するテキストがない場合は、ASRを適用してサポートします。.txtまたは.lab形式をサポートしています",
"Infer interface is closed": "推論インターフェースが閉じられています",
"Inference Configuration": "推論設定",
"Inference Server Configuration": "推論サーバー設定",
"Inference Server Error": "推論サーバーエラー",
"Inferring interface is launched at {}": "推論インターフェースが{}で起動しました",
"Initial Learning Rate": "初期学習率",
"Input Audio & Source Path for Transcription": "入力オーディオと文字起こしのソースパス",
"Input Text": "入力テキスト",
"Invalid path: {}": "無効なパス: {}",
"It is recommended to use CUDA, if you have low configuration, use CPU": "CUDAの使用をお勧めします。低い構成の場合はCPUを使用してください",
"Iterative Prompt Length, 0 means off": "反復プロンプト長。0はオフを意味します",
"Japanese": "日本語",
"LLAMA Configuration": "LLAMA設定",
"LLAMA Model Config": "LLAMAモデル設定",
"LLAMA Model Path": "LLAMAモデルパス",
"Labeling Device": "ラベリングデバイス",
"LoRA Model to be merged": "マージするLoRAモデル",
"Maximum Audio Duration": "最大オーディオの長さ",
"Maximum Length per Sample": "サンプルあたりの最大長",
"Maximum Training Steps": "最大トレーニングステップ数",
"Maximum tokens per batch, 0 means no limit": "バッチあたりの最大トークン数。0は制限なしを意味します",
"Merge": "マージ",
"Merge LoRA": "LoRAのマージ",
"Merge successfully": "マージに成功しました",
"Minimum Audio Duration": "最小オーディオの長さ",
"Model Output Path": "モデル出力パス",
"Model Size": "モデルサイズ",
"Move": "移動",
"Move files successfully": "ファイルの移動に成功しました",
"No audio generated, please check the input text.": "オーディオが生成されていません。入力テキストを確認してください。",
"No selected options": "選択されたオプションはありません",
"Number of Workers": "ワーカー数",
"Open Inference Server": "推論サーバーを開く",
"Open Labeler WebUI": "ラベラーWebUIを開く",
"Open Tensorboard": "Tensorboardを開く",
"Opened labeler in browser": "ブラウザでラベラーを開きました",
"Optional Label Language": "オプションのラベル言語",
"Optional online ver": "オプションのオンラインバージョン",
"Output Path": "出力パス",
"Path error, please check the model file exists in the corresponding path": "パスエラー。対応するパスにモデルファイルが存在するか確認してください",
"Precision": "精度",
"Probability of applying Speaker Condition": "話者条件を適用する確率",
"Put your text here.": "ここにテキストを入力してください。",
"Reference Audio": "リファレンスオーディオ",
"Reference Text": "リファレンステキスト",
"Related code and weights are released under FISH AUDIO RESEARCH LICENSE.": "関連コードと重みはFISH AUDIO RESEARCH LICENSEの下でリリースされます。",
"Remove Selected Data": "選択したデータを削除",
"Removed path successfully!": "パスの削除に成功しました!",
"Repetition Penalty": "反復ペナルティ",
"Save model every n steps": "nステップごとにモデルを保存",
"Select LLAMA ckpt": " LLAMA チェックポイントを選択",
"Select VITS ckpt": "VITS チェックポイントを選択",
"Select VQGAN ckpt": "VQGAN チェックポイントを選択",
"Select source file processing method": "ソースファイルの処理方法を選択",
"Select the model to be trained (Depending on the Tab page you are on)": "タブページに応じてトレーニングするモデルを選択してください",
"Selected: {}": "選択済み: {}",
"Speaker": "話者",
"Speaker is identified by the folder name": "話者はフォルダ名で識別されます",
"Start Training": "トレーニング開始",
"Streaming Audio": "ストリーミングオーディオ",
"Streaming Generate": "ストリーミング合成",
"Tensorboard Host": "Tensorboardホスト",
"Tensorboard Log Path": "Tensorboardログパス",
"Tensorboard Port": "Tensorboardポート",
"Tensorboard interface is closed": "Tensorboardインターフェースが閉じられています",
"Tensorboard interface is launched at {}": "Tensorboardインターフェースが{}で起動されました",
"Text is too long, please keep it under {} characters.": "テキストが長すぎます。{}文字以内に抑えてください。",
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左側の入力フォルダまたはファイルリストのパス。チェックの有無にかかわらず、このリストの後続のトレーニングに使用されます。",
"Training Configuration": "トレーニング設定",
"Training Error": "トレーニングエラー",
"Training stopped": "トレーニングが停止しました",
"Type name of the speaker": "話者の名前を入力",
"Type the path or select from the dropdown": "パスを入力するか、ドロップダウンから選択してください",
"Use LoRA": "LoRAを使用",
"Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRAを使用するとGPUメモリを節約できますが、モデルの品質が低下する可能性があります",
"Use filelist": "ファイルリストを使用",
"Use large for 10G+ GPU, medium for 5G, small for 2G": "10G以上のGPUには大、5Gには中、2Gには小を使用してください",
"VITS Configuration": "VITS の構成",
"VQGAN Configuration": "VQGAN の構成",
"Validation Batch Size": "検証バッチサイズ",
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "前処理フォルダの状態を表示(スライダーを使用してツリーの深さを制御)",
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "モデルの誤用については一切責任を負いません。使用する前に、現地の法律と規制を考慮してください。",
"WebUI Host": "WebUIホスト",
"WebUI Port": "WebUIポート",
"Whisper Model": "Whisperモデル",
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5).": "ソースコードは[こちら](https://github.com/fishaudio/fish-speech)、モデルは[こちら](https://huggingface.co/fishaudio/fish-speech-1.5)にあります。",
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30シリーズ以降のGPUにはbf16-trueを、10シリーズ以降のGPUには16-mixedをお勧めします",
"latest": "最新",
"new": "新規",
"Realtime Transform Text": "リアルタイム変換テキスト",
"Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)",
"Text Normalization": "テキスト正規化",
"Select Example Audio": "サンプル音声を選択"
}
================================================
FILE: fish_speech/i18n/locale/ko_KR.json
================================================
{
"16-mixed is recommended for 10+ series GPU": "10+ 시리즈 GPU에는 16-mixed를 권장합니다.",
"5 to 10 seconds of reference audio, useful for specifying speaker.": "화자를 특정하는 데 유의미한 5~10초의 길이의 참조 오디오 데이터.",
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)에서 개발한 VQ-GAN 및 Llama 기반의 텍스트 음성 변환 모델.",
"Accumulate Gradient Batches": "그라디언트 배치 누적",
"Add to Processing Area": "처리 영역에 추가",
"Added path successfully!": "경로가 성공적으로 추가되었습니다!",
"Advanced Config": "고급 설정",
"Base LLAMA Model": "기본 LLAMA 모델",
"Batch Inference": "배치 추론",
"Batch Size": "배치 크기",
"Changing with the Model Path": "모델 경로에 따라 변경 중",
"Chinese": "중국어",
"Compile Model": "모델 컴파일",
"Compile the model can significantly reduce the inference time, but will increase cold start time": "모델을 컴파일하면 추론 시간이 크게 줄어들지만, 초기 시작 시간이 길어집니다.",
"Copy": "복사",
"Data Preprocessing": "데이터 전처리",
"Data Preprocessing Path": "데이터 전처리 경로",
"Data Source": "데이터 소스",
"Decoder Model Config": "디코더 모델 설정",
"Decoder Model Path": "디코더 모델 경로",
"Disabled": "비활성화 됨",
"Enable Reference Audio": "참고 음성 활성화",
"English": "영어",
"Error Message": "오류 메시지",
"File Preprocessing": "파일 전처리",
"Generate": "생성",
"Generated Audio": "생성된 오디오",
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "오디오애 대응하는 텍스트가 없을 경우, ASR을 적용해 지원하며, .txt 또는 .lab 형식을 지원합니다.",
"Infer interface is closed": "추론 인터페이스가 닫혔습니다.",
"Inference Configuration": "추론 설정",
"Inference Server Configuration": "추론 서버 설정",
"Inference Server Error": "추론 서버 오류",
"Inferring interface is launched at {}": "추론 인터페이스가 {}에서 시작되었습니다.",
"Initial Learning Rate": "초기 학습률",
"Input Audio & Source Path for Transcription": "전사할 입력 오디오 및 소스 경로",
"Input Text": "입력 텍스트",
"Invalid path: {}": "유효하지 않은 경로: {}",
"It is recommended to use CUDA, if you have low configuration, use CPU": "CUDA 사용을 권장하며, 낮은 사양일 경우 CPU를 사용하는 것을 권장합니다.",
"Iterative Prompt Length, 0 means off": "반복 프롬프트 길이. (0:비활성화)",
"Japanese": "일본어",
"LLAMA Configuration": "LLAMA 설정",
"LLAMA Model Config": "LLAMA 모델 설정",
"LLAMA Model Path": "LLAMA 모델 경로",
"Labeling Device": "라벨링 장치",
"LoRA Model to be merged": "병합할 LoRA 모델",
"Maximum Audio Duration": "최대 오디오 길이",
"Maximum Length per Sample": "샘플당 최대 길이",
"Maximum Training Steps": "최대 학습 단계",
"Maximum tokens per batch, 0 means no limit": "배치당 최대 토큰 수(0:제한 없음)",
"Merge": "병합",
"Merge LoRA": "LoRA 병합",
"Merge successfully": "성공적으로 병합 되었습니다.",
"Minimum Audio Duration": "최소 오디오 길이",
"Model Output Path": "모델 출력 경로",
"Model Size": "모델 크기",
"Move": "이동",
"Move files successfully": "파일이 성공적으로 이동되었습니다.",
"No audio generated, please check the input text.": "생성된 오디오가 없습니다. 입력된 텍스트를 확인하세요.",
"No selected options": "옵션이 선택되지 않았습니다.",
"Number of Workers": "작업자 수",
"Open Inference Server": "추론 서버 열기",
"Open Labeler WebUI": "라벨러 WebUI 열기",
"Open Tensorboard": "Tensorboard 열기",
"Opened labeler in browser": "브라우저에서 라벨러가 열렸습니다.",
"Optional Label Language": "선택적 라벨 언어",
"Optional online ver": "온라인 버전 선택",
"Output Path": "출력 경로",
"Path error, please check the model file exists in the corresponding path": "경로 오류, 해당 경로에 모델 파일이 있는지 확인하십시오.",
"Precision": "정밀도",
"Probability of applying Speaker Condition": "화자 조건 적용 확률",
"Put your text here.": "여기에 텍스트를 입력하세요.",
"Reference Audio": "참고 오디오",
"Reference Text": "참고 텍스트",
"Related code and weights are released under FISH AUDIO RESEARCH LICENSE.": "관련 코드 및 가중치는 FISH AUDIO RESEARCH LICENSE 하에 배포됩니다.",
"Remove Selected Data": "선택한 데이터 제거",
"Removed path successfully!": "경로가 성공적으로 제거되었습니다!",
"Repetition Penalty": "반복 패널티",
"Save model every n steps": "n 단계마다 모델 저장",
"Select LLAMA ckpt": "LLAMA ckpt 선택",
"Select VITS ckpt": "VITS ckpt 선택",
"Select VQGAN ckpt": "VQGAN ckpt 선택",
"Select source file processing method": "소스 파일 처리 방법 선택",
"Select the model to be trained (Depending on the Tab page you are on)": "학습할 모델 선택(탭 페이지에 따라 다름)",
"Selected: {}": "선택됨: {}",
"Speaker": "화자",
"Speaker is identified by the folder name": "화자는 폴더 이름으로 식별됩니다",
"Start Training": "학습 시작",
"Streaming Audio": "스트리밍 오디오",
"Streaming Generate": "스트리밍 생성",
"Tensorboard Host": "Tensorboard 호스트",
"Tensorboard Log Path": "Tensorboard 로그 경로",
"Tensorboard Port": "Tensorboard 포트",
"Tensorboard interface is closed": "Tensorboard 인터페이스가 닫혔습니다",
"Tensorboard interface is launched at {}": "Tensorboard 인터페이스가 {}에서 시작되었습니다.",
"Text is too long, please keep it under {} characters.": "텍스트가 너무 깁니다. {}자 이하로 입력해주세요.",
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "왼쪽의 입력 폴더 경로 또는 파일 목록의 경로. 체크 여부에 관계없이 이 목록에서 후속 학습에 사용됩니다.",
"Training Configuration": "학습 설정",
"Training Error": "학습 오류",
"Training stopped": "학습이 중지되었습니다.",
"Type name of the speaker": "화자의 이름을 입력하세요.",
"Type the path or select from the dropdown": "경로를 입력하거나 드롭다운에서 선택하세요.",
"Use LoRA": "LoRA 사용",
"Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRA를 사용하면 GPU 메모리를 절약할 수 있지만, 모델의 품질이 저하될 수 있습니다.",
"Use filelist": "파일 목록 사용",
"Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 환경에선 large, 5G에선 medium, 2G에선 small을 사용할 것을 권장합니다.",
"VITS Configuration": "VITS 설정",
"VQGAN Configuration": "VQGAN 설정",
"Validation Batch Size": "검증 배치 크기",
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "전처리 폴더의 상태를 확인합니다(슬라이더를 사용하여 트리의 깊이를 조절합니다)",
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "모델의 오용에 대해 책임지지 않습니다. 사용하기 전에 현지 법률과 규정을 고려하시길 바랍니다.",
"WebUI Host": "WebUI 호스트",
"WebUI Port": "WebUI 포트",
"Whisper Model": "Whisper 모델",
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5).": "소스 코드는 [이곳](https://github.com/fishaudio/fish-speech)에서, 모델은 [이곳](https://huggingface.co/fishaudio/fish-speech-1.5)에서 확인하실 수 있습니다.",
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 시리즈 GPU에는 bf16-true를, 10+ 시리즈 GPU에는 16-mixed를 권장합니다",
"latest": "최신",
"new": "새로운",
"Realtime Transform Text": "실시간 텍스트 변환",
"Normalization Result Preview (Currently Only Chinese)": "정규화 결과 미리보기(현재 중국어만 지원)",
"Text Normalization": "텍스트 정규화",
"Select Example Audio": "예시 오디오 선택"
}
================================================
FILE: fish_speech/i18n/locale/pt_BR.json
================================================
{
"5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de áudio de referência, útil para especificar o orador.",
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Um modelo de texto para fala baseado em VQ-GAN e Llama desenvolvido por [Fish Audio](https://fish.audio).",
"Accumulate Gradient Batches": "Acumular Lotes de Gradiente",
"Add to Processing Area": "Adicionar à Área de Processamento",
"Added path successfully!": "Caminho adicionado com sucesso!",
"Advanced Config": "Configuração Avançada",
"Base LLAMA Model": "Modelo LLAMA Base",
"Batch Inference": "Inferência em Lote",
"Batch Size": "Tamanho do Lote",
"Changing with the Model Path": "Alterando com o Caminho do Modelo",
"Compile Model": "Compilar Modelo",
"Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar o modelo pode reduzir significativamente o tempo de inferência, mas aumentará a latência inicial",
"Copy": "Copiar",
"Data Preprocessing": "Pré-processamento de Dados",
"Data Preprocessing Path": "Caminho de Pré-processamento de Dados",
"Data Source": "Fonte de Dados",
"Decoder Model Config": "Configuração do Modelo Decodificador",
"Decoder Model Path": "Caminho do Modelo Decodificador",
"Disabled": "Desativado",
"Enable Initial Prompt": "Habilitar Prompt Inicial",
"Enable Reference Audio": "Habilitar Áudio de Referência",
"English": "Inglês",
"Japanese": "Japonês",
"Chinese": "Chinês",
"Portuguese": "Português",
"Spanish": "Espanhol",
"Error Message": "Mensagem de Erro",
"Faster Whisper, Up to 5g GPU memory usage": "Faster Whisper (Usa até 5 GB de vRAM)",
"File Preprocessing": "Pré-processamento de Arquivos",
"Generate": "Gerar",
"Generated Audio": "Áudio Gerado",
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Se não houver texto correspondente ao áudio, utilize o ASR para assistência (formatos .txt ou .lab)",
"Infer interface is closed": "A interface de inferência foi fechada",
"Inference Configuration": "Configuração de Inferência",
"Inference Server Configuration": "Configuração do Servidor de Inferência",
"Inference Server Error": "Erro do Servidor de Inferência",
"Inferring interface is launched at {}": "A interface de inferência foi iniciada em {}",
"Initial Learning Rate": "Taxa de Aprendizagem Inicial",
"Initial Prompt": "Prompt Inicial",
"Initial prompt can provide contextual or vocabulary-specific guidance to the model.": "O prompt inicial pode fornecer orientação contextual ou específica de vocabulário para o modelo.",
"Input Audio & Source Path for Transcription": "Entrada de Áudio/Caminho de Origem para Transcrição",
"Input Text": "Texto de Entrada",
"Invalid path: {}": "Caminho inválido: {}",
"It is recommended to use CUDA, if you have low configuration, use CPU": "Para GPUs Nvidia é recomendado usar CUDA. Se não tiver uma GPU Nvidia, use CPU",
"Iterative Prompt Length, 0 means off": "Comprimento do Prompt Iterativo (0 = desativado)",
"LLAMA Configuration": "Configuração do LLAMA",
"LLAMA Model Config": "Configuração do Modelo LLAMA",
"LLAMA Model Path": "Caminho do Modelo LLAMA",
"Labeling Device": "Dispositivo de Rotulagem",
"LoRA Model to be merged": "Modelo LoRA para mesclagem",
"Maximum Length per Sample": "Comprimento Máximo por Amostra",
"Maximum Training Steps": "Etapas Máximas de Treinamento",
"Maximum tokens per batch, 0 means no limit": "Número máximo de tokens por lote, 0 significa sem limite",
"Merge": "Mesclar",
"Merge LoRA": "Mesclar LoRA",
"Merge successfully": "Mesclado com sucesso",
"Model Output Path": "Caminho de Saída do Modelo",
"Model Quantization": "Quantização do Modelo",
"Model Size": "Tamanho do Modelo",
"Move": "Mover",
"Move files successfully": "Arquivos movidos com sucesso",
"No audio generated, please check the input text.": "Nenhum áudio gerado, verifique o texto de entrada.",
"No selected options": "Nenhuma opção selecionada",
"Normalization Result Preview (Currently Only Chinese)": "Pré-visualização do Resultado da Normalização (Atualmente Apenas Chinês)",
"Number of Workers": "Número de Processos",
"Open Inference Server": "Abrir Servidor de Inferência",
"Open Labeler WebUI": "Abrir WebUI de Rotulagem",
"Open Tensorboard": "Abrir Tensorboard",
"Opened labeler in browser": "WebUI de rotulagem aberta no navegador",
"Optional Label Language": "Idioma do Rótulo (Opcional)",
"Optional online ver": "Versão online (opcional)",
"Output Path": "Caminho de Saída",
"Path error, please check the model file exists in the corresponding path": "Erro de caminho, verifique se o arquivo do modelo existe no caminho correspondente",
"Post-quantification Precision": "Precisão Pós-quantização",
"Precision": "Precisão",
"Probability of applying Speaker Condition": "Probabilidade de Aplicar Condição de Orador",
"Put your text here.": "Insira seu texto aqui.",
"Quantify": "Quantizar",
"Quantify successfully": "Quantizado com sucesso",
"Realtime Transform Text": "Transformar Texto em Tempo Real",
"Reference Audio": "Áudio de Referência",
"Reference Text": "Texto de Referência",
"warning": "Aviso",
"Pre-processing begins...": "O pré-processamento começou!",
"Related code and weights are released under FISH AUDIO RESEARCH LICENSE.": "O código relacionado e os pesos são licenciados sob a FISH AUDIO RESEARCH LICENSE.",
"Remove Selected Data": "Remover Dados Selecionados",
"Removed path successfully!": "Caminho removido com sucesso!",
"Repetition Penalty": "Penalidade de Repetição",
"Save model every n steps": "Salvar modelo a cada n etapas",
"Select LLAMA ckpt": "Selecionar .ckpt do LLAMA",
"Select source file processing method": "Escolha como processar o arquivo de origem",
"Select the model to be trained (Depending on the Tab page you are on)": "Selecione o modelo para o treinamento (dependendo da aba em que você está)",
"Selected: {}": "Selecionado: {}",
"Speaker is identified by the folder name": "O orador é identificado pelo nome da pasta",
"Start Training": "Iniciar Treinamento",
"Streaming Audio": "Áudio em Streaming",
"Streaming Generate": "Geração em Streaming",
"Tensorboard Host": "Host do Tensorboard",
"Tensorboard Log Path": "Caminho de Log do Tensorboard",
"Tensorboard Port": "Porta do Tensorboard",
"Tensorboard interface is closed": "A interface do Tensorboard está fechada",
"Tensorboard interface is launched at {}": "A interface do Tensorboard foi iniciada em {}",
"Text Normalization": "Normalização de Texto",
"Text is too long, please keep it under {} characters.": "O texto é muito longo. Mantenha-o com menos de {} caracteres.",
"The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase": "Quanto menor a precisão quantitativa, mais a eficácia pode diminuir, mas maior será o aumento da eficiência",
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "O caminho da pasta de entrada à esquerda ou a lista de arquivos. Independentemente de estar marcada ou não, ela será utilizada para o treinamento subsequente nesta lista.",
"Training Configuration": "Configuração de Treinamento",
"Training Error": "Erro de Treinamento",
"Training stopped": "Treinamento interrompido!",
"Type the path or select from the dropdown": "Digite o caminho ou selecione no menu suspenso",
"Use LoRA": "Usar LoRA",
"Use LoRA can save GPU memory, but may reduce the quality of the model": "O uso de LoRAs pode economizar memória da GPU, mas também pode reduzir a qualidade",
"Use filelist": "Usar lista de arquivos",
"VQGAN Configuration": "Configuração do VQGAN",
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Visualizar o status da pasta de pré-processamento (use o controle deslizante para controlar a profundidade da árvore)",
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "Não nos responsabilizamos por qualquer uso indevido do modelo. Por favor, considere as leis e regulamentações locais antes de usá-lo.",
"WebUI Host": "Host da WebUI",
"WebUI Port": "Porta da WebUI",
"Whisper Model": "Modelo Whisper",
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5).": "Você pode encontrar o código fonte [aqui](https://github.com/fishaudio/fish-speech) e os modelos [aqui](https://huggingface.co/fishaudio/fish-speech-1.5).",
"auto": "automático",
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true é recomendado para GPUs da série 30+, 16-mixed é recomendado para GPUs da série 10+",
"latest": "mais recente",
"new": "novo",
"This audio introduces the basic concepts and applications of artificial intelligence and machine learning.": "Este áudio introduz os conceitos básicos e aplicações de inteligência artificial e aprendizado de máquina.",
"You don't need to train this model!": "Não é necessário treinar este modelo!",
"Yes": "Sim",
"No": "Não",
"version:": "versão:",
"author:": "autor:"
}
================================================
FILE: fish_speech/i18n/locale/zh_CN.json
================================================
{
"16-mixed is recommended for 10+ series GPU": "10+ 系列 GPU 建议使用 16-mixed",
"5 to 10 seconds of reference audio, useful for specifying speaker.": "5 到 10 秒的参考音频,适用于指定音色。",
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.",
"Accumulate Gradient Batches": "梯度累积批次",
"Add to Processing Area": "加入处理区",
"Added path successfully!": "添加路径成功!",
"Advanced Config": "高级参数",
"Base LLAMA Model": "基础 LLAMA 模型",
"Batch Inference": "批量推理",
"Batch Size": "批次大小",
"Changing with the Model Path": "随模型路径变化",
"Chinese": "中文",
"Compile Model": "编译模型",
"Compile the model can significantly reduce the inference time, but will increase cold start time": "编译模型可以显著减少推理时间,但会增加冷启动时间",
"Copy": "复制",
"Data Preprocessing": "数据预处理",
"Data Preprocessing Path": "数据预处理路径",
"Data Source": "数据源",
"Decoder Model Config": "解码器模型配置",
"Decoder Model Path": "解码器模型路径",
"Disabled": "禁用",
"Enable Reference Audio": "启用参考音频",
"English": "英文",
"Error Message": "错误信息",
"File Preprocessing": "文件预处理",
"Generate": "生成",
"Generated Audio": "音频",
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "如果音频没有对应的文本,可以应用 ASR 辅助,支持 .txt 或 .lab 格式",
"Infer interface is closed": "推理界面已关闭",
"Inference Configuration": "推理配置",
"Inference Server Configuration": "推理服务器配置",
"Inference Server Error": "推理服务器错误",
"Inferring interface is launched at {}": "推理界面已在 {} 上启动",
"Initial Learning Rate": "初始学习率",
"Input Audio & Source Path for Transcription": "输入音频和转录源路径",
"Input Text": "输入文本",
"Invalid path: {}": "无效路径: {}",
"It is recommended to use CUDA, if you have low configuration, use CPU": "建议使用 CUDA,如果配置较低,使用 CPU",
"Iterative Prompt Length, 0 means off": "迭代提示长度,0 表示关闭",
"Japanese": "日文",
"LLAMA Configuration": "LLAMA 配置",
"LLAMA Model Config": "LLAMA 模型配置",
"LLAMA Model Path": "LLAMA 模型路径",
"Labeling Device": "标注加速设备",
"LoRA Model to be merged": "要合并的 LoRA 模型",
"Maximum Audio Duration": "最大音频时长",
"Maximum Length per Sample": "每个样本的最大长度",
"Maximum Training Steps": "最大训练步数",
"Maximum tokens per batch, 0 means no limit": "每批最大令牌数,0 表示无限制",
"Merge": "合并",
"Merge LoRA": "合并 LoRA",
"Merge successfully": "合并成功",
"Minimum Audio Duration": "最小音频时长",
"Model Output Path": "模型输出路径",
"Model Size": "模型规模",
"Move": "移动",
"Move files successfully": "移动文件成功",
"No audio generated, please check the input text.": "没有生成音频,请检查输入文本.",
"No selected options": "没有选择的选项",
"Number of Workers": "数据加载进程数",
"Open Inference Server": "打开推理服务器",
"Open Labeler WebUI": "打开标注工具",
"Open Tensorboard": "打开 Tensorboard",
"Opened labeler in browser": "在浏览器中打开标注工具",
"Optional Label Language": "[可选] 标注语言",
"Optional online ver": "[可选] 使用在线版",
"Output Path": "输出路径",
"Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径",
"Precision": "精度",
"Probability of applying Speaker Condition": "应用说话人条件的概率",
"Put your text here.": "在此处输入文本.",
"Reference Audio": "参考音频",
"Reference Text": "参考文本",
"Related code and weights are released under FISH AUDIO RESEARCH LICENSE.": "相关代码和权重使用 FISH AUDIO RESEARCH LICENSE 许可证发布.",
"Remove Selected Data": "移除选中数据",
"Removed path successfully!": "移除路径成功!",
"Repetition Penalty": "重复惩罚",
"Save model every n steps": "每 n 步保存模型",
"Select LLAMA ckpt": "选择 LLAMA 检查点",
"Select VITS ckpt": "选择 VITS 检查点",
"Select VQGAN ckpt": "选择 VQGAN 检查点",
"Select source file processing method": "选择源文件处理方法",
"Select the model to be trained (Depending on the Tab page you are on)": "根据您所在的选项卡页面选择要训练的模型",
"Selected: {}": "已选择: {}",
"Speaker": "说话人",
"Speaker is identified by the folder name": "自动根据父目录名称识别说话人",
"Start Training": "开始训练",
"Streaming Audio": "流式音频",
"Streaming Generate": "流式合成",
"Tensorboard Host": "Tensorboard 监听地址",
"Tensorboard Log Path": "Tensorboard 日志路径",
"Tensorboard Port": "Tensorboard 端口",
"Tensorboard interface is closed": "Tensorboard 界面已关闭",
"Tensorboard interface is launched at {}": "Tensorboard 界面已在 {} 上启动",
"Text is too long, please keep it under {} characters.": "文本太长,请保持在 {} 个字符以内.",
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左侧输入文件夹的路径或文件列表。无论是否选中,都将在此列表中用于后续训练.",
"Training Configuration": "训练配置",
"Training Error": "训练错误",
"Training stopped": "训练已停止",
"Type name of the speaker": "输入说话人的名称",
"Type the path or select from the dropdown": "输入路径或从下拉菜单中选择",
"Use LoRA": "使用 LoRA",
"Use LoRA can save GPU memory, but may reduce the quality of the model": "使用 LoRA 可以节省 GPU 内存,但可能会降低模型质量",
"Use filelist": "使用文件列表",
"Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 使用 large, 5G 使用 medium, 2G 使用 small",
"VITS Configuration": "VITS 配置",
"VQGAN Configuration": "VQGAN 配置",
"Validation Batch Size": "验证批次大小",
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "查看预处理文件夹的状态 (使用滑块控制树的深度)",
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.",
"WebUI Host": "WebUI 监听地址",
"WebUI Port": "WebUI 端口",
"Whisper Model": "Whisper 模型",
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5).": "你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1.5) 找到模型.",
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 系列 GPU 建议使用 bf16-true, 10+ 系列 GPU 建议使用 16-mixed",
"latest": "最近的检查点",
"new": "创建新的检查点",
"Realtime Transform Text": "实时规范化文本",
"Normalization Result Preview (Currently Only Chinese)": "规范化结果预览",
"Text Normalization": "文本规范化",
"Select Example Audio": "选择参考音频"
}
================================================
FILE: fish_speech/i18n/scan.py
================================================
import ast
import glob
import json
from collections import OrderedDict
from pathlib import Path
from loguru import logger
from .core import DEFAULT_LANGUAGE, I18N_FILE_PATH
def extract_i18n_strings(node):
i18n_strings = []
if (
isinstance(node, ast.Call)
and isinstance(node.func, ast.Name)
and node.func.id == "i18n"
):
for arg in node.args:
if isinstance(arg, ast.Str):
i18n_strings.append(arg.s)
for child_node in ast.iter_child_nodes(node):
i18n_strings.extend(extract_i18n_strings(child_node))
return i18n_strings
# scan the directory for all .py files (recursively)
# for each file, parse the code into an AST
# for each AST, extract the i18n strings
strings = []
folders = ["fish_speech", "tools"]
# for filename in glob.iglob("**/*.py", recursive=True):
for folder in folders:
for f in Path(folder).rglob("*.py"):
code = f.read_text(encoding="utf-8")
if "i18n(" in code:
tree = ast.parse(code)
i18n_strings = extract_i18n_strings(tree)
logger.info(f"Found {len(i18n_strings)} i18n strings in {f}")
strings.extend(i18n_strings)
code_keys = set(strings)
logger.info(f"Total unique: {len(code_keys)}")
standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
with open(standard_file, "r", encoding="utf-8") as f:
standard_data = json.load(f, object_pairs_hook=OrderedDict)
standard_keys = set(standard_data.keys())
# Define the standard file name
unused_keys = standard_keys - code_keys
logger.info(f"Found {len(unused_keys)} unused keys in {standard_file}")
for unused_key in unused_keys:
logger.info(f"\t{unused_key}")
missing_keys = code_keys - standard_keys
logger.info(f"Found {len(missing_keys)} missing keys in {standard_file}")
for missing_key in missing_keys:
logger.info(f"\t{missing_key}")
code_keys_dict = OrderedDict()
for s in strings:
code_keys_dict[s] = s
# write back
with open(standard_file, "w", encoding="utf-8") as f:
json.dump(code_keys_dict, f, ensure_ascii=False, indent=4, sort_keys=True)
f.write("\n")
logger.info(f"Updated {standard_file}")
# Define the standard file name
standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
# Find all JSON files in the directory
dir_path = I18N_FILE_PATH
languages = [f for f in dir_path.glob("*.json") if f.stem != DEFAULT_LANGUAGE]
# Load the standard file
with open(standard_file, "r", encoding="utf-8") as f:
standard_data = json.load(f, object_pairs_hook=OrderedDict)
# Loop through each language file
for lang_file in languages:
# Load the language file
with open(lang_file, "r", encoding="utf-8") as f:
lang_data = json.load(f, object_pairs_hook=OrderedDict)
# Find the difference between the language file and the standard file
diff = set(standard_data.keys()) - set(lang_data.keys())
miss = set(lang_data.keys()) - set(standard_data.keys())
# Add any missing keys to the language file
for key in diff:
lang_data[key] = "#!" + key
logger.info(f"Added missing key: {key} to {lang_file}")
# Del any extra keys to the language file
for key in miss:
del lang_data[key]
logger.info(f"Del extra key: {key} from {lang_file}")
# Sort the keys of the language file to match the order of the standard file
lang_data = OrderedDict(
sorted(lang_data.items(), key=lambda x: list(standard_data.keys()).index(x[0]))
)
# Save the updated language file
with open(lang_file, "w", encoding="utf-8") as f:
json.dump(lang_data, f, ensure_ascii=False, indent=4, sort_keys=True)
f.write("\n")
logger.info(f"Updated {lang_file}")
logger.info("Done")
================================================
FILE: fish_speech/inference_engine/__init__.py
================================================
import gc
import queue
from typing import Generator
import numpy as np
import torch
from loguru import logger
from fish_speech.inference_engine.reference_loader import ReferenceLoader
from fish_speech.inference_engine.utils import InferenceResult, wav_chunk_header
from fish_speech.inference_engine.vq_manager import VQManager
from fish_speech.models.dac.modded_dac import DAC
from fish_speech.models.text2semantic.inference import (
GenerateRequest,
GenerateResponse,
WrappedGenerateResponse,
)
from fish_speech.utils import autocast_exclude_mps, set_seed
from fish_speech.utils.schema import ServeTTSRequest
class TTSInferenceEngine(ReferenceLoader, VQManager):
def __init__(
self,
llama_queue: queue.Queue,
decoder_model: DAC,
precision: torch.dtype,
compile: bool,
) -> None:
super().__init__()
self.llama_queue = llama_queue
self.decoder_model = decoder_model
self.precision = precision
self.compile = compile
@torch.inference_mode()
def inference(self, req: ServeTTSRequest) -> Generator[InferenceResult, None, None]:
"""
Main inference function:
- Loads the reference audio and text.
- Calls the LLAMA model for inference.
- Decodes the VQ tokens to audio.
"""
ref_id: str | None = req.reference_id
prompt_tokens, prompt_texts = [], []
# Load the reference audio and text based on id or hash
if ref_id is not None:
prompt_tokens, prompt_texts = self.load_by_id(ref_id, req.use_memory_cache)
elif req.references:
prompt_tokens, prompt_texts = self.load_by_hash(
req.references, req.use_memory_cache
)
# Set the random seed if provided
if req.seed is not None:
set_seed(req.seed)
logger.warning(f"set seed: {req.seed}")
# Get the symbolic tokens from the LLAMA model
response_queue = self.send_Llama_request(req, prompt_tokens, prompt_texts)
# Get the sample rate from the decoder model
if hasattr(self.decoder_model, "spec_transform"):
sample_rate = self.decoder_model.spec_transform.sample_rate
else:
sample_rate = self.decoder_model.sample_rate
# If streaming, send the header
if req.streaming:
yield InferenceResult(
code="header",
audio=(
sample_rate,
np.array(wav_chunk_header(sample_rate=sample_rate)),
),
error=None,
)
segments = []
while True:
# Get the response from the LLAMA model
wrapped_result: WrappedGenerateResponse = response_queue.get()
if wrapped_result.status == "error":
yield InferenceResult(
code="error",
audio=None,
error=(
wrapped_result.response
if isinstance(wrapped_result.response, Exception)
else Exception("Unknown error")
),
)
break
# Check the response type
if not isinstance(wrapped_result.response, GenerateResponse):
raise TypeError(
"Expected GenerateResponse, got {type(wrapped_result.response).__name__}"
)
result: GenerateResponse = wrapped_result.response
if result.action != "next":
segment = self.get_audio_segment(result)
if req.streaming: # Used only by the API server
yield InferenceResult(
code="segment",
audio=(sample_rate, segment),
error=None,
)
segments.append(segment)
else:
break
# Clean up the memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Edge case: no audio generated
if len(segments) == 0:
yield InferenceResult(
code="error",
audio=None,
error=RuntimeError("No audio generated, please check the input text."),
)
else:
# Streaming or not, return the final audio
audio = np.concatenate(segments, axis=0)
yield InferenceResult(
code="final",
audio=(sample_rate, audio),
error=None,
)
return None
def send_Llama_request(
self, req: ServeTTSRequest, prompt_tokens: list, prompt_texts: list
) -> queue.Queue:
"""
Send a request to the LLAMA model to generate the symbolic tokens.
"""
# Prepare the request
request = dict(
device=self.decoder_model.device,
max_new_tokens=req.max_new_tokens,
text=req.text,
top_p=req.top_p,
repetition_penalty=req.repetition_penalty,
temperature=req.temperature,
compile=self.compile,
iterative_prompt=req.chunk_length > 0,
chunk_length=req.chunk_length,
prompt_tokens=prompt_tokens,
prompt_text=prompt_texts,
)
# Create a queue to get the response
response_queue = queue.Queue()
# Send the request to the LLAMA model
self.llama_queue.put(
GenerateRequest(
request=request,
response_queue=response_queue,
)
)
return response_queue
def get_audio_segment(self, result: GenerateResponse) -> np.ndarray:
"""
Decode the VQ tokens to audio.
"""
# Don't use autocast on MPS devices
with autocast_exclude_mps(
device_type=self.decoder_model.device.type, dtype=self.precision
):
# Decode the symbolic tokens to audio
segment = self.decode_vq_tokens(codes=result.codes)
# Convert the audio to numpy
return segment.float().cpu().numpy()
================================================
FILE: fish_speech/inference_engine/reference_loader.py
================================================
import io
from hashlib import sha256
from pathlib import Path
from typing import Callable, Literal, Tuple
import torch
import torchaudio
from loguru import logger
from fish_speech.models.dac.modded_dac import DAC
from fish_speech.utils.file import (
AUDIO_EXTENSIONS,
audio_to_bytes,
list_files,
read_ref_text,
)
from fish_speech.utils.schema import ServeReferenceAudio
class ReferenceLoader:
def __init__(self) -> None:
"""
Component of the TTSInferenceEngine class.
Loads and manages the cache for the reference audio and text.
"""
self.ref_by_id: dict = {}
self.ref_by_hash: dict = {}
# Make Pylance happy (attribut/method not defined...)
self.decoder_model: DAC
self.encode_reference: Callable
# Define the torchaudio backend
backends = torchaudio.list_audio_backends()
if "ffmpeg" in backends:
self.backend = "ffmpeg"
else:
self.backend = "soundfile"
def load_by_id(
self,
id: str,
use_cache: Literal["on", "off"],
) -> Tuple:
# Load the references audio and text by id
ref_folder = Path("references") / id
ref_folder.mkdir(parents=True, exist_ok=True)
ref_audios = list_files(
ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
)
if use_cache == "off" or id not in self.ref_by_id:
# If the references are not already loaded, encode them
prompt_tokens = [
self.encode_reference(
# decoder_model=self.decoder_model,
reference_audio=audio_to_bytes(str(ref_audio)),
enable_reference_audio=True,
)
for ref_audio in ref_audios
]
prompt_texts = [
read_ref_text(str(ref_audio.with_suffix(".lab")))
for ref_audio in ref_audios
]
self.ref_by_id[id] = (prompt_tokens, prompt_texts)
else:
# Reuse already encoded references
logger.info("Use same references")
prompt_tokens, prompt_texts = self.ref_by_id[id]
return prompt_tokens, prompt_texts
def load_by_hash(
self,
references: list[ServeReferenceAudio],
use_cache: Literal["on", "off"],
) -> Tuple:
# Load the references audio and text by hash
audio_hashes = [sha256(ref.audio).hexdigest() for ref in references]
cache_used = False
prompt_tokens, prompt_texts = [], []
for i, ref in enumerate(references):
if use_cache == "off" or audio_hashes[i] not in self.ref_by_hash:
# If the references are not already loaded, encode them
prompt_tokens.append(
self.encode_reference(
reference_audio=ref.audio,
enable_reference_audio=True,
)
)
prompt_texts.append(ref.text)
self.ref_by_hash[audio_hashes[i]] = (prompt_tokens[-1], ref.text)
else:
# Reuse already encoded references
cached_token, cached_text = self.ref_by_hash[audio_hashes[i]]
prompt_tokens.append(cached_token)
prompt_texts.append(cached_text)
cache_used = True
if cache_used:
logger.info("Use same references")
return prompt_tokens, prompt_texts
def load_audio(self, reference_audio: bytes | str, sr: int):
"""
Load the audio data from a file or bytes.
"""
if len(reference_audio) > 255 or not Path(reference_audio).exists():
audio_data = reference_audio
reference_audio = io.BytesIO(audio_data)
waveform, original_sr = torchaudio.load(reference_audio, backend=self.backend)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
if original_sr != sr:
resampler = torchaudio.transforms.Resample(
orig_freq=original_sr, new_freq=sr
)
waveform = resampler(waveform)
audio = waveform.squeeze().numpy()
return audio
def list_reference_ids(self) -> list[str]:
"""
List all valid reference IDs (subdirectory names containing valid audio and .lab files).
Returns:
list[str]: List of valid reference IDs
"""
ref_base_path = Path("references")
if not ref_base_path.exists():
return []
valid_ids = []
for ref_dir in ref_base_path.iterdir():
if not ref_dir.is_dir():
continue
# Check if directory contains at least one audio file and corresponding .lab file
audio_files = list_files(
ref_dir, AUDIO_EXTENSIONS, recursive=False, sort=False
)
if not audio_files:
continue
# Check if corresponding .lab file exists for at least one audio file
has_valid_pair = False
for audio_file in audio_files:
lab_file = audio_file.with_suffix(".lab")
if lab_file.exists():
has_valid_pair = True
break
if has_valid_pair:
valid_ids.append(ref_dir.name)
return sorted(valid_ids)
def add_reference(self, id: str, wav_file_path: str, reference_text: str) -> None:
"""
Add a new reference voice by creating a new directory and copying files.
Args:
id: Reference ID (directory name)
wav_file_path: Path to the audio file to copy
reference_text: Text content for the .lab file
Raises:
FileExistsError: If the reference ID already exists
FileNotFoundError: If the audio file doesn't exist
OSError: If file operations fail
"""
# Validate ID format
import re
if not re.match(r"^[a-zA-Z0-9\-_ ]+$", id):
raise ValueError(
"Reference ID contains invalid characters. Only alphanumeric, hyphens, underscores, and spaces are allowed."
)
if len(id) > 255:
raise ValueError(
"Reference ID is too long. Maximum length is 255 characters."
)
# Check if reference already exists
ref_dir = Path("references") / id
if ref_dir.exists():
raise FileExistsError(f"Reference ID '{id}' already exists")
# Check if audio file exists
audio_path = Path(wav_file_path)
if not audio_path.exists():
raise FileNotFoundError(f"Audio file not found: {wav_file_path}")
# Validate audio file extension
if audio_path.suffix.lower() not in AUDIO_EXTENSIONS:
raise ValueError(
f"Unsupported audio format: {audio_path.suffix}. Supported formats: {', '.join(AUDIO_EXTENSIONS)}"
)
try:
# Create reference directory
ref_dir.mkdir(parents=True, exist_ok=False)
# Determine the target audio filename with original extension
target_audio_path = ref_dir / f"sample{audio_path.suffix}"
# Copy audio file
import shutil
shutil.copy2(audio_path, target_audio_path)
# Create .lab file
lab_path = ref_dir / "sample.lab"
with open(lab_path, "w", encoding="utf-8") as f:
f.write(reference_text)
# Clear cache for this ID if it exists
if id in self.ref_by_id:
del self.ref_by_id[id]
logger.info(f"Successfully added reference voice with ID: {id}")
except Exception as e:
# Clean up on failure
if ref_dir.exists():
import shutil
shutil.rmtree(ref_dir)
raise e
def delete_reference(self, id: str) -> None:
"""
Delete a reference voice by removing its directory and files.
Args:
id: Reference ID (directory name) to delete
Raises:
FileNotFoundError: If the reference ID doesn't exist
OSError: If file operations fail
"""
# Check if reference exists
ref_dir = Path("references") / id
if not ref_dir.exists():
raise FileNotFoundError(f"Reference ID '{id}' does not exist")
try:
# Remove the entire reference directory
import shutil
shutil.rmtree(ref_dir)
# Clear cache for this ID if it exists
if id in self.ref_by_id:
del self.ref_by_id[id]
logger.info(f"Successfully deleted reference voice with ID: {id}")
except Exception as e:
logger.error(f"Failed to delete reference '{id}': {e}")
raise OSError(f"Failed to delete reference '{id}': {e}")
================================================
FILE: fish_speech/inference_engine/utils.py
================================================
import io
import wave
from dataclasses import dataclass
from typing import Literal, Optional, Tuple
import numpy as np
@dataclass
class InferenceResult:
code: Literal["header", "segment", "error", "final"]
audio: Optional[Tuple[int, np.ndarray]]
error: Optional[Exception]
def wav_chunk_header(
sample_rate: int = 44100, bit_depth: int = 16, channels: int = 1
) -> bytes:
buffer = io.BytesIO()
with wave.open(buffer, "wb") as wav_file:
wav_file.setnchannels(channels)
wav_file.setsampwidth(bit_depth // 8)
wav_file.setframerate(sample_rate)
wav_header_bytes = buffer.getvalue()
buffer.close()
return wav_header_bytes
================================================
FILE: fish_speech/inference_engine/vq_manager.py
================================================
from typing import Callable
import torch
from loguru import logger
from fish_speech.models.dac.modded_dac import DAC
class VQManager:
def __init__(self):
# Make Pylance happy (attribut/method not defined...)
self.decoder_model: DAC
self.load_audio: Callable
def decode_vq_tokens(self, codes):
logger.info(f"VQ features: {codes.shape}")
if isinstance(self.decoder_model, DAC):
return self.decoder_model.from_indices(codes[None])[0].squeeze()
raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
def encode_reference(self, reference_audio, enable_reference_audio):
if enable_reference_audio and reference_audio is not None:
# Load audios, and prepare basic info here
if hasattr(self.decoder_model, "spec_transform"):
sample_rate = self.decoder_model.spec_transform.sample_rate
else:
sample_rate = self.decoder_model.sample_rate
reference_audio_content = self.load_audio(reference_audio, sample_rate)
audios = torch.from_numpy(reference_audio_content).to(
self.decoder_model.device
)[None, None, :]
audio_lengths = torch.tensor(
[audios.shape[2]], device=self.decoder_model.device, dtype=torch.long
)
logger.info(
f"Loaded audio with {audios.shape[2] / sample_rate:.2f} seconds"
)
# VQ Encoder
if isinstance(self.decoder_model, DAC):
prompt_tokens = self.decoder_model.encode(audios, audio_lengths)[0][0]
logger.info(f"Encoded prompt: {prompt_tokens.shape}")
else:
raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
else:
prompt_tokens = None
logger.info("No reference audio provided")
return prompt_tokens
================================================
FILE: fish_speech/models/dac/__init__.py
================================================
================================================
FILE: fish_speech/models/dac/inference.py
================================================
from pathlib import Path
import click
import hydra
import numpy as np
import pyrootutils
import soundfile as sf
import torch
import torchaudio
from hydra import compose, initialize
from hydra.utils import instantiate
from loguru import logger
from omegaconf import OmegaConf
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from fish_speech.utils.file import AUDIO_EXTENSIONS
# register eval resolver
OmegaConf.register_new_resolver("eval", eval)
def load_model(config_name, checkpoint_path, device="cuda"):
hydra.core.global_hydra.GlobalHydra.instance().clear()
with initialize(version_base="1.3", config_path="../../configs"):
cfg = compose(config_name=config_name)
model = instantiate(cfg)
state_dict = torch.load(
checkpoint_path, map_location=device, mmap=True, weights_only=True
)
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
if any("generator" in k for k in state_dict):
state_dict = {
k.replace("generator.", ""): v
for k, v in state_dict.items()
if "generator." in k
}
result = model.load_state_dict(state_dict, strict=False, assign=True)
model.eval()
model.to(device)
logger.info(f"Loaded model: {result}")
return model
@torch.no_grad()
@click.command()
@click.option(
"--input-path",
"-i",
default="test.wav",
type=click.Path(exists=True, path_type=Path),
)
@click.option(
"--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
)
@click.option("--config-name", default="modded_dac_vq")
@click.option(
"--checkpoint-path",
default="checkpoints/openaudio-s1-mini/codec.pth",
)
@click.option(
"--device",
"-d",
default="cuda",
)
def main(input_path, output_path, config_name, checkpoint_path, device):
model = load_model(config_name, checkpoint_path, device=device)
if input_path.suffix in AUDIO_EXTENSIONS:
logger.info(f"Processing in-place reconstruction of {input_path}")
# Load audio
audio, sr = torchaudio.load(str(input_path))
if audio.shape[0] > 1:
audio = audio.mean(0, keepdim=True)
audio = torchaudio.functional.resample(audio, sr, model.sample_rate)
audios = audio[None].to(device)
logger.info(
f"Loaded audio with {audios.shape[2] / model.sample_rate:.2f} seconds"
)
# VQ Encoder
audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long)
indices, _ = model.encode(audios, audio_lengths)
if indices.ndim == 3:
indices = indices[0]
logger.info(f"Generated indices of shape {indices.shape}")
# Save indices
np.save(output_path.with_suffix(".npy"), indices.cpu().numpy())
elif input_path.suffix == ".npy":
logger.info(f"Processing precomputed indices from {input_path}")
indices = np.load(input_path)
indices = torch.from_numpy(indices).to(device).long()
assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
# indices_lens = torch.tensor([indices.shape[1]], device=device, dtype=torch.long)
else:
raise ValueError(f"Unknown input type: {input_path}")
# Restore
if indices.ndim == 2:
indices = indices.unsqueeze(0)
fake_audios = model.from_indices(indices)
audio_time = fake_audios.shape[-1] / model.sample_rate
logger.info(
f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
)
# Save audio
fake_audio = fake_audios[0, 0].float().cpu().numpy()
sf.write(output_path, fake_audio, model.sample_rate)
logger.info(f"Saved audio to {output_path}")
if __name__ == "__main__":
main()
================================================
FILE: fish_speech/models/dac/modded_dac.py
================================================
import math
import typing as tp
from dataclasses import dataclass
from typing import List, Optional, Union
import numpy as np
import torch
from audiotools import AudioSignal
from audiotools.ml import BaseModel
from dac.model.base import CodecMixin
from dac.nn.layers import Snake1d, WNConv1d, WNConvTranspose1d
from torch import Tensor, nn
from torch.nn import functional as F
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
@dataclass
class VQResult:
z: torch.Tensor
codes: torch.Tensor
latents: torch.Tensor
codebook_loss: torch.Tensor
commitment_loss: torch.Tensor
semantic_distill_z: torch.Tensor | None = None
def find_multiple(n: int, k: int) -> int:
if n % k == 0:
return n
return n + k - (n % k)
@dataclass
class ModelArgs:
block_size: int = 2048
n_layer: int = 8
n_head: int = 8
dim: int = 512
intermediate_size: int = 1536
n_local_heads: int = -1
head_dim: int = 64
rope_base: float = 10000
norm_eps: float = 1e-5
dropout_rate: float = 0.1
attn_dropout_rate: float = 0.1
channels_first: bool = True # to be compatible with conv1d input/output
pos_embed_type: str = "rope" # can be "rope" or "conformer"
max_relative_position: int = 128 # for conformer-style relative position embedding
window_size: int = 512 # for window limited attention
def __post_init__(self):
if self.n_local_heads == -1:
self.n_local_heads = self.n_head
if self.intermediate_size is None:
hidden_dim = 4 * self.dim
n_hidden = int(2 * hidden_dim / 3)
self.intermediate_size = find_multiple(n_hidden, 256)
assert self.pos_embed_type in [
"rope",
"conformer",
], "pos_embed_type must be either 'rope' or 'conformer'"
class KVCache(nn.Module):
def __init__(
self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16
):
super().__init__()
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
def update(self, input_pos, k_val, v_val):
# input_pos: [S], k_val: [B, H, S, D]
assert input_pos.shape[0] == k_val.shape[2]
k_out = self.k_cache
v_out = self.v_cache
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val
return (
k_out[:, :, : input_pos.max() + 1, :],
v_out[:, :, : input_pos.max() + 1, :],
)
def clear_cache(self, prompt_len):
self.k_cache[:, :, prompt_len:, :] = torch.zeros_like(
self.k_cache[:, :, prompt_len:, :]
)
self.v_cache[:, :, prompt_len:, :] = torch.zeros_like(
self.v_cache[:, :, prompt_len:, :]
)
class Transformer(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.config = config
self.layers = nn.ModuleList(
TransformerBlock(config) for _ in range(config.n_layer)
)
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
# Only compute RoPE frequencies if using RoPE
if config.pos_embed_type == "rope":
freqs_cis = precompute_freqs_cis(
327680, self.config.head_dim, self.config.rope_base
)
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
else:
self.register_buffer("freqs_cis", None)
causal_mask = torch.tril(torch.ones(32768, 32768, dtype=torch.bool))
self.register_buffer("causal_mask", causal_mask, persistent=False)
self.max_batch_size = -1
self.max_seq_length = -1
self.use_kv_cache = False
def setup_caches(self, max_batch_size, max_seq_length):
"""
This method will only be called during inference when using KV cache.
"""
head_dim = self.config.dim // self.config.n_head
max_seq_length = find_multiple(max_seq_length, 8)
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
dtype = self.norm.weight.dtype
device = self.norm.weight.device
for b in self.layers:
b.attention.kv_cache = KVCache(
max_batch_size,
max_seq_length,
self.config.n_local_heads,
head_dim,
dtype,
).to(device)
self.use_kv_cache = True
def forward(
self,
x: Tensor,
input_pos: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
) -> Tensor:
if self.config.pos_embed_type == "rope":
assert (
self.freqs_cis is not None
), "RoPE frequencies must be initialized for RoPE positional embedding"
# print("MAX", input_pos.max())
freqs_cis = self.freqs_cis[input_pos]
else:
freqs_cis = None
if mask is None: # in case of non-causal model
if not self.training and self.use_kv_cache:
mask = self.causal_mask[None, None, input_pos]
mask = mask[..., : input_pos.max() + 1]
else:
mask = self.causal_mask[None, None, input_pos]
mask = mask[..., input_pos]
for i, layer in enumerate(self.layers):
x = layer(x, input_pos, freqs_cis, mask)
x = self.norm(x)
return x
class TransformerBlock(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.attention = Attention(config)
self.feed_forward = FeedForward(config)
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.attention_layer_scale = LayerScale(config.dim, inplace=True)
self.ffn_layer_scale = LayerScale(config.dim, inplace=True)
def forward(
self,
x: Tensor,
input_pos: Tensor,
freqs_cis: Tensor,
mask: Tensor,
) -> Tensor:
h = x + self.attention_layer_scale(
self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
)
out = h + self.ffn_layer_scale(self.feed_forward(self.ffn_norm(h)))
return out
class Attention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
assert config.dim % config.n_head == 0
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
# key, query, value projections for all heads, but in a batch
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
self.kv_cache = None
self.n_head = config.n_head
self.head_dim = config.head_dim
self.n_local_heads = config.n_local_heads
self.dim = config.dim
self.attn_dropout_rate = config.attn_dropout_rate
self.pos_embed_type = config.pos_embed_type
# Add relative position embedding for conformer-style
if self.pos_embed_type == "conformer":
self.max_relative_position = config.max_relative_position
num_pos_embeddings = 2 * config.max_relative_position + 1
self.rel_pos_embeddings = nn.Parameter(
torch.zeros(num_pos_embeddings, self.head_dim)
)
nn.init.normal_(self.rel_pos_embeddings, mean=0.0, std=0.02)
def _compute_conformer_pos_scores(self, q: Tensor, seqlen: int) -> Tensor:
# q: [B, H, S, D]
# Returns: [B, H, S, S]
positions = torch.arange(seqlen, device=q.device)
relative_positions = positions.unsqueeze(1) - positions.unsqueeze(0) # [S, S]
relative_positions = torch.clamp(
relative_positions + self.max_relative_position,
0,
2 * self.max_relative_position,
)
rel_embeddings = self.rel_pos_embeddings[relative_positions] # [S, S, D]
# Compute attention scores with relative position embeddings
q = q.transpose(1, 2) # [B, S, H, D]
rel_logits = torch.matmul(q, rel_embeddings.transpose(-2, -1)) # [B, S, H, S]
rel_logits = rel_logits.transpose(1, 2) # [B, H, S, S]
return rel_logits
def forward(
self,
x: Tensor,
freqs_cis: Tensor,
mask: Tensor,
input_pos: Optional[Tensor] = None,
) -> Tensor:
bsz, seqlen, _ = x.shape
kv_size = self.n_local_heads * self.head_dim
q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
context_seqlen = seqlen
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
if self.pos_embed_type == "rope":
q = apply_rotary_emb(q, freqs_cis)
k = apply_rotary_emb(k, freqs_cis)
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
if self.kv_cache is not None:
k, v = self.kv_cache.update(input_pos, k, v)
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
if self.pos_embed_type == "conformer":
# Compute attention scores
scale = 1.0 / math.sqrt(self.head_dim)
scores = torch.matmul(q, k.transpose(-2, -1)) * scale
# Add relative position embeddings for conformer-style
rel_scores = self._compute_conformer_pos_scores(q, seqlen)
scores = scores + rel_scores
# Apply attention
if mask is not None:
scores = scores.masked_fill(~mask, float("-inf"))
attn = F.softmax(scores, dim=-1)
if self.attn_dropout_rate > 0 and self.training:
attn = F.dropout(attn, p=self.attn_dropout_rate)
y = torch.matmul(attn, v)
else:
y = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.attn_dropout_rate if self.training else 0.0,
attn_mask=mask,
)
# is_causal=True)
y = (
y.transpose(1, 2)
.contiguous()
.view(bsz, seqlen, self.head_dim * self.n_head)
)
y = self.wo(y)
return y
class FeedForward(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, x: Tensor) -> Tensor:
return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
def forward(self, x: Tensor) -> Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight
class LayerScale(nn.Module):
def __init__(
self,
dim: int,
init_values: Union[float, Tensor] = 1e-2,
inplace: bool = False,
) -> None:
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x: Tensor) -> Tensor:
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class WindowLimitedTransformer(Transformer):
"""
Transformer with window limited attention, causal.
"""
def __init__(
self,
config: ModelArgs,
input_dim: int = 512,
window_size: Optional[int] = None,
causal: bool = True,
look_ahead_conv: nn.Module = None,
):
super().__init__(config)
self.window_size = window_size
self.causal = causal
self.channels_first = config.channels_first
self.look_ahead_conv = (
look_ahead_conv if look_ahead_conv is not None else nn.Identity()
)
self.input_proj = (
nn.Linear(input_dim, config.dim)
if input_dim != config.dim
else nn.Identity()
)
self.output_proj = (
nn.Linear(config.dim, input_dim)
if input_dim != config.dim
else nn.Identity()
)
def make_window_limited_mask(
self,
max_length: int,
x_lens: Optional[Tensor] = None,
) -> Tensor:
"""
Make mask to form window limited attention.
"""
if self.causal:
mask = torch.tril(torch.ones(max_length, max_length))
row_indices = torch.arange(max_length).view(-1, 1)
window_size = self.window_size or max_length
valid_range = (row_indices - window_size + 1).clamp(min=0)
column_indices = torch.arange(max_length)
mask = (column_indices >= valid_range) & mask.bool()
else:
raise NotImplementedError
mask = mask.bool()[None, None]
return mask
def make_mask(
self,
max_length: int,
x_lens: Optional[Tensor] = None,
) -> Tensor:
"""
Make ordinary mask if window size is not specified.
"""
if self.causal:
mask = torch.tril(torch.ones(max_length, max_length))
else:
mask = torch.ones(max_length, max_length)
mask = mask.bool()[None, None]
for i, x_len in enumerate(x_lens):
mask[:x_len, i] = 0
mask = mask.bool()[None, None]
return mask
def forward(
self,
x: Tensor,
x_lens: Optional[Tensor] = None,
) -> Tensor:
if self.channels_first:
x = x.transpose(1, 2)
x = self.input_proj(x) # (B, T, D)
x = self.look_ahead_conv(x)
input_pos = torch.arange(x.shape[1], device=x.device)
# construct mask to form window limited attention
max_length = x.shape[1]
if self.window_size is not None:
mask = self.make_window_limited_mask(max_length, x_lens)
else:
mask = self.make_mask(max_length, x_lens)
mask = mask.to(x.device)
x = super().forward(x, input_pos, mask)
x = self.output_proj(x) # (B, T, D)
if self.channels_first:
x = x.transpose(1, 2)
return x
def precompute_freqs_cis(
seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16
) -> Tensor:
freqs = 1.0 / (
base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
)
t = torch.arange(seq_len, device=freqs.device)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
return cache.to(dtype=dtype)
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
x_out2 = torch.stack(
[
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
],
-1,
)
x_out2 = x_out2.flatten(3)
return x_out2.type_as(x)
def init_weights(m):
if isinstance(m, nn.Conv1d):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
assert (padding_left + padding_right) <= x.shape[-1]
end = x.shape[-1] - padding_right
return x[..., padding_left:end]
def get_extra_padding_for_conv1d(
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
) -> int:
"""See `pad_for_conv1d`."""
length = x.shape[-1]
n_frames = (length - kernel_size + padding_total) / stride + 1
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
return ideal_length - length
def pad1d(
x: torch.Tensor,
paddings: tp.Tuple[int, int],
mode: str = "zeros",
value: float = 0.0,
):
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
If this is the case, we insert extra 0 padding to the right
before the reflection happen.
"""
length = x.shape[-1]
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
if mode == "reflect":
max_pad = max(padding_left, padding_right)
extra_pad = 0
if length <= max_pad:
extra_pad = max_pad - length + 1
x = F.pad(x, (0, extra_pad))
padded = F.pad(x, paddings, mode, value)
end = padded.shape[-1] - extra_pad
return padded[..., :end]
else:
return F.pad(x, paddings, mode, value)
class CausalConvNet(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
dilation=1,
stride=1,
groups=1,
padding=None,
):
super(CausalConvNet, self).__init__()
self.conv = nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride=stride,
dilation=dilation,
groups=groups,
)
self.stride = stride
self.kernel_size = (kernel_size - 1) * dilation + 1
self.dilation = dilation
self.padding = self.kernel_size - self.stride
def forward(self, x):
pad = self.padding
extra_padding = get_extra_padding_for_conv1d(
x, self.kernel_size, self.stride, pad
)
x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
return self.conv(x).contiguous()
def weight_norm(self, name="weight", dim=0):
self.conv = weight_norm(self.conv, name=name, dim=dim)
return self
def remove_weight_norm(self):
self.conv = remove_parametrizations(self.conv)
return self
class CausalTransConvNet(nn.Module):
def __init__(
self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None
):
super(CausalTransConvNet, self).__init__()
self.conv = nn.ConvTranspose1d(
in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
)
self.stride = stride
self.kernel_size = kernel_size
def forward(self, x):
x = self.conv(x)
pad = self.kernel_size - self.stride
padding_right = math.ceil(pad)
padding_left = pad - padding_right
x = unpad1d(x, (padding_left, padding_right))
return x.contiguous()
def weight_norm(self, name="weight", dim=0):
self.conv = weight_norm(self.conv, name=name, dim=dim)
return self
def remove_weight_norm(self):
self.conv = remove_parametrizations(self.conv)
return self
def CausalWNConv1d(*args, **kwargs):
return CausalConvNet(*args, **kwargs).weight_norm()
def CausalWNConvTranspose1d(*args, **kwargs):
return CausalTransConvNet(*args, **kwargs).weight_norm()
class ResidualUnit(nn.Module):
def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False):
super().__init__()
conv_class = CausalWNConv1d if causal else WNConv1d
pad = ((7 - 1) * dilation) // 2
self.block = nn.Sequential(
Snake1d(dim),
conv_class(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
Snake1d(dim),
conv_class(dim, dim, kernel_size=1),
)
self.causal = causal
def forward(self, x):
y = self.block(x)
pad = x.shape[-1] - y.shape[-1]
if pad > 0:
if self.causal:
x = x[..., :-pad]
else:
x = x[..., pad // 2 : -pad // 2]
return x + y
class EncoderBlock(nn.Module):
def __init__(
self,
dim: int = 16,
stride: int = 1,
causal: bool = False,
n_t_layer: int = 0,
transformer_general_config=None,
):
super().__init__()
conv_class = CausalWNConv1d if causal else WNConv1d
transformer_module = (
nn.Identity()
if n_t_layer == 0
else (
WindowLimitedTransformer(
causal=causal,
input_dim=dim,
window_size=getattr(transformer_general_config, "window_size", 512),
config=transformer_general_config(
n_layer=n_t_layer,
n_head=dim // 64,
dim=dim,
intermediate_size=dim * 3,
),
)
)
)
self.block = nn.Sequential(
ResidualUnit(dim // 2, dilation=1, causal=causal),
ResidualUnit(dim // 2, dilation=3, causal=causal),
ResidualUnit(dim // 2, dilation=9, causal=causal),
Snake1d(dim // 2),
conv_class(
dim // 2,
dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
),
transformer_module,
)
def forward(self, x):
return self.block(x)
class Encoder(nn.Module):
def __init__(
self,
d_model: int = 64,
strides: list = [2, 4, 8, 8],
d_latent: int = 64,
n_transformer_layers: list = [0, 0, 4, 4],
transformer_general_config: ModelArgs = None,
causal: bool = False,
):
super().__init__()
conv_class = CausalWNConv1d if causal else WNConv1d
# Create first convolution
self.block = [conv_class(1, d_model, kernel_size=7, padding=3)]
# Create EncoderBlocks that double channels as they downsample by `stride`
for stride, n_t_layer in zip(strides, n_transformer_layers):
d_model *= 2
self.block += [
EncoderBlock(
d_model,
stride=stride,
causal=causal,
n_t_layer=n_t_layer,
transformer_general_config=transformer_general_config,
)
]
# Create last convolution
self.block += [
Snake1d(d_model),
conv_class(d_model, d_latent, kernel_size=3, padding=1),
]
# Wrap black into nn.Sequential
self.block = nn.Sequential(*self.block)
self.enc_dim = d_model
def forward(self, x):
return self.block(x)
class DecoderBlock(nn.Module):
def __init__(
self,
input_dim: int = 16,
output_dim: int = 8,
stride: int = 1,
causal: bool = False,
n_t_layer: int = 0,
transformer_general_config=None,
):
super().__init__()
conv_trans_class = CausalWNConvTranspose1d if causal else WNConvTranspose1d
transformer_module = (
nn.Identity()
if n_t_layer == 0
else (
WindowLimitedTransformer(
causal=causal,
input_dim=input_dim,
window_size=None,
config=transformer_general_config(
n_layer=n_t_layer,
n_head=input_dim // 64,
dim=input_dim,
intermediate_size=input_dim * 3,
),
)
)
)
self.block = nn.Sequential(
# transformer_module,
Snake1d(input_dim),
conv_trans_class(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
),
ResidualUnit(output_dim, dilation=1, causal=causal),
ResidualUnit(output_dim, dilation=3, causal=causal),
ResidualUnit(output_dim, dilation=9, causal=causal),
)
def forward(self, x):
return self.block(x)
class Decoder(nn.Module):
def __init__(
self,
input_channel,
channels,
rates,
d_out: int = 1,
causal: bool = False,
n_transformer_layers: list = [0, 0, 0, 0],
transformer_general_config=None,
):
super().__init__()
conv_class = CausalWNConv1d if causal else WNConv1d
# Add first conv layer
layers = [conv_class(input_channel, channels, kernel_size=7, padding=3)]
# Add upsampling + MRF blocks
for i, (stride, n_t_layer) in enumerate(zip(rates, n_transformer_layers)):
input_dim = channels // 2**i
output_dim = channels // 2 ** (i + 1)
layers += [
DecoderBlock(
input_dim,
output_dim,
stride,
causal=causal,
n_t_layer=n_t_layer,
transformer_general_config=transformer_general_config,
)
]
# Add final conv layer
layers += [
Snake1d(output_dim),
conv_class(output_dim, d_out, kernel_size=7, padding=3),
nn.Tanh(),
]
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class DAC(BaseModel, CodecMixin):
def __init__(
self,
encoder_dim: int = 64,
encoder_rates: List[int] = [2, 4, 8, 8],
latent_dim: int = None,
decoder_dim: int = 1536,
decoder_rates: List[int] = [8, 8, 4, 2],
quantizer: torch.nn.Module = None,
sample_rate: int = 44100,
causal: bool = True,
encoder_transformer_layers: List[int] = [0, 0, 0, 0],
decoder_transformer_layers: List[int] = [0, 0, 0, 0],
overwrite_decoder: torch.nn.Module = None,
transformer_general_config=None,
):
super().__init__()
self.encoder_dim = encoder_dim
self.encoder_rates = encoder_rates
self.decoder_dim = decoder_dim
self.decoder_rates = decoder_rates
self.sample_rate = sample_rate
if latent_dim is None:
latent_dim = encoder_dim * (2 ** len(encoder_rates))
self.latent_dim = latent_dim
self.hop_length = np.prod(encoder_rates)
self.encoder = Encoder(
encoder_dim,
encoder_rates,
latent_dim,
causal=causal,
n_transformer_layers=encoder_transformer_layers,
transformer_general_config=transformer_general_config,
)
self.quantizer = quantizer
if overwrite_decoder is not None:
self.decoder = overwrite_decoder
else:
self.decoder = Decoder(
latent_dim,
decoder_dim,
decoder_rates,
causal=causal,
n_transformer_layers=decoder_transformer_layers,
transformer_general_config=transformer_general_config,
)
self.sample_rate = sample_rate
self.apply(init_weights)
self.delay = self.get_delay()
self.frame_length = self.hop_length * 4
def preprocess(self, audio_data, sample_rate):
if sample_rate is None:
sample_rate = self.sample_rate
assert sample_rate == self.sample_rate
length = audio_data.shape[-1]
right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
audio_data = nn.functional.pad(audio_data, (0, right_pad))
return audio_data
def encode(
self,
audio_data: torch.Tensor,
audio_lengths: torch.Tensor = None,
n_quantizers: int = None,
**kwargs,
):
"""Encode given audio data and return quantized latent codes
Parameters
----------
audio_data : Tensor[B x T]
Audio data to encode
n_quantizers : int, optional
Number of quantizers to use, by default None
If None, all quantizers are used.
Returns
-------
dict
A dictionary with the following keys:
"z" : Tensor[B x D x T]
Quantized continuous representation of input
"codes" : Tensor[B x N x T]
Codebook indices for each codebook
(quantized discrete representation of input)
"latents" : Tensor[B x N*D x T]
Projected latents (continuous representation of input before quantization)
"vq/commitment_loss" : Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
"vq/codebook_loss" : Tensor[1]
Codebook loss to update the codebook
"length" : int
Number of samples in input audio
"""
# pad to multiple of self.frame_length
if audio_data.ndim == 2:
audio_data = audio_data.unsqueeze(1)
length = audio_data.shape[-1]
right_pad = math.ceil(length / self.frame_length) * self.frame_length - length
audio_data = nn.functional.pad(audio_data, (0, right_pad))
if audio_lengths is None:
audio_lengths = torch.LongTensor([length + right_pad]).to(audio_data.device)
z = self.encoder(audio_data)
vq_results = self.quantizer(z, n_quantizers, **kwargs)
indices = vq_results.codes
indices_lens = torch.ceil(audio_lengths / self.frame_length).long()
return indices, indices_lens
def from_indices(self, indices: torch.Tensor):
z = self.quantizer.decode(indices)
return self.decoder(z)
def decode(self, z: torch.Tensor):
"""Decode given latent codes and return audio data
Parameters
----------
z : Tensor[B x D x T]
Quantized continuous representation of input
length : int, optional
Number of samples in output audio, by default None
Returns
-------
dict
A dictionary with the following keys:
"audio" : Tensor[B x 1 x length]
Decoded audio data.
"""
return self.decoder(z)
def forward(
self,
audio_data: torch.Tensor,
template: torch.Tensor = None,
mask: torch.Tensor = None,
sample_rate: int = None,
n_quantizers: int = None,
**kwargs,
):
"""Model forward pass
Parameters
----------
audio_data : Tensor[B x 1 x T]
Audio data to encode
sample_rate : int, optional
Sample rate of audio data in Hz, by default None
If None, defaults to `self.sample_rate`
n_quantizers : int, optional
Number of quantizers to use, by default None.
If None, all quantizers are used.
Returns
-------
dict
A dictionary with the following keys:
"z" : Tensor[B x D x T]
Quantized continuous representation of input
"codes" : Tensor[B x N x T]
Codebook indices for each codebook
(quantized discrete representation of input)
"latents" : Tensor[B x N*D x T]
Projected latents (continuous representation of input before quantization)
"vq/commitment_loss" : Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
"vq/codebook_loss" : Tensor[1]
Codebook loss to update the codebook
"length" : int
Number of samples in input audio
"audio" : Tensor[B x 1 x length]
Decoded audio data.
"""
length = audio_data.shape[-1]
audio_data = self.preprocess(audio_data, sample_rate)
vq_results = self.encode(audio_data, n_quantizers, **kwargs)
z = vq_results[0] if isinstance(vq_results, tuple) else vq_results.z
x = self.decode(z)
return x[..., :length], vq_results
if __name__ == "__main__":
import hydra
import numpy as np
import soundfile as sf
import torch
from omegaconf import OmegaConf
# 配置路径
config_path = "fish_speech/configs/modded_dac_vq.yaml"
checkpoint_path = "checkpoints/s2-pro/codec.pth"
codes_path = "./output/codes_0.npy" # 你的 codes 文件路径
output_path = "reconstructed_from_codes.wav"
sample_rate = 44100 # 请确保采样率与模型训练时一致
with torch.inference_mode():
# 1. 初始化模型
model = hydra.utils.instantiate(OmegaConf.load(config_path))
new_sd = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(new_sd, strict=False)
model.cuda()
model.eval()
# 2. 加载外部 codes (.npy)
# 预期 shape 通常为 [num_codebooks, seq_len] 或 [1, num_codebooks, seq_len]
codes_np = np.load(codes_path)
codes_tensor = torch.from_numpy(codes_np).to(torch.long).cuda()
# 如果 codes 没有 batch 维度,增加一个维度 [1, num_codebooks, seq_len]
if len(codes_tensor.shape) == 2:
codes_tensor = codes_tensor.unsqueeze(0)
print(f"Loaded codes shape: {codes_tensor.shape}")
# 3. 直接从 codes 重建音频 (Decoding)
# 注意:fish_speech 的 model.from_indices 通常接受的输入是 LongTensor
fake_audio = model.from_indices(codes_tensor)
# 4. 后处理与保存
# fake_audio 形状通常为 [B, C, T]
audio_np = fake_audio.squeeze().cpu().numpy()
# 如果是多声道,转置为 soundfile 要求的 (samples, channels)
if len(audio_np.shape) == 2:
audio_np = audio_np.T
sf.write(output_path, audio_np, sample_rate)
print(f"重建完成。音频已保存至: {output_path}")
================================================
FILE: fish_speech/models/dac/rvq.py
================================================
import math
import typing as tp
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from dac.nn.quantize import ResidualVectorQuantize
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
assert (padding_left + padding_right) <= x.shape[-1]
end = x.shape[-1] - padding_right
return x[..., padding_left:end]
def get_extra_padding_for_conv1d(
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
) -> int:
"""See `pad_for_conv1d`."""
length = x.shape[-1]
n_frames = (length - kernel_size + padding_total) / stride + 1
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
return ideal_length - length
def pad1d(
x: torch.Tensor,
paddings: tp.Tuple[int, int],
mode: str = "zeros",
value: float = 0.0,
):
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
If this is the case, we insert extra 0 padding to the right
before the reflection happen.
"""
length = x.shape[-1]
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
if mode == "reflect":
max_pad = max(padding_left, padding_right)
extra_pad = 0
if length <= max_pad:
extra_pad = max_pad - length + 1
x = F.pad(x, (0, extra_pad))
padded = F.pad(x, paddings, mode, value)
end = padded.shape[-1] - extra_pad
return padded[..., :end]
else:
return F.pad(x, paddings, mode, value)
class CausalConvNet(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
dilation=1,
stride=1,
groups=1,
padding=None,
):
super(CausalConvNet, self).__init__()
self.conv = nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride=stride,
dilation=dilation,
groups=groups,
)
self.stride = stride
self.kernel_size = (kernel_size - 1) * dilation + 1
self.dilation = dilation
self.padding = self.kernel_size - self.stride
def forward(self, x):
pad = self.padding
extra_padding = get_extra_padding_for_conv1d(
x, self.kernel_size, self.stride, pad
)
x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
return self.conv(x).contiguous()
def weight_norm(self, name="weight", dim=0):
self.conv = weight_norm(self.conv, name=name, dim=dim)
return self
def remove_weight_norm(self):
self.conv = remove_parametrizations(self.conv)
return self
class CausalTransConvNet(nn.Module):
def __init__(
self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None
):
super(CausalTransConvNet, self).__init__()
self.conv = nn.ConvTranspose1d(
in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
)
self.stride = stride
self.kernel_size = kernel_size
def forward(self, x):
x = self.conv(x)
pad = self.kernel_size - self.stride
padding_right = math.ceil(pad)
padding_left = pad - padding_right
x = unpad1d(x, (padding_left, padding_right))
return x.contiguous()
def weight_norm(self, name="weight", dim=0):
self.conv = weight_norm(self.conv, name=name, dim=dim)
return self
def remove_weight_norm(self):
self.conv = remove_parametrizations(self.conv)
return self
# ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py
class ConvNeXtBlock(nn.Module):
r"""ConvNeXt Block. There are two equivalent implementations:
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
We use (2) as we find it slightly faster in PyTorch
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
kernel_size (int): Kernel size for depthwise conv. Default: 7.
dilation (int): Dilation for depthwise conv. Default: 1.
""" # noqa: E501
def __init__(
self,
dim: int,
layer_scale_init_value: float = 1e-6,
mlp_ratio: float = 4.0,
kernel_size: int = 7,
dilation: int = 1,
):
super().__init__()
convnet_type = CausalConvNet
self.dwconv = convnet_type(
dim,
dim,
kernel_size=kernel_size,
# padding=int(dilation * (kernel_size - 1) / 2),
groups=dim,
dilation=dilation,
) # depthwise conv
self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(
dim, int(mlp_ratio * dim)
) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
self.gamma = (
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
if layer_scale_init_value > 0
else None
)
def forward(self, x, apply_residual: bool = True):
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
if apply_residual:
x = input + x
return x
@dataclass
class VQResult:
z: torch.Tensor
codes: torch.Tensor
latents: torch.Tensor
codebook_loss: torch.Tensor
commitment_loss: torch.Tensor
semantic_distill_z: torch.Tensor | None = None
class DownsampleResidualVectorQuantize(nn.Module):
def __init__(
self,
input_dim: int = 1024,
n_codebooks: int = 9,
codebook_dim: int = 8,
quantizer_dropout: float = 0.5,
codebook_size: int = 1024,
semantic_codebook_size: int = 4096,
downsample_factor: tuple[int] = (2, 2),
downsample_dims: tuple[int] | None = None,
pre_module: nn.Module | None = None,
post_module: nn.Module | None = None,
semantic_predictor_module: nn.Module | None = None,
):
super().__init__()
if downsample_dims is None:
downsample_dims = [input_dim for _ in range(len(downsample_factor))]
all_dims = (input_dim,) + tuple(downsample_dims)
self.semantic_quantizer = ResidualVectorQuantize(
input_dim=input_dim,
n_codebooks=1,
codebook_size=semantic_codebook_size,
codebook_dim=codebook_dim,
quantizer_dropout=0.0,
)
self.quantizer = ResidualVectorQuantize(
input_dim=input_dim,
n_codebooks=n_codebooks,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
quantizer_dropout=quantizer_dropout,
)
self.downsample_factor = downsample_factor
self.downsample_dims = downsample_dims
convnet_type = CausalConvNet
transconvnet_type = CausalTransConvNet
self.downsample = nn.Sequential(
*[
nn.Sequential(
convnet_type(
all_dims[idx],
all_dims[idx + 1],
kernel_size=factor,
stride=factor,
),
ConvNeXtBlock(dim=all_dims[idx + 1]),
)
for idx, factor in enumerate(downsample_factor)
]
)
self.upsample = nn.Sequential(
*[
nn.Sequential(
transconvnet_type(
all_dims[idx + 1],
all_dims[idx],
kernel_size=factor,
stride=factor,
),
ConvNeXtBlock(dim=all_dims[idx]),
)
for idx, factor in reversed(list(enumerate(downsample_factor)))
]
)
self.apply(self._init_weights)
self.pre_module = (
pre_module if pre_module is not None else nn.Identity()
) # leave for transformer, LSTM or Mamba or something else
self.post_module = post_module if post_module is not None else nn.Identity()
self.semantic_predictor_module = (
semantic_predictor_module
if semantic_predictor_module is not None
else nn.Identity()
)
def _init_weights(self, m):
if isinstance(m, (nn.Conv1d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def forward(
self, z, n_quantizers: int = None, semantic_len: torch.Tensor = None, **kwargs
):
# z: (B, D, T)
original_shape = z.shape
if semantic_len is None:
semantic_len = torch.LongTensor([z.shape[-1]])
z = self.downsample(z)
z = self.pre_module(z) # B, T, D
(
semantic_z,
semantic_codes,
semantic_latents,
semantic_commitment_loss,
semantic_codebook_loss,
) = self.semantic_quantizer(z)
residual_z = z - semantic_z
residual_z, codes, latents, commitment_loss, codebook_loss = self.quantizer(
residual_z, n_quantizers=n_quantizers
)
z = semantic_z + residual_z
commitment_loss = commitment_loss + semantic_commitment_loss
codebook_loss = codebook_loss + semantic_codebook_loss
codes = torch.cat([semantic_codes, codes], dim=1)
latents = torch.cat([semantic_latents, latents], dim=1)
z = self.post_module(z)
z = self.upsample(z)
# z: (B, D, T)
# semantic distillation (disabled here since only used in training)
# semantic_distill_z = self.semantic_predictor_module(semantic_z, semantic_len).mT # wav2vec target is B, T, D
# Pad or crop z to match original shape
diff = original_shape[-1] - z.shape[-1]
right = 0
left = abs(diff) - right
if diff > 0:
z = F.pad(z, (left, right))
elif diff < 0:
z = z[..., left:]
results = VQResult(
z=z,
codes=codes,
latents=latents,
commitment_loss=commitment_loss,
codebook_loss=codebook_loss,
)
return results
# def encode(self, z):
# z = self.downsample(z)
# z = self.pre_module(z)
# _, indices, _, _, _ = self.quantizer(z.mT)
# indices = rearrange(indices, "g b l r -> b (g r) l")
# return indices
#
def decode(self, indices: torch.Tensor):
# indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups)
indices[:, 0] = torch.clamp(
indices[:, 0], max=self.semantic_quantizer.codebook_size - 1
)
indices[:, 1:] = torch.clamp(
indices[:, 1:], max=self.quantizer.codebook_size - 1
)
z_q_semantic = self.semantic_quantizer.from_codes(indices[:, :1])[0]
z_q_residual = self.quantizer.from_codes(indices[:, 1:])[0]
z_q = z_q_semantic + z_q_residual
z_q = self.post_module(z_q)
z_q = self.upsample(z_q)
return z_q
# def from_latents(self, latents: torch.Tensor):
# z_q, z_p, codes = super().from_latents(latents)
# z_q = self.upsample(z_q)
# return z_q, z_p, codes
if __name__ == "__main__":
rvq = DownsampleResidualVectorQuantize(
input_dim=512,
n_codebooks=8,
codebook_dim=8,
codebook_size=1024,
quantizer_dropout=0.5,
downsample_factor=[2, 2],
)
rvq.eval()
x = torch.randn(2, 512, 442)
result = rvq(x)
print(rvq)
print(result.latents.shape, result.codes.shape, result.z.shape)
# y = rvq.from_codes(result.codes)
# print(y[0].shape)
# y = rvq.from_latents(
result1 = rvq(x[:, :, :40])
print(result1.latents.shape, result1.codes.shape, result1.z.shape)
assert torch.allclose(result.z[:, :, :40], result1.z, atol=1e-8)
print("Success")
================================================
FILE: fish_speech/models/text2semantic/__init__.py
================================================
================================================
FILE: fish_speech/models/text2semantic/inference.py
================================================
import os
import queue
import re
import threading
import time
import traceback
from copy import deepcopy
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Literal, Optional, Tuple, Union
import click
import numpy as np
import torch
import torch._inductor.config
from loguru import logger
from tqdm import tqdm
from fish_speech.content_sequence import (
TextPart,
VQPart,
)
from fish_speech.conversation import Conversation, Message
from fish_speech.tokenizer import IM_END_TOKEN
os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
if hasattr(torch._inductor.config, "fx_graph_cache"):
torch._inductor.config.fx_graph_cache = True
from torch.nn.attention import SDPBackend, sdpa_kernel
from fish_speech.models.text2semantic.llama import (
BaseTransformer,
DualARTransformer,
NaiveTransformer,
)
def multinomial_sample_one_no_sync(probs_sort):
q = torch.rand_like(probs_sort)
q = -torch.log(q)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
RAS_WIN_SIZE = 10 # window for Repetition Aware Sampling
RAS_HIGH_TEMP = 1.0
RAS_HIGH_TOP_P = 0.9
def logits_to_probs(
logits,
temperature: torch.Tensor,
top_p: torch.Tensor,
top_k: int, # 注意: 我看到你传进来的是 int,这很关键
) -> torch.Tensor:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
indices = torch.arange(sorted_logits.shape[-1], device=sorted_logits.device)
top_k_mask = indices >= top_k
sorted_indices_to_remove = (cum_probs > top_p) | top_k_mask
sorted_indices_to_remove[0] = False # 单元素修改问题不大,或者写成 | (indices != 0)
indices_to_remove = sorted_indices_to_remove.scatter(
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
)
logits = torch.where(
indices_to_remove, float("-Inf"), logits
) # 同样替换 masked_fill_ 为 torch.where
logits = logits / torch.clip(temperature, min=1e-5)
probs = torch.nn.functional.softmax(logits, dim=-1)
return probs
def sample(
logits,
temperature: torch.Tensor,
top_p: torch.Tensor,
top_k: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
probs = logits_to_probs(
logits=logits[0, -1],
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs
def decode_one_token_ar(
model: DualARTransformer,
x: torch.Tensor,
input_pos: torch.Tensor,
temperature: torch.Tensor,
top_p: torch.Tensor,
top_k: int,
semantic_logit_bias: torch.Tensor,
audio_masks: torch.Tensor,
audio_parts: torch.Tensor,
previous_tokens: Optional[torch.Tensor] = None,
) -> torch.Tensor:
forward_result = model.forward_generate(
x,
input_pos,
audio_masks=audio_masks,
audio_parts=audio_parts,
)
logits = forward_result.logits # (1, 1, vocab_size)
hidden_states = forward_result.hidden_states
# Apply constrained decoding: only allow semantic tokens + im_end
biased_logits = logits + semantic_logit_bias
# Normal sample
main_token_normal = sample(
biased_logits, temperature=temperature, top_p=top_p, top_k=top_k
)[0]
# RAS: also sample with high temp to use as fallback if token repeats
high_temp = torch.tensor(
RAS_HIGH_TEMP, device=temperature.device, dtype=temperature.dtype
)
high_top_p = torch.tensor(RAS_HIGH_TOP_P, device=top_p.device, dtype=top_p.dtype)
main_token_high = sample(
biased_logits, temperature=high_temp, top_p=high_top_p, top_k=top_k
)[0]
# Use high-temp sample if: token is semantic AND token is in previous window
if previous_tokens is not None:
in_window = (previous_tokens[0] == main_token_normal).any()
# Use tensor ops (&, torch.where) instead of Python (and, if) — torch.compile requires no data-dependent branching
is_semantic = (main_token_normal >= model.config.semantic_begin_id) & (
main_token_normal <= model.config.semantic_end_id
)
should_use_high = in_window & is_semantic
main_token_normal = torch.where(
should_use_high, main_token_high, main_token_normal
)
codebooks = [main_token_normal]
input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
model.forward_generate_fast(hidden_states, input_pos)
a = codebooks[0] - model.config.semantic_begin_id
a = torch.clamp(a, min=0, max=model.config.codebook_size - 1)
hidden_states = model.fast_embeddings(a)
codebooks.append(a)
for codebook_idx in range(1, model.config.num_codebooks):
input_pos = torch.tensor(
[codebook_idx], device=hidden_states.device, dtype=torch.long
)
logits = model.forward_generate_fast(hidden_states, input_pos)
short_logits = logits # DualAR predicts config.codebook_size number of tokens
# Convert logits to probs (no constrain for fast codebooks)
a = sample(
short_logits,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)[0]
hidden_states = model.fast_embeddings(a)
codebooks.append(a)
codebooks = torch.stack(codebooks, dim=1)
# Only delete references, let Python GC handle cleanup
del logits, hidden_states, forward_result
return codebooks.T
def decode_n_tokens(
model: DualARTransformer,
cur_token: torch.Tensor,
input_pos: torch.Tensor,
num_new_tokens: int,
temperature: torch.Tensor,
top_p: torch.Tensor,
top_k: int,
semantic_logit_bias: torch.Tensor,
audio_masks: torch.Tensor,
audio_parts: torch.Tensor,
decode_one_token=decode_one_token_ar,
):
# Rolling window for RAS (Repetition Aware Sampling)
previous_tokens = torch.zeros(
(model.config.num_codebooks + 1, RAS_WIN_SIZE),
dtype=torch.int,
device=cur_token.device,
)
# Accumulate all generated tokens (the actual output)
new_tokens = []
# [MODIFIED] Pre-fetch ID for efficiency loop
im_end_id = model.tokenizer.get_token_id(IM_END_TOKEN)
for i in tqdm(range(num_new_tokens)):
with sdpa_kernel(SDPBackend.MATH):
next_token = decode_one_token(
model=model,
x=cur_token,
input_pos=input_pos,
previous_tokens=previous_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
semantic_logit_bias=semantic_logit_bias,
audio_masks=audio_masks,
audio_parts=audio_parts,
).clone()
input_pos += 1
cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
# Roll RAS window left and insert new token at end
previous_tokens = previous_tokens.roll(-1, dims=1)
previous_tokens[:, -1] = next_token.view(model.config.num_codebooks + 1, -1)[
:, 0
]
new_tokens.append(next_token)
if cur_token[0, 0, -1] == im_end_id:
break
del cur_token
return torch.cat(new_tokens, dim=1)
@torch.no_grad()
@torch.inference_mode()
def generate(
*,
model: DualARTransformer,
prompt: torch.Tensor,
max_new_tokens: int,
audio_masks: torch.Tensor,
audio_parts: torch.Tensor,
decode_one_token=decode_one_token_ar,
num_samples: int = 1,
**sampling_kwargs,
):
"""
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
"""
# create an empty tensor of the expected final shape and fill in the current tokens
T = prompt.size(1)
prompt = prompt[None].repeat(num_samples, 1, 1)
if T >= model.config.max_seq_len:
raise ValueError(
f"Input sequence length {T} exceeds max_seq_len {model.config.max_seq_len}"
)
if max_new_tokens:
if T + max_new_tokens > model.config.max_seq_len:
max_new_tokens = model.config.max_seq_len - T
T_new = T + max_new_tokens
else:
T_new = model.config.max_seq_len
max_new_tokens = T_new - T
device = prompt.device
dtype = next(
model.parameters()
).dtype # model weight dtype (bfloat16), NOT prompt dtype (int32)
# Critical fix: Only set up cache on first run or when necessary
if not hasattr(model, "_cache_setup_done") or not model._cache_setup_done:
with torch.device(device):
model.setup_caches(
max_batch_size=1, # Fixed to 1, avoid dynamic changes
max_seq_len=model.config.max_seq_len,
dtype=next(model.parameters()).dtype,
)
model._cache_setup_done = True
codebook_dim = 1 + model.config.num_codebooks
# Create new tensor each time, but try to reuse memory
input_pos = torch.arange(0, T, device=device, dtype=torch.long)
empty = torch.empty(
(codebook_dim, model.config.max_seq_len), dtype=prompt.dtype, device=device
)
empty[:, :T] = prompt
seq = empty
temp_val = sampling_kwargs.get("temperature", 1.0)
top_p_val = sampling_kwargs.get("top_p", 0.9)
top_k_val = sampling_kwargs.get("top_k", 30)
temperature = torch.tensor(temp_val, device=device, dtype=dtype)
top_p = torch.tensor(top_p_val, device=device, dtype=dtype)
# Build semantic logit bias: 0 for semantic tokens + im_end, -inf for all others
vocab_size = model.config.vocab_size
semantic_logit_bias = torch.full(
(1, 1, vocab_size), float("-inf"), device=device, dtype=dtype
)
# [MODIFIED] Use config for semantic range
semantic_logit_bias[
0, 0, model.config.semantic_begin_id : model.config.semantic_end_id + 1
] = 0.0
# [MODIFIED] Use tokenizer.get_token_id (Wrapper method)
semantic_logit_bias[0, 0, model.tokenizer.get_token_id(IM_END_TOKEN)] = 0.0
prefill_decode = decode_one_token_ar
first_token = prefill_decode(
model,
prompt.view(1, codebook_dim, -1),
input_pos,
temperature,
top_p,
top_k_val,
semantic_logit_bias,
audio_masks,
audio_parts,
)
seq[:, T : T + 1] = first_token
# Recreate input_pos
input_pos = torch.tensor([T], device=device, dtype=torch.int)
x = decode_n_tokens(
model,
first_token.view(1, codebook_dim, -1),
input_pos,
max_new_tokens - 1,
temperature=temperature,
top_p=top_p,
top_k=top_k_val,
semantic_logit_bias=semantic_logit_bias,
audio_masks=audio_masks,
audio_parts=audio_parts,
decode_one_token=decode_one_token,
)
seq = seq[:, : T + 1 + x.size(1)]
seq[:, T + 1 :] = x
# Clean up temporary variables
del first_token, x, prompt, empty, input_pos
return seq
def init_model(checkpoint_path, device, precision, compile=False):
model = DualARTransformer.from_pretrained(checkpoint_path, load_weights=True)
model = model.to(device=device, dtype=precision)
logger.info(f"Restored model from checkpoint")
if isinstance(model, DualARTransformer):
decode_one_token = decode_one_token_ar
# prefill_n_tokens = decode_one_token_ar
logger.info("Using DualARTransformer")
else:
raise ValueError("Unsupported model type")
# Pre-create fixed parameter tensors to avoid runtime creation
model.fixed_temperature = torch.tensor(0.7, device=device, dtype=torch.float)
model.fixed_top_p = torch.tensor(0.7, device=device, dtype=torch.float)
model.fixed_repetition_penalty = torch.tensor(1.5, device=device, dtype=torch.float)
# Mark whether cache has been initialized
model._cache_setup_done = False
if compile:
logger.info("Compiling function...")
decode_one_token = torch.compile(
decode_one_token,
backend="inductor" if torch.cuda.is_available() else "aot_eager",
mode="default" if torch.cuda.is_available() else None,
fullgraph=True,
)
return model.eval(), decode_one_token
@torch.inference_mode()
def load_codec_model(codec_checkpoint_path, device, precision=torch.bfloat16):
"""Load the DAC codec model for audio encoding/decoding."""
from hydra.utils import instantiate
from omegaconf import OmegaConf
config_path = Path(__file__).parent.parent.parent / "configs" / "modded_dac_vq.yaml"
cfg = OmegaConf.load(str(config_path))
codec = instantiate(cfg)
state_dict = torch.load(codec_checkpoint_path, map_location="cpu")
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
if any("generator" in k for k in state_dict):
state_dict = {
k.replace("generator.", ""): v
for k, v in state_dict.items()
if "generator." in k
}
codec.load_state_dict(state_dict, strict=False)
codec.eval()
codec.to(device=device, dtype=precision)
return codec
@torch.inference_mode()
def encode_audio(audio_path, codec, device):
"""Encode an audio file to VQ codes."""
import torchaudio
wav, sr = torchaudio.load(str(audio_path))
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
wav = torchaudio.functional.resample(wav.to(device), sr, codec.sample_rate)[0]
# Match codec model dtype (e.g. bfloat16)
model_dtype = next(codec.parameters()).dtype
audios = wav[None, None].to(dtype=model_dtype) # (1, 1, T)
audio_lengths = torch.tensor([len(wav)], device=device, dtype=torch.long)
indices, feature_lengths = codec.encode(audios, audio_lengths)
return indices[0, :, : feature_lengths[0]] # (num_codebooks, T)
@torch.inference_mode()
def decode_to_audio(codes, codec):
"""Decode VQ codes to audio waveform."""
# codes: (num_codebooks, T) -> (1, num_codebooks, T)
audio = codec.from_indices(codes[None])
return audio[0, 0] # (T,) mono waveform
@dataclass
class GenerateResponse:
action: Literal["sample", "next"]
codes: Optional[torch.Tensor] = None
text: Optional[str] = None
def split_text_by_speaker(text: str) -> list[str]:
"""
Split text into turns based on <|speaker:X|> tags.
Args:
text: The full text with speaker tags
Returns:
List of speaker turns, each starting with <|speaker:X|>
"""
pattern = r"(<\|speaker:\d+\|>)"
parts = re.split(pattern, text)
turns = []
i = 0
while i < len(parts):
part = parts[i].strip()
if re.match(pattern, part):
if i + 1 < len(parts):
turn = part + parts[i + 1]
turns.append(turn.strip())
i += 2
else:
turns.append(part)
i += 1
else:
i += 1
return turns
def group_turns_into_batches(
turns: list[str], max_speakers: int = 3, max_bytes: int = 300
) -> list[str]:
"""
Group turns into batches based on speaker count or byte limit.
Args:
turns: List of speaker turns
max_speakers: Maximum number of speakers per batch (default 3)
max_bytes: Maximum UTF-8 bytes per batch (default 300)
Returns:
List of batched text strings
"""
batches = []
current_batch = []
current_bytes = 0
for turn in turns:
turn_bytes = len(turn.encode("utf-8"))
would_exceed_speakers = len(current_batch) >= max_speakers
would_exceed_bytes = current_bytes + turn_bytes > max_bytes and current_batch
if would_exceed_speakers or would_exceed_bytes:
batches.append("\n".join(current_batch))
current_batch = [turn]
current_bytes = turn_bytes
else:
current_batch.append(turn)
current_bytes += turn_bytes
if current_batch:
batches.append("\n".join(current_batch))
return batches
def generate_long(
*,
model,
device: Union[str, torch.device],
decode_one_token: Callable,
text: str,
num_samples: int = 1,
max_new_tokens: int = 0,
top_p: float = 0.9,
top_k: int = 30,
repetition_penalty: float = 1.1,
temperature: float = 1.0,
compile: bool = False,
iterative_prompt: bool = True,
chunk_length: int = 512,
prompt_text: Optional[Union[str, list[str]]] = None,
prompt_tokens: Optional[Union[torch.Tensor, list[torch.Tensor]]] = None,
):
assert 0 < top_p <= 1, "top_p must be in (0, 1]"
assert 0 < temperature < 2, "temperature must be in (0, 2)"
use_prompt = bool(prompt_text) and bool(prompt_tokens)
if use_prompt and isinstance(prompt_text, str):
prompt_text = [prompt_text]
prompt_tokens = [prompt_tokens]
if use_prompt:
assert len(prompt_text) == len(
prompt_tokens
), "Prompt text and tokens must have the same length"
if prompt_tokens:
prompt_tokens = [i.cpu() for i in prompt_tokens]
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
tokenizer = model.tokenizer
max_length = model.config.max_seq_len
# Build base conversation with system message
base_conversation = Conversation()
if use_prompt:
# Auto-add speaker tags to prompt texts that don't have them
tagged_prompt_text = []
for i, t in enumerate(prompt_text):
if not re.search(r"<\|speaker:\d+\|>", t):
tagged_prompt_text.append(f"<|speaker:{i}|>{t}")
else:
tagged_prompt_text.append(t)
system_parts = [
TextPart(
text="convert the provided text to speech reference to the following:\n\nText:\n",
cal_loss=False,
),
]
reference_text = "\n".join(tagged_prompt_text)
system_parts.append(TextPart(text=reference_text, cal_loss=False))
system_parts.append(TextPart(text="\n\nSpeech:\n", cal_loss=False))
all_codes = torch.cat([c for c in prompt_tokens], dim=1)
system_parts.append(VQPart(codes=all_codes, cal_loss=False))
# torch.save(all_codes, "debug_vq_codes.pt")
else:
system_parts = [
TextPart(text="convert the provided text to speech", cal_loss=False)
]
base_conversation.append(
Message(
role="system",
parts=system_parts,
cal_loss=False,
add_im_start=True,
add_im_end=True,
)
)
# Split text by speaker and group into batches
turns = split_text_by_speaker(text)
if turns:
batches = group_turns_into_batches(
turns, max_speakers=5, max_bytes=chunk_length
)
else:
batches = [text]
logger.info(f"Split into {len(turns)} turns, grouped into {len(batches)} batches")
for sample_idx in range(num_samples):
if torch.cuda.is_available():
torch.cuda.synchronize()
t0 = time.perf_counter()
# Deep copy base conversation for this sample
conversation = deepcopy(base_conversation)
for batch_idx, batch_text in enumerate(batches):
logger.info(
f"--- Sample {sample_idx}, Batch {batch_idx} "
f"({len(batch_text.encode('utf-8'))} bytes) ---"
)
logger.info(f"Batch text: {batch_text}")
# Add user message
conversation.append(
Message(
role="user",
parts=[TextPart(text=batch_text, cal_loss=False)],
cal_loss=False,
add_im_start=True,
add_im_end=True,
)
)
# Deep copy for generation (don't pollute original conversation)
conversation_gen = deepcopy(conversation)
conversation_gen.append(
Message(
role="assistant",
parts=[],
cal_loss=False,
modality="voice",
add_im_start=True,
add_im_end=False,
)
)
logger.info("Visualizing prompt structure:")
conversation_gen.visualize(
tokenizer,
merge_audio_tokens=True,
merge_semantic_tokens=True,
)
encoded, audio_masks, audio_parts = conversation_gen.encode_for_inference(
tokenizer, num_codebooks=model.config.num_codebooks
)
logger.info(f"Encoded prompt shape: {encoded.shape}")
if audio_parts is not None:
logger.info(f"Audio parts shape: {audio_parts.shape}")
if audio_masks is not None:
logger.info(
f"Audio masks non-zero count: {torch.count_nonzero(audio_masks)}"
)
if encoded.size(1) > max_length - 2048:
raise ValueError(
f"Prompt is too long: {encoded.size(1)} > {max_length - 2048}"
)
encoded = encoded.to(device=device)
prompt_length = encoded.size(1)
y = generate(
model=model,
prompt=encoded,
max_new_tokens=max_new_tokens,
audio_masks=audio_masks,
audio_parts=audio_parts,
decode_one_token=decode_one_token,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
if sample_idx == 0 and batch_idx == 0 and compile:
logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
if torch.cuda.is_available():
torch.cuda.synchronize()
t_batch = time.perf_counter() - t0
tokens_generated = y.size(1) - prompt_length
tokens_sec = tokens_generated / t_batch if t_batch > 0 else 0
logger.info(
f"Batch {batch_idx}: Generated {tokens_generated} tokens in "
f"{t_batch:.02f} seconds, {tokens_sec:.02f} tokens/sec"
)
logger.info(
f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
)
# Extract generated codes
codes = y[1:, prompt_length:-1].clone()
assert (codes >= 0).all(), f"Negative code found: {codes}"
# Add assistant message with generated codes back to conversation
conversation.append(
Message(
role="assistant",
parts=[VQPart(codes=codes.cpu(), cal_loss=False)],
cal_loss=False,
modality="voice",
add_im_start=True,
add_im_end=True,
)
)
yield GenerateResponse(action="sample", codes=codes, text=batch_text)
# Cleanup
del y, encoded
if torch.cuda.is_available():
logger.info(
f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
)
yield GenerateResponse(action="next")
@dataclass
class WrappedGenerateResponse:
status: Literal["success", "error"]
response: Optional[Union[GenerateResponse, Exception]] = None
@dataclass
class GenerateRequest:
request: dict
response_queue: queue.Queue
def launch_thread_safe_queue(
checkpoint_path,
device,
precision,
compile: bool = False,
):
input_queue = queue.Queue()
init_event = threading.Event()
def worker():
model, decode_one_token = init_model(
checkpoint_path, device, precision, compile=compile
)
with torch.device(device):
model.setup_caches(
max_batch_size=1,
max_seq_len=model.config.max_seq_len,
dtype=next(model.parameters()).dtype,
)
init_event.set()
while True:
item: GenerateRequest | None = input_queue.get()
if item is None:
break
kwargs = item.request
response_queue = item.response_queue
try:
for chunk in generate_long(
model=model, decode_one_token=decode_one_token, **kwargs
):
response_queue.put(
WrappedGenerateResponse(status="success", response=chunk)
)
# Only clear cache after complete request batch
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception as e:
logger.error(traceback.format_exc())
response_queue.put(WrappedGenerateResponse(status="error", response=e))
# Clear cache on error
if torch.cuda.is_available():
torch.cuda.empty_cache()
threading.Thread(target=worker, daemon=True).start()
init_event.wait()
return input_queue
@click.command()
@click.option(
"--text",
type=str,
default="<|speaker:0|>你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
)
@click.option("--prompt-text", type=str, default=None, multiple=True)
@click.option(
"--prompt-tokens",
type=click.Path(path_type=Path, exists=True),
default=None,
multiple=True,
)
@click.option(
"--prompt-audio",
type=click.Path(path_type=Path, exists=True),
default=None,
multiple=True,
)
@click.option("--output", type=click.Path(path_type=Path), default=None)
@click.option("--num-samples", type=int, default=1)
@click.option("--max-new-tokens", type=int, default=0)
@click.option("--top-p", type=float, default=0.9)
@click.option("--top-k", type=int, default=30)
@click.option("--temperature", type=float, default=1.0)
@click.option(
"--checkpoint-path",
type=click.Path(path_type=Path, exists=True),
default="checkpoints/s2-pro",
)
@click.option("--device", type=str, default="cuda")
@click.option("--compile/--no-compile", default=False)
@click.option("--seed", type=int, default=42)
@click.option("--half/--no-half", default=False)
@click.option("--iterative-prompt/--no-iterative-prompt", default=True)
@click.option("--chunk-length", type=int, default=300)
@click.option("--output-dir", type=Path, default="output")
def main(
text: str,
prompt_text: Optional[tuple[str, ...]],
prompt_tokens: Optional[tuple[Path, ...]],
prompt_audio: Optional[tuple[Path, ...]],
output: Optional[Path],
num_samples: int,
max_new_tokens: int,
top_p: float,
top_k: int,
temperature: float,
checkpoint_path: Path,
device: str,
compile: bool,
seed: int,
half: bool,
iterative_prompt: bool,
chunk_length: int,
output_dir: Path,
) -> None:
os.makedirs(output_dir, exist_ok=True)
precision = torch.half if half else torch.bfloat16
if prompt_text and not prompt_audio and not prompt_tokens:
raise ValueError(
"--prompt-text requires either --prompt-audio or --prompt-tokens"
)
if prompt_text and prompt_tokens and len(prompt_text) != len(prompt_tokens):
raise ValueError(
f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
)
if prompt_text and prompt_audio and len(prompt_text) != len(prompt_audio):
raise ValueError(
f"Number of prompt text ({len(prompt_text)}) and prompt audio ({len(prompt_audio)}) should be the same"
)
logger.info("Loading model ...")
t0 = time.time()
model, decode_one_token = init_model(
checkpoint_path, device, precision, compile=compile
)
with torch.device(device):
model.setup_caches(
max_batch_size=1,
max_seq_len=model.config.max_seq_len,
dtype=next(model.parameters()).dtype,
)
if torch.cuda.is_available():
torch.cuda.synchronize()
logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
codec = None
codec_checkpoint = checkpoint_path / "codec.pth"
# Handle prompt: --prompt-audio takes priority over --prompt-tokens
prompt_tokens_list = None
if prompt_audio:
logger.info("Loading codec model for audio encoding...")
codec = load_codec_model(codec_checkpoint, device, precision)
prompt_tokens_list = [
encode_audio(p, codec, device).cpu() for p in prompt_audio
]
logger.info(f"Encoded {len(prompt_audio)} audio file(s) to VQ codes")
elif prompt_tokens is not None:
prompt_tokens_list = [torch.from_numpy(np.load(p)) for p in prompt_tokens]
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
generator = generate_long(
model=model,
device=device,
decode_one_token=decode_one_token,
text=text,
num_samples=num_samples,
max_new_tokens=max_new_tokens,
top_p=top_p,
top_k=top_k,
temperature=temperature,
compile=compile,
iterative_prompt=iterative_prompt,
chunk_length=chunk_length,
prompt_text=list(prompt_text) if prompt_text else None,
prompt_tokens=prompt_tokens_list,
)
idx = 0
codes = []
for response in generator:
if response.action == "sample":
codes.append(response.codes)
logger.info(f"Sampled text: {response.text}")
elif response.action == "next":
if codes:
merged_codes = torch.cat(codes, dim=1)
codes_npy_path = os.path.join(output_dir, f"codes_{idx}.npy")
np.save(codes_npy_path, merged_codes.cpu().numpy())
logger.info(f"Saved codes to {codes_npy_path}")
# Decode to wav if --output is specified
if output:
if codec is None:
logger.info("Loading codec model for audio decoding...")
codec = load_codec_model(codec_checkpoint, device, precision)
audio = decode_to_audio(merged_codes.to(device), codec)
import soundfile as sf
out_path = (
str(output)
if num_samples == 1
else str(output.with_stem(f"{output.stem}_{idx}"))
)
sf.write(out_path, audio.cpu().float().numpy(), codec.sample_rate)
logger.info(f"Saved audio to {out_path}")
logger.info(f"Next sample")
codes = []
idx += 1
else:
logger.error(f"Error: {response}")
if __name__ == "__main__":
main()
================================================
FILE: fish_speech/models/text2semantic/lit_module.py
================================================
from typing import Any, Optional
import lightning as L
import torch
import torch.nn.functional as F
from lightning.pytorch.utilities.types import OptimizerLRScheduler
import fish_speech.utils as utils
CODEBOOK_PAD_TOKEN_ID = 0
from fish_speech.models.text2semantic.llama import NaiveTransformer
log = utils.RankedLogger(__name__, rank_zero_only=True)
class TextToSemantic(L.LightningModule):
def __init__(
self,
model: NaiveTransformer,
optimizer: Any,
lr_scheduler: Any,
):
super().__init__()
self.model = model
self.optimizer_builder = optimizer
self.lr_scheduler_builder = lr_scheduler
def forward(self, x):
return self.model(x)
def on_save_checkpoint(self, checkpoint):
# Save only LoRA parameters
state_dict = checkpoint["state_dict"]
use_lora = any("lora" in name for name in state_dict.keys())
if not use_lora:
return
for name in list(state_dict.keys()):
if "lora" not in name:
state_dict.pop(name)
def configure_optimizers(self) -> OptimizerLRScheduler:
# Get weight decay parameters
weight_decay_parameters, other_parameters = [], []
for name, param in self.named_parameters():
if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
other_parameters.append(param)
else:
weight_decay_parameters.append(param)
optimizer = self.optimizer_builder(
[
{"params": weight_decay_parameters},
{"params": other_parameters, "weight_decay": 0.0},
]
)
# Print the parameters and their weight decay
for i in optimizer.param_groups:
log.info(
f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
)
lr_scheduler = self.lr_scheduler_builder(optimizer)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": lr_scheduler,
"interval": "step",
},
}
# Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
def get_batch_logps(
self,
logits: torch.FloatTensor,
labels: torch.LongTensor,
average_log_prob: bool = False,
) -> torch.FloatTensor:
"""Compute the log probabilities of the given labels under the given logits.
Args:
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
Returns:
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
"""
assert logits.shape[:-1] == labels.shape
labels = labels.clone()
loss_mask = labels != -100
# dummy token; we'll ignore the losses on these tokens later
labels[labels == -100] = 0
per_token_logps = torch.gather(
logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
).squeeze(-1)
if average_log_prob:
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
else:
return (per_token_logps * loss_mask).sum(-1)
def _step(self, batch, batch_idx, stage: str):
is_train = stage == "train"
if is_train:
# Key part to make lora work
# Otherwise the parameters are merged, which lead to incorrect gradients
self.model.train()
# Do positive and negative samples in the same batch to speed up training
labels = batch["labels"]
outputs = self.model(
inp=batch["inputs"],
key_padding_mask=batch["attention_masks"],
labels=batch["labels"],
)
token_logits = outputs.token_logits
codebook_logits = outputs.codebook_logits
# Generate labels
base_loss = F.cross_entropy(
token_logits.view(-1, token_logits.size(-1)),
labels[:, 0].reshape(-1),
ignore_index=-100,
)
token_ids = labels[:, 0]
semantic_mask = (token_ids >= self.model.tokenizer.semantic_begin_id) & (
token_ids <= self.model.tokenizer.semantic_end_id
)
all_codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks]
all_codebook_labels_permuted = all_codebook_labels.permute(0, 2, 1)
filtered_codebook_labels = all_codebook_labels_permuted[semantic_mask]
semantic_loss = F.cross_entropy(
codebook_logits.reshape(-1, codebook_logits.size(-1)),
filtered_codebook_labels.reshape(-1),
ignore_index=-100,
)
loss = base_loss + semantic_loss
self.log(
f"{stage}/loss",
loss,
on_step=is_train,
on_epoch=not is_train,
prog_bar=True,
logger=True,
sync_dist=not is_train,
)
self.log(
f"{stage}/base_loss",
base_loss,
on_step=is_train,
on_epoch=not is_train,
prog_bar=False,
logger=True,
sync_dist=not is_train,
)
self.log(
f"{stage}/semantic_loss",
semantic_loss,
on_step=is_train,
on_epoch=not is_train,
prog_bar=False,
logger=True,
sync_dist=not is_train,
)
# Top-5 accuracy
accuracy = self.get_accuracy(codebook_logits, filtered_codebook_labels)
self.log(
f"{stage}/top_5_accuracy",
accuracy,
on_step=is_train,
on_epoch=not is_train,
prog_bar=True,
logger=True,
sync_dist=not is_train,
)
return loss
def get_accuracy(self, logits, labels):
mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID)
if mask.sum() == 0:
return torch.tensor(0.0, device=logits.device)
_, indices = logits.topk(5, dim=-1)
correct = indices.eq(labels.unsqueeze(-1))
correct[~mask] = 0
correct = correct.sum()
accuracy = correct / mask.sum()
return accuracy
def training_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "train")
def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "val")
================================================
FILE: fish_speech/models/text2semantic/llama.py
================================================
import dataclasses
import json
import math
from collections import OrderedDict
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import torch
import torch.nn as nn
from einops import rearrange
from loguru import logger
from torch import Tensor
from torch.nn import functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.utils.checkpoint import checkpoint
from fish_speech.models.text2semantic.lora import LoraConfig, setup_lora
def find_multiple(n: int, k: int) -> int:
if n % k == 0:
return n
return n + k - (n % k)
@dataclass
class BaseModelArgs:
model_type: str = "base"
vocab_size: int = 32000
n_layer: int = 32
n_head: int = 32
dim: int = 4096
intermediate_size: int = None
n_local_heads: int = -1
head_dim: int = 64
rope_base: float = 10000
norm_eps: float = 1e-5
max_seq_len: int = 2048
dropout: float = 0.0
tie_word_embeddings: bool = True
attention_qkv_bias: bool = False
attention_o_bias: bool = False
attention_qk_norm: bool = False
# Codebook configs
codebook_size: int = 160
num_codebooks: int = 4
semantic_begin_id: int = 0
semantic_end_id: int = 0
# Gradient checkpointing
use_gradient_checkpointing: bool = True
# Initialize the model
initializer_range: float = 0.02
# Dummy vars
is_reward_model: bool = False
scale_codebook_embeddings: bool = False
audio_embed_dim: Optional[int] = None
def __post_init__(self):
if self.n_local_heads == -1:
self.n_local_heads = self.n_head
if self.intermediate_size is None:
hidden_dim = 4 * self.dim
n_hidden = int(2 * hidden_dim / 3)
self.intermediate_size = find_multiple(n_hidden, 256)
if self.head_dim is None:
self.head_dim = self.dim // self.n_head
@staticmethod
def from_pretrained(path: str):
path = Path(path)
if path.is_dir():
path = path / "config.json"
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
match data["model_type"]:
case "naive":
cls = NaiveModelArgs
case "dual_ar":
cls = DualARModelArgs
case "fish_qwen3_omni":
return BaseModelArgs._from_fish_qwen3_omni(data)
case _:
raise ValueError(f"Unknown model type: {data['model_type']}")
# Filter out unexpected keyword arguments
valid_keys = {f.name for f in dataclasses.fields(cls)}
data = {k: v for k, v in data.items() if k in valid_keys}
return cls(**data)
@staticmethod
def _from_fish_qwen3_omni(data: dict) -> "DualARModelArgs":
tc = data["text_config"]
adc = data["audio_decoder_config"]
flat = dict(
model_type="dual_ar",
vocab_size=tc["vocab_size"],
n_layer=tc["n_layer"],
n_head=tc["n_head"],
n_local_heads=tc.get("n_local_heads", -1),
head_dim=tc.get("head_dim"),
dim=tc["dim"],
intermediate_size=tc.get("intermediate_size"),
rope_base=tc.get("rope_base", 10000),
norm_eps=tc.get("norm_eps", 1e-5),
max_seq_len=tc.get("max_seq_len", 2048),
dropout=tc.get("dropout", 0.0),
tie_word_embeddings=tc.get("tie_word_embeddings", True),
attention_qkv_bias=tc.get("attention_qkv_bias", False),
attention_o_bias=tc.get("attention_o_bias", False),
attention_qk_norm=tc.get("attention_qk_norm", False),
use_gradient_checkpointing=tc.get("use_gradient_checkpointing", True),
initializer_range=tc.get("initializer_range", 0.02),
semantic_begin_id=data.get("semantic_start_token_id", 0),
semantic_end_id=data.get("semantic_end_token_id", 0),
scale_codebook_embeddings=True,
norm_fastlayer_input=True,
audio_embed_dim=adc.get("text_dim", tc["dim"]),
codebook_size=adc["vocab_size"],
num_codebooks=adc["num_codebooks"],
n_fast_layer=adc["n_layer"],
fast_dim=adc.get("dim"),
fast_n_head=adc.get("n_head"),
fast_n_local_heads=adc.get("n_local_heads"),
fast_head_dim=adc.get("head_dim"),
fast_intermediate_size=adc.get("intermediate_size"),
fast_attention_qkv_bias=adc.get("attention_qkv_bias"),
fast_attention_qk_norm=adc.get("attention_qk_norm"),
fast_attention_o_bias=adc.get("attention_o_bias"),
)
valid_keys = {f.name for f in dataclasses.fields(DualARModelArgs)}
flat = {k: v for k, v in flat.items() if k in valid_keys and v is not None}
return DualARModelArgs(**flat)
def save(self, path: str):
with open(path, "w") as f:
json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
@dataclass
class NaiveModelArgs(BaseModelArgs):
model_type: str = "naive"
@dataclass
class DualARModelArgs(BaseModelArgs):
model_type: str = "dual_ar"
n_fast_layer: int = 4
fast_dim: int | None = None
fast_n_head: int | None = None
fast_n_local_heads: int | None = None
fast_head_dim: int | None = None
fast_intermediate_size: int | None = None
fast_attention_qkv_bias: bool | None = None
fast_attention_qk_norm: bool | None = None
fast_attention_o_bias: bool | None = None
norm_fastlayer_input: bool = False
def __post_init__(self):
super().__post_init__()
self.fast_dim = self.fast_dim or self.dim
self.fast_n_head = self.fast_n_head or self.n_head
self.fast_n_local_heads = self.fast_n_local_heads or self.n_local_heads
self.fast_head_dim = self.fast_head_dim or self.head_dim
self.fast_intermediate_size = (
self.fast_intermediate_size or self.intermediate_size
)
self.fast_attention_qkv_bias = (
self.fast_attention_qkv_bias
if self.fast_attention_qkv_bias is not None
else self.attention_qkv_bias
)
self.fast_attention_qk_norm = (
self.fast_attention_qk_norm
if self.fast_attention_qk_norm is not None
else self.attention_qk_norm
)
self.fast_attention_o_bias = (
self.fast_attention_o_bias
if self.fast_attention_o_bias is not None
else self.attention_o_bias
)
class KVCache(nn.Module):
def __init__(
self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
):
super().__init__()
cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
def update(self, input_pos, k_val, v_val):
# input_pos: [S], k_val: [B, H, S, D]
assert input_pos.shape[0] == k_val.shape[2]
k_out = self.k_cache
v_out = self.v_cache
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val
return k_out, v_out
@dataclass
class TransformerForwardResult:
token_logits: Tensor
codebook_logits: Tensor
@dataclass
class BaseTransformerForwardResult:
logits: Tensor
hidden_states: Tensor
def _remap_fish_qwen3_omni_keys(weights: OrderedDict) -> OrderedDict:
if not any(k.startswith(("text_model.", "audio_decoder.")) for k in weights):
return weights
new_weights = OrderedDict()
for k, v in weights.items():
if k.startswith("text_model.model."):
new_key = k[len("text_model.model.") :]
elif k.startswith("audio_decoder."):
suffix = k[len("audio_decoder.") :]
new_key = (
suffix
if suffix.startswith("codebook_embeddings.")
else "fast_" + suffix
)
else:
new_key = k
new_weights[new_key] = v
return new_weights
class BaseTransformer(nn.Module):
def __init__(
self,
config: BaseModelArgs,
init_weights: bool = True,
) -> None:
super().__init__()
self.config = config
# Slow transformer
self.embeddings = nn.Embedding(
config.vocab_size,
config.dim,
)
self.codebook_embeddings = nn.Embedding(
config.codebook_size * config.num_codebooks,
config.dim,
)
self.layers = nn.ModuleList(
TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
)
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
if self.config.tie_word_embeddings is False:
self.output = nn.Linear(
config.dim,
config.vocab_size,
bias=False,
)
self.register_buffer(
"freqs_cis",
precompute_freqs_cis(
config.max_seq_len,
config.head_dim,
config.rope_base,
),
persistent=False,
)
self.register_buffer(
"causal_mask",
torch.tril(
torch.ones(
config.max_seq_len,
config.max_seq_len,
dtype=torch.bool,
)
),
persistent=False,
)
# For kv cache
self.max_batch_size = -1
self.max_seq_len = -1
if init_weights:
self.apply(self._init_weights)
def setup_caches(
self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
):
if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
return
max_seq_len = find_multiple(max_seq_len, 8)
self.max_seq_len = max_seq_len
self.max_batch_size = max_batch_size
for b in self.layers:
b.attention.kv_cache = KVCache(
max_batch_size,
max_seq_len,
self.config.n_local_heads,
self.config.head_dim,
dtype=dtype,
)
def embed(self, inp: Tensor) -> Tensor:
embeds = []
for i in range(self.config.num_codebooks):
emb = self.codebook_embeddings(
inp[:, i + 1] + i * self.config.codebook_size
)
embeds.append(emb)
vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1)
is_semantic = (inp[:, 0] >= self.config.semantic_begin_id) & (
inp[:, 0] <= self.config.semantic_end_id
)
vq_embeds_sum[~is_semantic] = 0
x = self.embeddings(inp[:, 0]) + vq_embeds_sum
return x
def forward(
self,
inp: Tensor,
key_padding_mask: Optional[Tensor] = None,
) -> BaseTransformerForwardResult:
seq_len = inp.size(2)
# Here we want to merge the embeddings of the codebooks
x = self.embed(inp)
freqs_cis = self.freqs_cis[:seq_len]
mask = None
if key_padding_mask is not None:
causal = self.causal_mask[:seq_len, :seq_len]
causal = rearrange(causal, "q k -> 1 1 q k")
atten_mask = rearrange(key_padding_mask, "b s -> b 1 1 s")
atten_mask = atten_mask.logical_not()
mask = causal & atten_mask
for layer in self.layers:
if self.config.use_gradient_checkpointing and self.training:
x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
else:
x = layer(x, freqs_cis, mask)
slow_out = self.norm(x)
if self.config.tie_word_embeddings:
token_logits = F.linear(slow_out, self.embeddings.weight)
else:
token_logits = self.output(slow_out)
hidden_out = (
slow_out if getattr(self.config, "norm_fastlayer_input", False) else x
)
return BaseTransformerForwardResult(
logits=token_logits,
hidden_states=hidden_out,
)
def forward_generate(
self,
inp: Tensor,
input_pos: Optional[Tensor] = None,
audio_masks: Optional[Tensor] = None,
audio_parts: Optional[Tensor] = None,
return_all: bool = False,
) -> BaseTransformerForwardResult:
# Embedding logic replicated from embed() for compilation compatibility
embeds = []
for i in range(self.config.num_codebooks):
emb = self.codebook_embeddings(
inp[:, i + 1] + i * self.config.codebook_size
)
embeds.append(emb)
vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1)
vq_masks = (inp[:, 0] >= self.config.semantic_begin_id) & (
inp[:, 0] <= self.config.semantic_end_id
)
vq_embeds_sum[~vq_masks] = 0
x = self.embeddings(inp[:, 0]) + vq_embeds_sum
if self.config.scale_codebook_embeddings:
vq_masks_expanded = vq_masks.unsqueeze(-1).expand_as(x)
x = torch.where(
vq_masks_expanded, x / math.sqrt(self.config.num_codebooks + 1), x
)
# Audio embeddings
if audio_parts is not None:
# Note: This assumes self.audio_projector exists if audio_parts is used
# It seems missing in init, but we keep existing logic
if hasattr(self, "audio_projector"):
audio_embeds = self.audio_projector(audio_parts)
if self.config.scale_codebook_embeddings:
x[audio_masks] = audio_embeds / math.sqrt(2)
else:
x[audio_masks] = audio_embeds
else:
logger.warning("audio_parts provided but model has no audio_projector")
if input_pos is None:
input_pos = torch.arange(inp.shape[-1], device=x.device)
max_seq_len = inp.shape[-1]
else:
max_seq_len = self.max_seq_len
mask = self.causal_mask[None, None, input_pos, :max_seq_len] # (B, N, Q, K)
freqs_cis = self.freqs_cis[input_pos]
for layer in self.layers:
x = layer(x, freqs_cis, mask, input_pos=input_pos)
if x.size(1) > 1 and not return_all:
x = x[:, -1:]
slow_out = self.norm(x)
if self.config.is_reward_model:
token_logits = self.score_output(slow_out)
elif self.config.tie_word_embeddings:
token_logits = F.linear(slow_out, self.embeddings.weight)
else:
token_logits = self.output(slow_out)
hidden_out = (
slow_out if getattr(self.config, "norm_fastlayer_input", False) else x
)
return BaseTransformerForwardResult(
logits=token_logits,
hidden_states=hidden_out,
)
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@staticmethod
def from_pretrained(
path: str,
load_weights: bool = False,
max_length: int | None = None,
lora_config: LoraConfig | None = None,
rope_base: int | None = None,
) -> "BaseTransformer":
# Import wrapper locally to avoid circular dependency or global import issues
from fish_speech.tokenizer import FishTokenizer
config = BaseModelArgs.from_pretrained(str(path))
if max_length is not None:
config.max_seq_len = max_length
logger.info(f"Override max_seq_len to {max_length}")
if rope_base is not None:
config.rope_base = rope_base
logger.info(f"Override rope_base to {rope_base}")
try:
tokenizer = FishTokenizer.from_pretrained(path)
config.semantic_begin_id = tokenizer.semantic_begin_id
config.semantic_end_id = tokenizer.semantic_end_id
logger.info(
f"Injected Semantic IDs into Config: {config.semantic_begin_id}-{config.semantic_end_id}"
)
except Exception as e:
logger.warning(
f"Failed to load tokenizer for config injection: {e}. Semantic IDs might be 0."
)
match config.model_type:
case "naive":
model_cls = NaiveTransformer
case "dual_ar":
model_cls = DualARTransformer
case _:
raise ValueError(f"Unknown model type: {config.model_type}")
logger.info(f"Loading model from {path}, config: {config}")
# Initialize model without passing tokenizer explicitly to __init__
model = model_cls(config)
# Attach tokenizer to model instance for inference convenience (optional, but good for user scripts)
model.tokenizer = tokenizer
if load_weights is False:
logger.info("Randomly initialized model")
else:
if "int8" in str(Path(path)):
logger.info("Using int8 weight-only quantization!")
from tools.llama.quantize import WeightOnlyInt8QuantHandler
simple_quantizer = WeightOnlyInt8QuantHandler(model)
model = simple_quantizer.convert_for_runtime()
if "int4" in str(Path(path)):
logger.info("Using int4 quantization!")
path_comps = path.name.split("-")
assert path_comps[-2].startswith("g")
groupsize = int(path_comps[-2][1:])
from tools.llama.quantize import WeightOnlyInt4QuantHandler
simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
model = simple_quantizer.convert_for_runtime()
path_obj = Path(path)
index_json = path_obj / "model.safetensors.index.json"
single_st = path_obj / "model.safetensors"
pth_file = path_obj / "model.pth"
if index_json.exists():
logger.info("Loading sharded safetensors weights")
from safetensors.torch import load_file as st_load_file
with open(index_json) as f:
st_index = json.load(f)
shard_files = sorted(set(st_index["weight_map"].values()))
weights = OrderedDict()
for shard in shard_files:
weights.update(st_load_file(str(path_obj / shard), device="cpu"))
weights = _remap_fish_qwen3_omni_keys(weights)
elif single_st.exists():
logger.info("Loading single safetensors weights")
from safetensors.torch import load_file as st_load_file
weights = OrderedDict(st_load_file(str(single_st), device="cpu"))
weights = _remap_fish_qwen3_omni_keys(weights)
elif pth_file.exists():
weights = torch.load(
pth_file,
map_location="cpu",
mmap=True,
weights_only=True,
)
if "state_dict" in weights:
weights = weights["state_dict"]
if weights and next(iter(weights.keys())).startswith("model."):
weights = OrderedDict(
(k.replace("model.", ""), v) for k, v in weights.items()
)
for k in list(weights.keys()):
if "audio_" in k:
weights.pop(k)
else:
raise FileNotFoundError(f"No model weights found in {path_obj}")
err = model.load_state_dict(weights, strict=False, assign=True)
logger.info(f"Model weights loaded - Status: {err}")
if lora_config is not None:
setup_lora(model, lora_config)
logger.info(f"LoRA setup: {lora_config}")
return model
def save_pretrained(self, path: str, drop_lora: bool = False):
path = Path(path)
path.mkdir(parents=True, exist_ok=True)
self.config.save(path / "config.json")
state_dict = self.state_dict()
if drop_lora:
for key in list(state_dict.keys()):
if "lora" not in key:
continue
state_dict.pop(key)
torch.save(state_dict, path / "model.pth")
if hasattr(self, "tokenizer"):
self.tokenizer.save_pretrained(path)
class NaiveTransformer(BaseTransformer):
def __init__(self, config: NaiveModelArgs) -> None:
super().__init__(config, init_weights=False)
self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.codebook_output = nn.Linear(
config.dim,
config.codebook_size * config.num_codebooks,
bias=False,
)
self.apply(self._init_weights)
def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
token_logits = result.logits
x = result.hidden_states
# Codebook
codebook_logits = self.codebook_output(self.codebook_norm(x))
codebook_logits = rearrange(
codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
)
return TransformerForwardResult(
token_logits=token_logits,
codebook_logits=codebook_logits,
)
def forward(
self,
inp: Tensor,
key_padding_mask: Optional[Tensor] = None,
) -> TransformerForwardResult:
result = super().forward(
inp=inp,
key_padding_mask=key_padding_mask,
)
return self.decode(result)
def forward_generate(
self, x: Tensor, input_pos: Optional[Tensor] = None
) -> TransformerForwardResult:
result = super().forward_generate(x, input_pos)
return self.decode(result)
class DualARTransformer(BaseTransformer):
def __init__(self, config: NaiveModelArgs) -> None:
super().__init__(config, init_weights=False)
# Project to fast dim if needed
if config.fast_dim is not None and config.fast_dim != config.dim:
self.fast_project_in = nn.Linear(config.dim, config.fast_dim)
else:
self.fast_project_in = nn.Identity()
# Fast transformer
self.fast_embeddings = nn.Embedding(config.codebook_size, config.fast_dim)
# The equivalent bs is so large that sdpa doesn't work
override_config = dataclasses.replace(
config,
dim=config.fast_dim,
n_head=config.fast_n_head,
n_local_heads=config.fast_n_local_heads,
head_dim=config.fast_head_dim,
intermediate_size=config.fast_intermediate_size,
attention_qkv_bias=config.fast_attention_qkv_bias,
attention_qk_norm=config.fast_attention_qk_norm,
attention_o_bias=config.fast_attention_o_bias,
)
self.fast_layers = nn.ModuleList(
TransformerBlock(override_config, use_sdpa=False)
for _ in range(config.n_fast_layer)
)
self.fast_norm = RMSNorm(config.fast_dim, eps=config.norm_eps)
self.fast_output = nn.Linear(
config.fast_dim,
config.codebook_size,
bias=False,
)
self.register_buffer(
"fast_freqs_cis",
precompute_freqs_cis(
config.num_codebooks,
config.fast_head_dim,
config.rope_base,
),
persistent=False,
)
self.apply(self._init_weights)
def setup_caches(
self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
):
super().setup_caches(max_batch_size, max_seq_len, dtype)
# Fast transformer
# The max seq len here is the number of codebooks
for b in self.fast_layers:
b.attention.kv_cache = KVCache(
max_batch_size,
self.config.num_codebooks,
self.config.fast_n_local_heads,
self.config.fast_head_dim,
dtype=dtype,
)
def forward(
self,
inp: Tensor,
labels: Optional[Tensor] = None,
key_padding_mask: Optional[Tensor] = None,
vq_parts: Optional[Tensor] = None,
vq_masks: Optional[Tensor] = None,
vq_require_losses: Optional[Tensor] = None,
mel_parts: Optional[Tensor] = None,
mel_masks: Optional[Tensor] = None,
) -> TransformerForwardResult:
parent_result = super().forward(
inp=inp,
key_padding_mask=key_padding_mask,
)
token_logits = parent_result.logits
x = parent_result.hidden_states
# Fast transformer
fast_seq_len = self.config.num_codebooks
fast_mask = self.causal_mask[
None, None, :fast_seq_len, :fast_seq_len
] # (B, N, Q, K)
fast_freqs_cis = self.fast_freqs_cis[:fast_seq_len]
# Extract corresponding parts with labels
token_labels = labels[:, 0]
# [MODIFIED] Use config instead of tokenizer
codebook_mask = (token_labels >= self.config.semantic_begin_id) & (
token_labels <= self.config.semantic_end_id
)
# This gives where input token is <|semantic|>
x = x[codebook_mask]
if x.shape[0] == 0:
# Use dummy input when no vq is required
x = torch.zeros(
(4, self.config.dim),
device=x.device,
dtype=x.dtype,
)
codebooks = torch.zeros(
(x.shape[0], self.config.num_codebooks - 1),
device=x.device,
dtype=torch.int,
)
else:
all_codebooks = labels[:, 1:, :]
all_codebooks_permuted = all_codebooks.permute(0, 2, 1)
semantic_codebooks = all_codebooks_permuted[codebook_mask]
codebooks = semantic_codebooks[:, :-1]
x = self.fast_project_in(x)
codebook_embeddings = self.fast_embeddings(codebooks)
x = torch.cat([x[:, None], codebook_embeddings], dim=1)
for layer in self.fast_layers:
if self.config.use_gradient_checkpointing and self.training:
x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True)
else:
x = layer(x, fast_freqs_cis, fast_mask)
# unflatten the batch and num_codebooks
fast_out = self.fast_norm(x)
codebook_logits = self.fast_output(fast_out)
assert codebook_logits.shape[1] == self.config.num_codebooks
return TransformerForwardResult(
token_logits=token_logits,
codebook_logits=codebook_logits,
)
def forward_generate_fast(
self, x: Tensor, input_pos: Optional[Tensor] = None
) -> Tensor:
# Fast transformer
x = x.view(x.shape[0], 1, -1)
fast_mask = self.causal_mask[
None, None, input_pos, : self.config.num_codebooks
] # (B, N, Q, K)
fast_freqs_cis = self.fast_freqs_cis[input_pos]
for layer in self.fast_layers:
x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
# unflatten the batch and num_codebooks
fast_out = self.fast_norm(x) # only take the last token
codebook_logits = self.fast_output(fast_out)
return codebook_logits
def forward_generate(
self,
x: Tensor,
input_pos: Optional[Tensor] = None,
audio_masks: Optional[Tensor] = None,
audio_parts: Optional[Tensor] = None,
) -> TransformerForwardResult:
x = super().forward_generate(x, input_pos, audio_masks, audio_parts)
x.hidden_states = self.fast_project_in(x.hidden_states)
return x
class TransformerBlock(nn.Module):
def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
super().__init__()
self.attention = Attention(config, use_sdpa=use_sdpa)
self.feed_forward = FeedForward(config)
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
def forward(
self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
) -> Tensor:
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
out = h + self.feed_forward(self.ffn_norm(h))
return out
class Attention(nn.Module):
def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
super().__init__()
assert config.dim % config.n_head == 0
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
# key, query, value projections for all heads, but in a batch
self.wqkv = nn.Linear(
config.dim, total_head_dim, bias=config.attention_qkv_bias
)
self.wo = nn.Linear(
config.n_head * config.head_dim, config.dim, bias=config.attention_o_bias
)
self.kv_cache = None
if config.attention_qk_norm:
self.q_norm = nn.RMSNorm(config.head_dim, config.norm_eps)
self.k_norm = nn.RMSNorm(config.head_dim, config.norm_eps)
self.dropout = config.dropout
self.n_head = config.n_head
self.head_dim = config.head_dim
self.n_local_heads = config.n_local_heads
self.dim = config.dim
self.use_sdpa = use_sdpa
self.attention_qk_norm = config.attention_qk_norm
self.config = config
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(self, state_dict, prefix, *args):
if prefix + "wq.weight" in state_dict:
wq = state_dict.pop(prefix + "wq.weight")
wk = state_dict.pop(prefix + "wk.weight")
wv = state_dict.pop(prefix + "wv.weight")
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
def forward(
self,
x: Tensor,
freqs_cis: Tensor,
mask: Tensor,
input_pos: Optional[Tensor] = None,
) -> Tensor:
bsz, seqlen, _ = x.shape
q_size = self.n_head * self.head_dim
kv_size = self.n_local_heads * self.head_dim
q, k, v = self.wqkv(x).split([q_size, kv_size, kv_size], dim=-1)
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
if self.attention_qk_norm:
q = self.q_norm(q)
k = self.k_norm(k)
q = apply_rotary_emb(q, freqs_cis)
k = apply_rotary_emb(k, freqs_cis)
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
if self.kv_cache is not None:
k, v = self.kv_cache.update(input_pos, k, v)
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
if self.use_sdpa:
if mask is None:
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
y = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.dropout if self.training else 0.0,
is_causal=True,
# No third party attn_mask here to use flash_attention
)
else:
y = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
dropout_p=self.dropout if self.training else 0.0,
)
else:
y = self.eq_scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
dropout_p=self.dropout if self.training else 0.0,
)
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, q_size)
return self.wo(y)
def eq_scaled_dot_product_attention(
self,
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
) -> torch.Tensor:
# This is a standard scaled dot product attention
# It's low efficient, but it doesn't raise cuda error
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1))
attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias = torch.where(
attn_mask.logical_not(), float("-inf"), attn_bias
)
else:
attn_bias = attn_bias + attn_mask
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value
class FeedForward(nn.Module):
def __init__(self, config: BaseModelArgs) -> None:
super().__init__()
self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
def forward(self, x: Tensor) -> Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
def forward(self, x: Tensor) -> Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight
def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
"""
Precomputes frequency tensors for complex exponentials (cis)
Args:
seq_len: Length of the sequence for which positional embeddings are needed.
n_elem: Number of elements in the frequency tensor.
base: Base value for the frequency scaling (default: 10000).
Returns:
A tensor containing the precomputed frequencies in real and imaginary parts (bfloat16).
"""
freqs = 1.0 / (
base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
)
t = torch.arange(seq_len, device=freqs.device)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
return cache.to(dtype=torch.bfloat16)
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
x_out2 = torch.stack(
[
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
],
-1,
)
x_out2 = x_out2.flatten(3)
return x_out2.type_as(x)
================================================
FILE: fish_speech/models/text2semantic/lora.py
================================================
from dataclasses import dataclass
import loralib as lora
@dataclass
class LoraConfig:
r: int
lora_alpha: float
lora_dropout: float = 0.0
def _replace_embedding(old_embed, lora_config):
new_embed = lora.Embedding(
num_embeddings=old_embed.num_embeddings,
embedding_dim=old_embed.embedding_dim,
padding_idx=old_embed.padding_idx,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
)
new_embed.weight.data.copy_(old_embed.weight.data)
return new_embed
def setup_lora(model, lora_config):
# Replace the embedding layer with a LoRA layer, preserving pretrained weights
model.embeddings = _replace_embedding(model.embeddings, lora_config)
model.codebook_embeddings = _replace_embedding(
model.codebook_embeddings, lora_config
)
# Replace output layer with a LoRA layer
linears = [(model, "output")]
# Replace all linear layers with LoRA layers
for layer in model.layers:
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
linears.extend(
[
(layer.feed_forward, "w1"),
(layer.feed_forward, "w2"),
(layer.feed_forward, "w3"),
]
)
if hasattr(model, "fast_layers"):
model.fast_embeddings = _replace_embedding(model.fast_embeddings, lora_config)
# Dual-AR model
linears.append((model, "fast_output"))
for layer in model.fast_layers:
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
linears.extend(
[
(layer.feed_forward, "w1"),
(layer.feed_forward, "w2"),
(layer.feed_forward, "w3"),
]
)
for module, layer_name in linears:
old_linear = getattr(module, layer_name)
updated_linear = lora.Linear(
in_features=old_linear.in_features,
out_features=old_linear.out_features,
bias=old_linear.bias is not None,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.lora_dropout,
)
updated_linear.weight.data.copy_(old_linear.weight.data)
if old_linear.bias is not None:
updated_linear.bias.data.copy_(old_linear.bias.data)
setattr(module, layer_name, updated_linear)
# Mark only the LoRA layers as trainable
lora.mark_only_lora_as_trainable(model, bias="none")
def get_merged_state_dict(model):
# This line will merge the state dict of the model and the LoRA parameters
model.eval()
# Then we need to remove the LoRA parameters from the state dict
state_dict = model.state_dict()
for name in list(state_dict.keys()):
if "lora" in name:
state_dict.pop(name)
return state_dict
================================================
FILE: fish_speech/scheduler.py
================================================
import math
def get_cosine_schedule_with_warmup_lr_lambda(
current_step: int,
*,
num_warmup_steps: int | float,
num_training_steps: int,
num_cycles: float = 0.5,
final_lr_ratio: float = 0.0,
):
if 0 < num_warmup_steps < 1: # float mode
num_warmup_steps = int(num_warmup_steps * num_training_steps)
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
return max(
final_lr_ratio,
0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
)
def get_constant_schedule_with_warmup_lr_lambda(
current_step: int,
*,
num_warmup_steps: int | float,
num_training_steps: int | None = None,
):
if 0 < num_warmup_steps < 1: # float mode
num_warmup_steps = int(num_warmup_steps * num_training_steps)
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return 1.0
================================================
FILE: fish_speech/text/__init__.py
================================================
from .clean import clean_text
__all__ = ["clean_text"]
================================================
FILE: fish_speech/text/clean.py
================================================
import re
SYMBOLS_MAPPING = {
"‘": "'",
"’": "'",
}
REPLACE_SYMBOL_REGEX = re.compile(
"|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys())
)
EMOJI_REGEX = re.compile(
"["
"\U0001f600-\U0001f64f" # emoticons
"\U0001f300-\U0001f5ff" # symbols & pictographs
"\U0001f680-\U0001f6ff" # transport & map symbols
"\U0001f1e0-\U0001f1ff" # flags (iOS)
"]+",
flags=re.UNICODE,
)
def clean_text(text):
# Clean the text
text = text.strip()
# Replace all chinese symbols with their english counterparts
text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
# Remove emojis
text = EMOJI_REGEX.sub(r"", text)
# Remove continuous periods (...) and commas (,,,)
text = re.sub(r"[,]{2,}", lambda m: m.group()[0], text)
return text
================================================
FILE: fish_speech/tokenizer.py
================================================
import json
import logging
from pathlib import Path
from typing import TYPE_CHECKING, List, Union
import torch
from transformers import AutoTokenizer
if TYPE_CHECKING:
from transformers import PreTrainedTokenizerFast
logger = logging.getLogger(__name__)
# Constants definitions
EOS_TOKEN = "<|endoftext|>"
PAD_TOKEN = "<|pad|>"
IM_START_TOKEN = "<|im_start|>"
IM_END_TOKEN = "<|im_end|>"
PHONEME_START_TOKEN = "<|phoneme_start|>"
PHONEME_END_TOKEN = "<|phoneme_end|>"
MODALITY_TEXT_TOKEN = "<|text|>"
MODALITY_VOICE_TOKEN = "<|voice|>"
MODALITY_INTERLEAVE_TOKEN = "<|interleave|>"
AUDIO_START_TOKEN = "<|audio_start|>"
AUDIO_END_TOKEN = "<|audio_end|>"
AUDIO_EMBED_TOKEN = "<|audio_pad|>"
MODALITY_TOKENS = {
"text": MODALITY_TEXT_TOKEN,
"voice": MODALITY_VOICE_TOKEN,
"interleave": MODALITY_INTERLEAVE_TOKEN,
}
SEMANTIC_TOKEN_TEMPLATE = "<|semantic:{i}|>"
SEMANTIC_TOKENS = [SEMANTIC_TOKEN_TEMPLATE.format(i=i) for i in range(4096)]
ALL_SPECIAL_TOKENS = [
EOS_TOKEN,
PAD_TOKEN,
IM_START_TOKEN,
IM_END_TOKEN,
PHONEME_START_TOKEN,
PHONEME_END_TOKEN,
MODALITY_TEXT_TOKEN,
MODALITY_VOICE_TOKEN,
MODALITY_INTERLEAVE_TOKEN,
AUDIO_START_TOKEN,
AUDIO_END_TOKEN,
AUDIO_EMBED_TOKEN,
*SEMANTIC_TOKENS,
]
class FishTokenizer:
def __init__(self, model_path: str):
self._tokenizer = AutoTokenizer.from_pretrained(model_path)
self.semantic_id_to_token_id = {}
vocab = self._tokenizer.get_vocab()
valid_ids = []
for code_idx in range(4096):
token = SEMANTIC_TOKEN_TEMPLATE.format(i=code_idx)
if token in vocab:
token_id = vocab[token]
self.semantic_id_to_token_id[code_idx] = token_id
valid_ids.append(token_id)
if not valid_ids:
logger.error(
"CRITICAL ERROR: No semantic tokens found in vocab! Audio cannot be synthesized."
)
self.semantic_begin_id = 0
self.semantic_end_id = 0
# Dummy tensor to prevent crash, though generation will fail
self.semantic_map_tensor = torch.zeros(4096, dtype=torch.long)
else:
self.semantic_begin_id = min(valid_ids)
self.semantic_end_id = max(valid_ids)
# Create a lookup tensor to handle potential gaps in token IDs safely
self.semantic_map_tensor = torch.zeros(4096, dtype=torch.long)
for k, v in self.semantic_id_to_token_id.items():
self.semantic_map_tensor[k] = v
logger.info(
f"Loaded Tokenizer. Semantic Range: {self.semantic_begin_id} -> {self.semantic_end_id}"
)
@property
def vocab_size(self):
return self._tokenizer.vocab_size
@property
def pad_token_id(self):
return self._tokenizer.pad_token_id
@property
def eos_token_id(self):
return self._tokenizer.eos_token_id
def get_token_id(self, token: str) -> int:
return self._tokenizer.convert_tokens_to_ids(token)
def encode(
self, text: str, add_special_tokens: bool = False, **kwargs
) -> List[int]:
# [FIX] Force Qwen/Tiktoken backends to parse special tokens inline
import inspect
sig = inspect.signature(self._tokenizer.encode)
if "allowed_special" in sig.parameters and "allowed_special" not in kwargs:
kwargs["allowed_special"] = "all"
return self._tokenizer.encode(
text, add_special_tokens=add_special_tokens, **kwargs
)
def decode(self, tokens: Union[List[int], int], **kwargs) -> str:
return self._tokenizer.decode(tokens, **kwargs)
def save_pretrained(self, path: str):
self._tokenizer.save_pretrained(path)
@classmethod
def from_pretrained(cls, path: str):
return cls(path)
def __getattr__(self, name):
return getattr(self._tokenizer, name)
================================================
FILE: fish_speech/train.py
================================================
import os
os.environ["USE_LIBUV"] = "0"
import sys
from typing import Optional
import hydra
import lightning as L
import pyrootutils
import torch
from lightning import Callback, LightningDataModule, LightningModule, Trainer
from lightning.pytorch.loggers import Logger
from lightning.pytorch.strategies import DDPStrategy
from omegaconf import DictConfig, OmegaConf
os.environ.pop("SLURM_NTASKS", None)
os.environ.pop("SLURM_JOB_NAME", None)
os.environ.pop("SLURM_NTASKS_PER_NODE", None)
# register eval resolver and root
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
# Allow TF32 on Ampere GPUs
torch.set_float32_matmul_precision("high")
torch.backends.cudnn.allow_tf32 = True
# register eval resolver
OmegaConf.register_new_resolver("eval", eval)
import fish_speech.utils as utils
log = utils.RankedLogger(__name__, rank_zero_only=True)
@utils.task_wrapper
def train(cfg: DictConfig) -> tuple[dict, dict]:
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
training.
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
failure. Useful for multiruns, saving info about the crash, etc.
Args:
cfg (DictConfig): Configuration composed by Hydra.
Returns:
Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
""" # noqa: E501
# set seed for random number generators in pytorch, numpy and python.random
if cfg.get("seed"):
L.seed_everything(cfg.seed, workers=False)
if cfg.get("deterministic"):
torch.use_deterministic_algorithms(True)
log.info(f"Instantiating datamodule <{cfg.data._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
log.info(f"Instantiating model <{cfg.model._target_}>")
model: LightningModule = hydra.utils.instantiate(cfg.model)
log.info("Instantiating callbacks...")
callbacks: list[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
log.info("Instantiating loggers...")
logger: list[Logger] = utils.instantiate_loggers(cfg.get("logger"))
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(
cfg.trainer,
callbacks=callbacks,
logger=logger,
)
object_dict = {
"cfg": cfg,
"datamodule": datamodule,
"model": model,
"callbacks": callbacks,
"logger": logger,
"trainer": trainer,
}
if logger:
log.info("Logging hyperparameters!")
utils.log_hyperparameters(object_dict)
if cfg.get("train"):
log.info("Starting training!")
ckpt_path = cfg.get("ckpt_path")
auto_resume = False
resume_ckpt_path = utils.get_latest_checkpoint(cfg.paths.ckpt_dir)
if resume_ckpt_path is not None:
ckpt_path = resume_ckpt_path
auto_resume = True
if ckpt_path is not None:
log.info(f"Resuming from checkpoint: {ckpt_path}")
# resume weights only is disabled for auto-resume
if cfg.get("resume_weights_only") and auto_resume is False:
log.info("Resuming weights only!")
ckpt = torch.load(ckpt_path, map_location=model.device)
if "state_dict" in ckpt:
ckpt = ckpt["state_dict"]
err = model.load_state_dict(ckpt, strict=False)
log.info(f"Error loading state dict: {err}")
ckpt_path = None
trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
train_metrics = trainer.callback_metrics
if cfg.get("test"):
log.info("Starting testing!")
ckpt_path = trainer.checkpoint_callback.best_model_path
if ckpt_path == "":
log.warning("Best ckpt not found! Using current weights for testing...")
ckpt_path = cfg.get("ckpt_path")
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
log.info(f"Best ckpt path: {ckpt_path}")
test_metrics = trainer.callback_metrics
# merge train and test metrics
metric_dict = {**train_metrics, **test_metrics}
return metric_dict, object_dict
@hydra.main(
version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml"
)
def main(cfg: DictConfig) -> Optional[float]:
# train the model
train(cfg)
if __name__ == "__main__":
main()
================================================
FILE: fish_speech/utils/__init__.py
================================================
from .braceexpand import braceexpand
from .context import autocast_exclude_mps
from .file import get_latest_checkpoint
from .instantiators import instantiate_callbacks, instantiate_loggers
from .logger import RankedLogger
from .logging_utils import log_hyperparameters
from .rich_utils import enforce_tags, print_config_tree
from .utils import extras, get_metric_value, set_seed, task_wrapper
__all__ = [
"enforce_tags",
"extras",
"get_metric_value",
"RankedLogger",
"instantiate_callbacks",
"instantiate_loggers",
"log_hyperparameters",
"print_config_tree",
"task_wrapper",
"braceexpand",
"get_latest_checkpoint",
"autocast_exclude_mps",
"set_seed",
]
================================================
FILE: fish_speech/utils/braceexpand.py
================================================
"""
Bash-style brace expansion
Copied from: https://github.com/trendels/braceexpand/blob/main/src/braceexpand/__init__.py
License: MIT
"""
import re
import string
from itertools import chain, product
from typing import Iterable, Iterator, Optional
__all__ = ["braceexpand", "alphabet", "UnbalancedBracesError"]
class UnbalancedBracesError(ValueError):
pass
alphabet = string.ascii_uppercase + string.ascii_lowercase
int_range_re = re.compile(r"^(-?\d+)\.\.(-?\d+)(?:\.\.-?(\d+))?$")
char_range_re = re.compile(r"^([A-Za-z])\.\.([A-Za-z])(?:\.\.-?(\d+))?$")
escape_re = re.compile(r"\\(.)")
def braceexpand(pattern: str, escape: bool = True) -> Iterator[str]:
"""braceexpand(pattern) -> iterator over generated strings
Returns an iterator over the strings resulting from brace expansion
of pattern. This function implements Brace Expansion as described in
bash(1), with the following limitations:
* A pattern containing unbalanced braces will raise an
UnbalancedBracesError exception. In bash, unbalanced braces will either
be partly expanded or ignored.
* A mixed-case character range like '{Z..a}' or '{a..Z}' will not
include the characters '[]^_`' between 'Z' and 'a'.
When escape is True (the default), characters in pattern can be
prefixed with a backslash to cause them not to be interpreted as
special characters for brace expansion (such as '{', '}', ',').
To pass through a a literal backslash, double it ('\\\\').
When escape is False, backslashes in pattern have no special
meaning and will be preserved in the output.
Examples:
>>> from braceexpand import braceexpand
# Integer range
>>> list(braceexpand('item{1..3}'))
['item1', 'item2', 'item3']
# Character range
>>> list(braceexpand('{a..c}'))
['a', 'b', 'c']
# Sequence
>>> list(braceexpand('index.html{,.backup}'))
['index.html', 'index.html.backup']
# Nested patterns
>>> list(braceexpand('python{2.{5..7},3.{2,3}}'))
['python2.5', 'python2.6', 'python2.7', 'python3.2', 'python3.3']
# Prefixing an integer with zero causes all numbers to be padded to
# the same width.
>>> list(braceexpand('{07..10}'))
['07', '08', '09', '10']
# An optional increment can be specified for ranges.
>>> list(braceexpand('{a..g..2}'))
['a', 'c', 'e', 'g']
# Ranges can go in both directions.
>>> list(braceexpand('{4..1}'))
['4', '3', '2', '1']
# Numbers can be negative
>>> list(braceexpand('{2..-1}'))
['2', '1', '0', '-1']
# Unbalanced braces raise an exception.
>>> list(braceexpand('{1{2,3}'))
Traceback (most recent call last):
...
UnbalancedBracesError: Unbalanced braces: '{1{2,3}'
# By default, the backslash is the escape character.
>>> list(braceexpand(r'{1\\{2,3}'))
['1{2', '3']
# Setting 'escape' to False disables backslash escaping.
>>> list(braceexpand(r'\\{1,2}', escape=False))
['\\\\1', '\\\\2']
"""
return (
escape_re.sub(r"\1", s) if escape else s for s in parse_pattern(pattern, escape)
)
def parse_pattern(pattern: str, escape: bool) -> Iterator[str]:
start = 0
pos = 0
bracketdepth = 0
items: list[Iterable[str]] = []
# print 'pattern:', pattern
while pos < len(pattern):
if escape and pattern[pos] == "\\":
pos += 2
continue
elif pattern[pos] == "{":
if bracketdepth == 0 and pos > start:
# print 'literal:', pattern[start:pos]
items.append([pattern[start:pos]])
start = pos
bracketdepth += 1
elif pattern[pos] == "}":
bracketdepth -= 1
if bracketdepth == 0:
# print 'expression:', pattern[start+1:pos]
expr = pattern[start + 1 : pos]
item = parse_expression(expr, escape)
if item is None: # not a range or sequence
items.extend([["{"], parse_pattern(expr, escape), ["}"]])
else:
items.append(item)
start = pos + 1 # skip the closing brace
pos += 1
if bracketdepth != 0: # unbalanced braces
raise UnbalancedBracesError("Unbalanced braces: '%s'" % pattern)
if start < pos:
items.append([pattern[start:]])
return ("".join(item) for item in product(*items))
def parse_expression(expr: str, escape: bool) -> Optional[Iterable[str]]:
int_range_match = int_range_re.match(expr)
if int_range_match:
return make_int_range(*int_range_match.groups())
char_range_match = char_range_re.match(expr)
if char_range_match:
return make_char_range(*char_range_match.groups())
return parse_sequence(expr, escape)
def parse_sequence(seq: str, escape: bool) -> Optional[Iterator[str]]:
# sequence -> chain(*sequence_items)
start = 0
pos = 0
bracketdepth = 0
items: list[Iterable[str]] = []
# print 'sequence:', seq
while pos < len(seq):
if escape and seq[pos] == "\\":
pos += 2
continue
elif seq[pos] == "{":
bracketdepth += 1
elif seq[pos] == "}":
bracketdepth -= 1
elif seq[pos] == "," and bracketdepth == 0:
items.append(parse_pattern(seq[start:pos], escape))
start = pos + 1 # skip the comma
pos += 1
if bracketdepth != 0:
raise UnbalancedBracesError
if not items:
return None
# part after the last comma (may be the empty string)
items.append(parse_pattern(seq[start:], escape))
return chain(*items)
def make_int_range(left: str, right: str, incr: Optional[str] = None) -> Iterator[str]:
if any([s.startswith(("0", "-0")) for s in (left, right) if s not in ("0", "-0")]):
padding = max(len(left), len(right))
else:
padding = 0
step = (int(incr) or 1) if incr else 1
start = int(left)
end = int(right)
r = range(start, end + 1, step) if start < end else range(start, end - 1, -step)
fmt = "%0{}d".format(padding)
return (fmt % i for i in r)
def make_char_range(left: str, right: str, incr: Optional[str] = None) -> str:
step = (int(incr) or 1) if incr else 1
start = alphabet.index(left)
end = alphabet.index(right)
if start < end:
return alphabet[start : end + 1 : step]
else:
end = end or -len(alphabet)
return alphabet[start : end - 1 : -step]
if __name__ == "__main__":
import doctest
import sys
failed, _ = doctest.testmod(optionflags=doctest.IGNORE_EXCEPTION_DETAIL)
if failed:
sys.exit(1)
================================================
FILE: fish_speech/utils/context.py
================================================
from contextlib import nullcontext
import torch
def autocast_exclude_mps(
device_type: str, dtype: torch.dtype
) -> nullcontext | torch.autocast:
return (
nullcontext()
if torch.backends.mps.is_available()
else torch.autocast(device_type, dtype)
)
================================================
FILE: fish_speech/utils/file.py
================================================
import os
from pathlib import Path
from typing import Union
from loguru import logger
from natsort import natsorted
AUDIO_EXTENSIONS = {
".mp3",
".wav",
".flac",
".ogg",
".m4a",
".wma",
".aac",
".aiff",
".aif",
".aifc",
}
VIDEO_EXTENSIONS = {
".mp4",
".avi",
}
def get_latest_checkpoint(path: Path | str) -> Path | None:
# Find the latest checkpoint
ckpt_dir = Path(path)
if ckpt_dir.exists() is False:
return None
ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime)
if len(ckpts) == 0:
return None
return ckpts[-1]
def audio_to_bytes(file_path):
if not file_path or not Path(file_path).exists():
return None
with open(file_path, "rb") as wav_file:
wav = wav_file.read()
return wav
def read_ref_text(ref_text):
path = Path(ref_text)
if path.exists() and path.is_file():
with path.open("r", encoding="utf-8") as file:
return file.read()
return ref_text
def list_files(
path: Union[Path, str],
extensions: set[str] = set(),
recursive: bool = False,
sort: bool = True,
) -> list[Path]:
"""List files in a directory.
Args:
path (Path): Path to the directory.
extensions (set, optional): Extensions to filter. Defaults to None.
recursive (bool, optional): Whether to search recursively. Defaults to False.
sort (bool, optional): Whether to sort the files. Defaults to True.
Returns:
list: List of files.
"""
if isinstance(path, str):
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"Directory {path} does not exist.")
files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
if sort:
files = natsorted(files)
return files
def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
"""
Load a Bert-VITS2 style filelist.
"""
files = set()
results = []
count_duplicated, count_not_found = 0, 0
LANGUAGE_TO_LANGUAGES = {
"zh": ["zh", "en"],
"jp": ["jp", "en"],
"en": ["en"],
}
with open(path, "r", encoding="utf-8") as f:
for line in f.readlines():
splits = line.strip().split("|", maxsplit=3)
if len(splits) != 4:
logger.warning(f"Invalid line: {line}")
continue
filename, speaker, language, text = splits
file = Path(filename)
language = language.strip().lower()
if language == "ja":
language = "jp"
assert language in ["zh", "jp", "en"], f"Invalid language {language}"
languages = LANGUAGE_TO_LANGUAGES[language]
if file in files:
logger.warning(f"Duplicated file: {file}")
count_duplicated += 1
continue
if not file.exists():
logger.warning(f"File not found: {file}")
count_not_found += 1
continue
results.append((file, speaker, languages, text))
if count_duplicated > 0:
logger.warning(f"Total duplicated files: {count_duplicated}")
if count_not_found > 0:
logger.warning(f"Total files not found: {count_not_found}")
return results
================================================
FILE: fish_speech/utils/instantiators.py
================================================
from typing import List
import hydra
from omegaconf import DictConfig
from pytorch_lightning import Callback
from pytorch_lightning.loggers import Logger
from .logger import RankedLogger
log = RankedLogger(__name__, rank_zero_only=True)
def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
"""Instantiates callbacks from config."""
callbacks: List[Callback] = []
if not callbacks_cfg:
log.warning("No callback configs found! Skipping..")
return callbacks
if not isinstance(callbacks_cfg, DictConfig):
raise TypeError("Callbacks config must be a DictConfig!")
for _, cb_conf in callbacks_cfg.items():
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
log.info(f"Instantiating callback <{cb_conf._target_}>")
callbacks.append(hydra.utils.instantiate(cb_conf))
return callbacks
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
"""Instantiates loggers from config."""
logger: List[Logger] = []
if not logger_cfg:
log.warning("No logger configs found! Skipping...")
return logger
if not isinstance(logger_cfg, DictConfig):
raise TypeError("Logger config must be a DictConfig!")
for _, lg_conf in logger_cfg.items():
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
log.info(f"Instantiating logger <{lg_conf._target_}>")
logger.append(hydra.utils.instantiate(lg_conf))
return logger
================================================
FILE: fish_speech/utils/logger.py
================================================
import logging
from typing import Mapping, Optional
from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
class RankedLogger(logging.LoggerAdapter):
"""A multi-GPU-friendly python command line logger."""
def __init__(
self,
name: str = __name__,
rank_zero_only: bool = True,
extra: Optional[Mapping[str, object]] = None,
) -> None:
"""Initializes a multi-GPU-friendly python command line logger that logs on all processes
with their rank prefixed in the log message.
:param name: The name of the logger. Default is ``__name__``.
:param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
:param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
"""
logger = logging.getLogger(name)
super().__init__(logger=logger, extra=extra)
self.rank_zero_only = rank_zero_only
def log(
self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs
) -> None:
"""Delegate a log call to the underlying logger, after prefixing its message with the rank
of the process it's being logged from. If `'rank'` is provided, then the log will only
occur on that rank/process.
:param level: The level to log at. Look at `logging.__init__.py` for more information.
:param msg: The message to log.
:param rank: The rank to log at.
:param args: Additional args to pass to the underlying logging function.
:param kwargs: Any additional keyword args to pass to the underlying logging function.
"""
if self.isEnabledFor(level):
msg, kwargs = self.process(msg, kwargs)
current_rank = getattr(rank_zero_only, "rank", None)
if current_rank is None:
raise RuntimeError(
"The `rank_zero_only.rank` needs to be set before use"
)
msg = rank_prefixed_message(msg, current_rank)
if self.rank_zero_only:
if current_rank == 0:
self.logger.log(level, msg, *args, **kwargs)
else:
if rank is None:
self.logger.log(level, msg, *args, **kwargs)
elif current_rank == rank:
self.logger.log(level, msg, *args, **kwargs)
================================================
FILE: fish_speech/utils/logging_utils.py
================================================
from lightning.pytorch.utilities import rank_zero_only
from fish_speech.utils import logger as log
@rank_zero_only
def log_hyperparameters(object_dict: dict) -> None:
"""Controls which config parts are saved by lightning loggers.
Additionally saves:
- Number of model parameters
"""
hparams = {}
cfg = object_dict["cfg"]
model = object_dict["model"]
trainer = object_dict["trainer"]
if not trainer.logger:
log.warning("Logger not found! Skipping hyperparameter logging...")
return
hparams["model"] = cfg["model"]
# save number of model parameters
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
hparams["model/params/trainable"] = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
hparams["model/params/non_trainable"] = sum(
p.numel() for p in model.parameters() if not p.requires_grad
)
hparams["data"] = cfg["data"]
hparams["trainer"] = cfg["trainer"]
hparams["callbacks"] = cfg.get("callbacks")
hparams["extras"] = cfg.get("extras")
hparams["task_name"] = cfg.get("task_name")
hparams["tags"] = cfg.get("tags")
hparams["ckpt_path"] = cfg.get("ckpt_path")
hparams["seed"] = cfg.get("seed")
# send hparams to all loggers
for logger in trainer.loggers:
logger.log_hyperparams(hparams)
================================================
FILE: fish_speech/utils/rich_utils.py
================================================
from pathlib import Path
from typing import Sequence
import rich
import rich.syntax
import rich.tree
from hydra.core.hydra_config import HydraConfig
from lightning.pytorch.utilities import rank_zero_only
from omegaconf import DictConfig, OmegaConf, open_dict
from rich.prompt import Prompt
from fish_speech.utils import logger as log
@rank_zero_only
def print_config_tree(
cfg: DictConfig,
print_order: Sequence[str] = (
"data",
"model",
"callbacks",
"logger",
"trainer",
"paths",
"extras",
),
resolve: bool = False,
save_to_file: bool = False,
) -> None:
"""Prints content of DictConfig using Rich library and its tree structure.
Args:
cfg (DictConfig): Configuration composed by Hydra.
print_order (Sequence[str], optional): Determines in what order config components are printed.
resolve (bool, optional): Whether to resolve reference fields of DictConfig.
save_to_file (bool, optional): Whether to export config to the hydra output folder.
""" # noqa: E501
style = "dim"
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
queue = []
# add fields from `print_order` to queue
for field in print_order:
(
queue.append(field)
if field in cfg
else log.warning(
f"Field '{field}' not found in config. "
+ f"Skipping '{field}' config printing..."
)
)
# add all the other fields to queue (not specified in `print_order`)
for field in cfg:
if field not in queue:
queue.append(field)
# generate config tree from queue
for field in queue:
branch = tree.add(field, style=style, guide_style=style)
config_group = cfg[field]
if isinstance(config_group, DictConfig):
branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
else:
branch_content = str(config_group)
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
# print config tree
rich.print(tree)
# save config tree to file
if save_to_file:
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
rich.print(tree, file=file)
@rank_zero_only
def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
"""Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501
if not cfg.get("tags"):
if "id" in HydraConfig().cfg.hydra.job:
raise ValueError("Specify tags before launching a multirun!")
log.warning("No tags provided in config. Prompting user to input tags...")
tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
tags = [t.strip() for t in tags.split(",") if t != ""]
with open_dict(cfg):
cfg.tags = tags
log.info(f"Tags: {cfg.tags}")
if save_to_file:
with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
rich.print(cfg.tags, file=file)
================================================
FILE: fish_speech/utils/schema.py
================================================
import base64
import os
import queue
from dataclasses import dataclass
from typing import Literal
import torch
from pydantic import BaseModel, Field, conint, model_validator
from pydantic.functional_validators import SkipValidation
from typing_extensions import Annotated
from fish_speech.content_sequence import TextPart, VQPart
class ServeVQPart(BaseModel):
type: Literal["vq"] = "vq"
codes: SkipValidation[list[list[int]]]
class ServeTextPart(BaseModel):
type: Literal["text"] = "text"
text: str
class ServeAudioPart(BaseModel):
type: Literal["audio"] = "audio"
audio: bytes
class ServeRequest(BaseModel):
# Raw content sequence dict that we can use with ContentSequence(**content)
content: dict
max_new_tokens: int = 600
top_p: float = 0.7
repetition_penalty: float = 1.2
temperature: float = 0.7
streaming: bool = False
num_samples: int = 1
early_stop_threshold: float = 1.0
class ServeVQGANEncodeRequest(BaseModel):
# The audio here should be in wav, mp3, etc
audios: list[bytes]
class ServeVQGANEncodeResponse(BaseModel):
tokens: SkipValidation[list[list[list[int]]]]
class ServeVQGANDecodeRequest(BaseModel):
tokens: SkipValidation[list[list[list[int]]]]
class ServeVQGANDecodeResponse(BaseModel):
# The audio here should be in PCM float16 format
audios: list[bytes]
class ServeReferenceAudio(BaseModel):
audio: bytes
text: str
@model_validator(mode="before")
def decode_audio(cls, values):
audio = values.get("audio")
if (
isinstance(audio, str) and len(audio) > 255
): # Check if audio is a string (Base64)
try:
values["audio"] = base64.b64decode(audio)
except Exception:
# If the audio is not a valid base64 string, we will just ignore it and let the server handle it
pass
return values
def __repr__(self) -> str:
return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
class ServeTTSRequest(BaseModel):
text: str
chunk_length: Annotated[int, conint(ge=100, le=1000, strict=True)] = 200
# Audio format
format: Literal["wav", "pcm", "mp3", "opus"] = "wav"
# Latency mode (used by api.fish.audio; "normal" or "balanced")
latency: Literal["normal", "balanced"] = "normal"
# References audios for in-context learning
references: list[ServeReferenceAudio] = []
# Reference id
# For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
# Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
reference_id: str | None = None
seed: int | None = None
use_memory_cache: Literal["on", "off"] = "off"
# Normalize text for en & zh, this increase stability for numbers
normalize: bool = True
# not usually used below
streaming: bool = False
max_new_tokens: int = 1024
top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.8
repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.1
temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.8
class Config:
# Allow arbitrary types for pytorch related types
arbitrary_types_allowed = True
class AddReferenceRequest(BaseModel):
id: str = Field(..., min_length=1, max_length=255, pattern=r"^[a-zA-Z0-9\-_ ]+$")
audio: bytes
text: str = Field(..., min_length=1)
class AddReferenceResponse(BaseModel):
success: bool
message: str
reference_id: str
class ListReferencesResponse(BaseModel):
success: bool
reference_ids: list[str]
message: str = "Success"
class DeleteReferenceResponse(BaseModel):
success: bool
message: str
reference_id: str
class UpdateReferenceResponse(BaseModel):
success: bool
message: str
old_reference_id: str
new_reference_id: str
================================================
FILE: fish_speech/utils/spectrogram.py
================================================
import torch
import torchaudio.functional as F
from torch import Tensor, nn
from torchaudio.transforms import MelScale
class LinearSpectrogram(nn.Module):
def __init__(
self,
n_fft=2048,
win_length=2048,
hop_length=512,
center=False,
mode="pow2_sqrt",
):
super().__init__()
self.n_fft = n_fft
self.win_length = win_length
self.hop_length = hop_length
self.center = center
self.mode = mode
self.return_complex = True
self.register_buffer("window", torch.hann_window(win_length), persistent=False)
def forward(self, y: Tensor) -> Tensor:
if y.ndim == 3:
y = y.squeeze(1)
y = torch.nn.functional.pad(
y.unsqueeze(1),
(
(self.win_length - self.hop_length) // 2,
(self.win_length - self.hop_length + 1) // 2,
),
mode="reflect",
).squeeze(1)
spec = torch.stft(
y,
self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
center=self.center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=self.return_complex,
)
if self.return_complex:
spec = torch.view_as_real(spec)
if self.mode == "pow2_sqrt":
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
return spec
class LogMelSpectrogram(nn.Module):
def __init__(
self,
sample_rate=44100,
n_fft=2048,
win_length=2048,
hop_length=512,
n_mels=128,
center=False,
f_min=0.0,
f_max=None,
):
super().__init__()
self.sample_rate = sample_rate
self.n_fft = n_fft
self.win_length = win_length
self.hop_length = hop_length
self.center = center
self.n_mels = n_mels
self.f_min = f_min
self.f_max = f_max or float(sample_rate // 2)
self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
fb = F.melscale_fbanks(
n_freqs=self.n_fft // 2 + 1,
f_min=self.f_min,
f_max=self.f_max,
n_mels=self.n_mels,
sample_rate=self.sample_rate,
norm="slaney",
mel_scale="slaney",
)
self.register_buffer(
"fb",
fb,
persistent=False,
)
def compress(self, x: Tensor) -> Tensor:
return torch.log(torch.clamp(x, min=1e-5))
def decompress(self, x: Tensor) -> Tensor:
return torch.exp(x)
def apply_mel_scale(self, x: Tensor) -> Tensor:
return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2)
def forward(
self, x: Tensor, return_linear: bool = False, sample_rate: int = None
) -> Tensor:
if sample_rate is not None and sample_rate != self.sample_rate:
x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate)
linear = self.spectrogram(x)
x = self.apply_mel_scale(linear)
x = self.compress(x)
if return_linear:
return x, self.compress(linear)
return x
================================================
FILE: fish_speech/utils/utils.py
================================================
import random
import warnings
from importlib.util import find_spec
from typing import Callable
import numpy as np
import torch
from omegaconf import DictConfig
from .logger import RankedLogger
from .rich_utils import enforce_tags, print_config_tree
log = RankedLogger(__name__, rank_zero_only=True)
def extras(cfg: DictConfig) -> None:
"""Applies optional utilities before the task is started.
Utilities:
- Ignoring python warnings
- Setting tags from command line
- Rich config printing
"""
# return if no `extras` config
if not cfg.get("extras"):
log.warning("Extras config not found! ")
return
# disable python warnings
if cfg.extras.get("ignore_warnings"):
log.info("Disabling python warnings! ")
warnings.filterwarnings("ignore")
# prompt user to input tags from command line if none are provided in the config
if cfg.extras.get("enforce_tags"):
log.info("Enforcing tags! ")
enforce_tags(cfg, save_to_file=True)
# pretty print config tree using Rich library
if cfg.extras.get("print_config"):
log.info("Printing config tree with Rich! ")
print_config_tree(cfg, resolve=True, save_to_file=True)
def task_wrapper(task_func: Callable) -> Callable:
"""Optional decorator that controls the failure behavior when executing the task function.
This wrapper can be used to:
- make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
- save the exception to a `.log` file
- mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
- etc. (adjust depending on your needs)
Example:
```
@utils.task_wrapper
def train(cfg: DictConfig) -> Tuple[dict, dict]:
...
return metric_dict, object_dict
```
""" # noqa: E501
def wrap(cfg: DictConfig):
# execute the task
try:
metric_dict, object_dict = task_func(cfg=cfg)
# things to do if exception occurs
except Exception as ex:
# save exception to `.log` file
log.exception("")
# some hyperparameter combinations might be invalid or
# cause out-of-memory errors so when using hparam search
# plugins like Optuna, you might want to disable
# raising the below exception to avoid multirun failure
raise ex
# things to always do after either success or exception
finally:
# display output dir path in terminal
log.info(f"Output dir: {cfg.paths.run_dir}")
# always close wandb run (even if exception occurs so multirun won't fail)
if find_spec("wandb"): # check if wandb is installed
import wandb
if wandb.run:
log.info("Closing wandb!")
wandb.finish()
return metric_dict, object_dict
return wrap
def get_metric_value(metric_dict: dict, metric_name: str) -> float:
"""Safely retrieves value of the metric logged in LightningModule."""
if not metric_name:
log.info("Metric name is None! Skipping metric value retrieval...")
return None
if metric_name not in metric_dict:
raise Exception(
f"Metric value not found! \n"
"Make sure metric name logged in LightningModule is correct!\n"
"Make sure `optimized_metric` name in `hparams_search` config is correct!"
)
metric_value = metric_dict[metric_name].item()
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
return metric_value
def set_seed(seed: int):
if seed < 0:
seed = -seed
if seed > (1 << 31):
seed = 1 << 31
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if torch.backends.cudnn.is_available():
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
================================================
FILE: inference.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Fish Speech"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### For Windows User / win用户"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "bat"
}
},
"outputs": [],
"source": [
"!chcp 65001"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### For Linux User / Linux 用户"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import locale\n",
"locale.setlocale(locale.LC_ALL, 'en_US.UTF-8')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Prepare Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# For Chinese users, you probably want to use mirror to accelerate downloading\n",
"# !set HF_ENDPOINT=https://hf-mirror.com\n",
"# !export HF_ENDPOINT=https://hf-mirror.com \n",
"\n",
"!hf download fishaudio/openaudio-s1-mini --local-dir checkpoints/openaudio-s1-mini/"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## WebUI Inference\n",
"\n",
"> You can use --compile to fuse CUDA kernels for faster inference (10x)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "shellscript"
}
},
"outputs": [],
"source": [
"!python tools/run_webui.py \\\n",
" --llama-checkpoint-path checkpoints/openaudio-s1-mini \\\n",
" --decoder-checkpoint-path checkpoints/openaudio-s1-mini/codec.pth \\\n",
" # --compile"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Break-down CLI Inference"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1. Encode reference audio: / 从语音生成 prompt: \n",
"\n",
"You should get a `fake.npy` file.\n",
"\n",
"你应该能得到一个 `fake.npy` 文件."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "shellscript"
}
},
"outputs": [],
"source": [
"## Enter the path to the audio file here\n",
"src_audio = r\"D:\\PythonProject\\vo_hutao_draw_appear.wav\"\n",
"\n",
"!python fish_speech/models/dac/inference.py \\\n",
" -i {src_audio} \\\n",
" --checkpoint-path \"checkpoints/openaudio-s1-mini/codec.pth\"\n",
"\n",
"from IPython.display import Audio, display\n",
"audio = Audio(filename=\"fake.wav\")\n",
"display(audio)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2. Generate semantic tokens from text: / 从文本生成语义 token:\n",
"\n",
"> This command will create a codes_N file in the working directory, where N is an integer starting from 0.\n",
"\n",
"> You may want to use `--compile` to fuse CUDA kernels for faster inference (~30 tokens/second -> ~300 tokens/second).\n",
"\n",
"> 该命令会在工作目录下创建 codes_N 文件, 其中 N 是从 0 开始的整数.\n",
"\n",
"> 您可以使用 `--compile` 来融合 cuda 内核以实现更快的推理 (~30 tokens/秒 -> ~300 tokens/秒)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "shellscript"
}
},
"outputs": [],
"source": [
"!python fish_speech/models/text2semantic/inference.py \\\n",
" --text \"hello world\" \\\n",
" --prompt-text \"The text corresponding to reference audio\" \\\n",
" --prompt-tokens \"fake.npy\" \\\n",
" --checkpoint-path \"checkpoints/openaudio-s1-mini\" \\\n",
" --num-samples 2\n",
" # --compile"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3. Generate speech from semantic tokens: / 从语义 token 生成人声:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "shellscript"
}
},
"outputs": [],
"source": [
"!python fish_speech/models/dac/inference.py \\\n",
" -i \"codes_0.npy\" \\\n",
" --checkpoint-path \"checkpoints/openaudio-s1-mini/codec.pth\"\n",
"\n",
"from IPython.display import Audio, display\n",
"audio = Audio(filename=\"fake.wav\")\n",
"display(audio)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
================================================
FILE: mkdocs.yml
================================================
site_name: Fish Audio
site_description: Targeting SOTA TTS solutions.
site_url: https://speech.fish.audio
# Repository
repo_name: fishaudio/fish-speech
repo_url: https://github.com/fishaudio/fish-speech
edit_uri: blob/main/docs
# Copyright
copyright: Copyright © 2023-2025 by Fish Audio
theme:
name: material
favicon: assets/logo.svg
language: en
features:
- content.action.edit
- content.action.view
- navigation.tracking
- navigation.footer
# - navigation.tabs
- search
- search.suggest
- search.highlight
- search.share
- content.code.copy
logo: assets/logo.svg
palette:
# Palette toggle for automatic mode
- media: "(prefers-color-scheme)"
toggle:
icon: material/brightness-auto
name: Switch to light mode
# Palette toggle for light mode
- media: "(prefers-color-scheme: light)"
scheme: default
toggle:
icon: material/brightness-7
name: Switch to dark mode
primary: black
font:
code: Roboto Mono
# Palette toggle for dark mode
- media: "(prefers-color-scheme: dark)"
scheme: slate
toggle:
icon: material/brightness-4
name: Switch to light mode
primary: black
font:
code: Roboto Mono
nav:
- Introduction: en/index.md
- Installation: en/install.md
- Finetune: en/finetune.md
- Inference: en/inference.md
- Server: en/server.md
- Samples: en/samples.md
# Plugins
plugins:
- search:
separator: '[\s\-,:!=\[\]()"`/]+|\.(?!\d)|&[lg]t;|(?!\b)(?=[A-Z][a-z])'
lang:
- en
- zh
- ja
- pt
- ko
- ar
- i18n:
docs_structure: folder
languages:
- locale: en
name: English
default: true
build: true
- locale: zh
name: 简体中文
build: true
nav:
- 介绍: zh/index.md
- 安装: zh/install.md
- 微调: zh/finetune.md
- 推理: zh/inference.md
- 示例: zh/samples.md
- locale: ja
name: 日本語
build: true
nav:
- はじめに: ja/index.md
- インストール: ja/install.md
- ファインチューニング: ja/finetune.md
- 推論: ja/inference.md
- サンプル: ja/samples.md
- locale: pt
name: Português (Brasil)
build: true
nav:
- Introdução: pt/index.md
- Instalação: pt/install.md
- Ajuste Fino: pt/finetune.md
- Inferência: pt/inference.md
- Amostras: pt/samples.md
- locale: ko
name: 한국어
build: true
nav:
- 소개: ko/index.md
- 설치: ko/install.md
- 파인튜닝: ko/finetune.md
- 추론: ko/inference.md
- 샘플: ko/samples.md
- locale: ar
name: العربية
build: true
nav:
- مقدمة: ar/index.md
- التثبيت: ar/install.md
- الضبط الدقيق: ar/finetune.md
- الاستنتاج: ar/inference.md
- العينات: ar/samples.md
markdown_extensions:
- pymdownx.highlight:
anchor_linenums: true
line_spans: __span
pygments_lang_class: true
- pymdownx.inlinehilite
- pymdownx.snippets
- pymdownx.superfences
- admonition
- pymdownx.details
- pymdownx.superfences
- attr_list
- md_in_html
- pymdownx.superfences
extra_css:
- stylesheets/extra.css
extra:
social:
- icon: fontawesome/brands/discord
link: https://discord.gg/Es5qTB9BcN
- icon: fontawesome/brands/docker
link: https://hub.docker.com/r/fishaudio/fish-speech
- icon: fontawesome/brands/qq
link: http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=jCKlUP7QgSm9kh95UlBoYv6s1I-Apl1M&authKey=xI5ttVAp3do68IpEYEalwXSYZFdfxZSkah%2BctF5FIMyN2NqAa003vFtLqJyAVRfF&noverify=0&group_code=593946093
homepage: https://speech.fish.audio
================================================
FILE: pyproject.toml
================================================
[project]
name = "fish-speech"
version = "2.0.0"
authors = [
{name = "Fish Audio", email = "oss@fish.audio"},
]
description = "Fish Speech"
readme = "README.md"
requires-python = ">=3.10"
keywords = ["TTS", "Speech"]
license = {text = "Fish Audio Research License"}
classifiers = [
"Programming Language :: Python :: 3",
]
dependencies = [
"numpy",
"torch==2.8.0",
"torchaudio==2.8.0",
"transformers<=4.57.3",
"datasets==2.18.0",
"lightning>=2.1.0",
"hydra-core>=1.3.2",
"tensorboard>=2.14.1",
"natsort>=8.4.0",
"einops>=0.7.0",
"librosa>=0.10.1",
"rich>=13.5.3",
"gradio>5.0.0",
"wandb>=0.15.11",
"grpcio>=1.58.0",
"kui>=1.6.0",
"uvicorn>=0.30.0",
"loguru>=0.6.0",
"loralib>=0.1.2",
"pyrootutils>=1.0.4",
"resampy>=0.4.3",
"einx[torch]==0.2.2",
"zstandard>=0.22.0",
"pydub",
"pyaudio",
"modelscope==1.17.1",
"opencc-python-reimplemented==0.1.7",
"silero-vad",
"ormsgpack",
"tiktoken>=0.8.0",
"pydantic==2.9.2",
"cachetools",
"descript-audio-codec",
"descript-audiotools",
"safetensors"
]
[project.optional-dependencies]
stable = [
"torch==2.8.0",
"torchaudio",
]
cpu = [
"torch==2.8.0",
"torchaudio",
]
cu126 = [
"torch==2.8.0",
"torchaudio",
]
cu128 = [
"torch==2.8.0",
"torchaudio",
]
cu129 = [
"torch==2.8.0",
"torchaudio",
]
[tool.uv]
conflicts = [
[
{ extra = "cpu" },
{ extra = "cu126" },
{ extra = "cu128" },
{ extra = "cu129" },
],
]
[tool.uv.sources]
torch = [
{ index = "pytorch-cpu", extra = "cpu" },
{ index = "pytorch-cu126", extra = "cu126" },
{ index = "pytorch-cu128", extra = "cu128" },
{ index = "pytorch-cu129", extra = "cu129" },
]
torchaudio = [
{ index = "pytorch-cpu", extra = "cpu" },
{ index = "pytorch-cu126", extra = "cu126" },
{ index = "pytorch-cu128", extra = "cu128" },
{ index = "pytorch-cu129", extra = "cu129" },
]
[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true
[[tool.uv.index]]
name = "pytorch-cu126"
url = "https://download.pytorch.org/whl/cu126"
explicit = true
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true
[[tool.uv.index]]
name = "pytorch-cu129"
url = "https://download.pytorch.org/whl/cu129"
explicit = true
[build-system]
requires = ["setuptools", "setuptools-scm"]
build-backend = "setuptools.build_meta"
[tool.setuptools]
packages = ["fish_speech", "tools"]
[tool.setuptools_scm]
================================================
FILE: pyrightconfig.json
================================================
{
"exclude": [
"data",
"filelists"
]
}
================================================
FILE: tools/api_client.py
================================================
import argparse
import base64
import time
import wave
import ormsgpack
import pyaudio
import requests
from pydub import AudioSegment
from pydub.playback import play
from fish_speech.utils.file import audio_to_bytes, read_ref_text
from fish_speech.utils.schema import ServeReferenceAudio, ServeTTSRequest
def parse_args():
parser = argparse.ArgumentParser(
description="Send a WAV file and text to a server and receive synthesized audio.",
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"--url",
"-u",
type=str,
default="http://127.0.0.1:8080/v1/tts",
help="URL of the server",
)
parser.add_argument(
"--text", "-t", type=str, required=True, help="Text to be synthesized"
)
parser.add_argument(
"--reference_id",
"-id",
type=str,
default=None,
help="ID of the reference model to be used for the speech\n(Local: name of folder containing audios and files)",
)
parser.add_argument(
"--reference_audio",
"-ra",
type=str,
nargs="+",
default=None,
help="Path to the audio file",
)
parser.add_argument(
"--reference_text",
"-rt",
type=str,
nargs="+",
default=None,
help="Reference text for voice synthesis",
)
parser.add_argument(
"--output",
"-o",
type=str,
default="generated_audio",
help="Output audio file name",
)
parser.add_argument(
"--play",
action=argparse.BooleanOptionalAction,
default=True,
help="Whether to play audio after receiving data",
)
parser.add_argument(
"--format", type=str, choices=["wav", "pcm", "mp3", "opus"], default="wav"
)
parser.add_argument(
"--latency",
type=str,
default="normal",
choices=["normal", "balanced"],
help="Used in api.fish.audio/v1/tts",
)
parser.add_argument(
"--max_new_tokens",
type=int,
default=1024,
help="Maximum new tokens to generate. \n0 means no limit.",
)
parser.add_argument(
"--chunk_length", type=int, default=300, help="Chunk length for synthesis"
)
parser.add_argument(
"--top_p", type=float, default=0.8, help="Top-p sampling for synthesis"
)
parser.add_argument(
"--repetition_penalty",
type=float,
default=1.1,
help="Repetition penalty for synthesis",
)
parser.add_argument(
"--temperature", type=float, default=0.8, help="Temperature for sampling"
)
# parser.add_argument("--streaming", type=bool, default=False, help="Enable streaming response")
parser.add_argument(
"--streaming", action="store_true", help="Enable streaming response"
)
parser.add_argument(
"--channels", type=int, default=1, help="Number of audio channels"
)
parser.add_argument("--rate", type=int, default=44100, help="Sample rate for audio")
parser.add_argument(
"--use_memory_cache",
type=str,
default="off",
choices=["on", "off"],
help="Cache encoded references codes in memory.\n",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="`None` means randomized inference, otherwise deterministic.\nIt can't be used for fixing a timbre.",
)
parser.add_argument(
"--api_key",
type=str,
default="YOUR_API_KEY",
help="API key for authentication",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
idstr: str | None = args.reference_id
# priority: ref_id > [{text, audio},...]
if idstr is None:
ref_audios = args.reference_audio
ref_texts = args.reference_text
if ref_audios is None:
byte_audios = []
else:
byte_audios = [audio_to_bytes(ref_audio) for ref_audio in ref_audios]
if ref_texts is None:
ref_texts = []
else:
ref_texts = [read_ref_text(ref_text) for ref_text in ref_texts]
else:
byte_audios = []
ref_texts = []
pass # in api.py
data = {
"text": args.text,
"references": [
ServeReferenceAudio(
audio=ref_audio if ref_audio is not None else b"", text=ref_text
)
for ref_text, ref_audio in zip(ref_texts, byte_audios)
],
"reference_id": idstr,
"format": args.format,
"latency": args.latency,
"max_new_tokens": args.max_new_tokens,
"chunk_length": args.chunk_length,
"top_p": args.top_p,
"repetition_penalty": args.repetition_penalty,
"temperature": args.temperature,
"streaming": args.streaming,
"use_memory_cache": args.use_memory_cache,
"seed": args.seed,
}
pydantic_data = ServeTTSRequest(**data)
print("Sending request")
start_time = time.time()
response = requests.post(
args.url,
params={"format": "msgpack"},
data=ormsgpack.packb(pydantic_data, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
stream=args.streaming,
headers={
"authorization": f"Bearer {args.api_key}",
"content-type": "application/msgpack",
},
)
end_time = time.time()
print(f"Request took {end_time - start_time} seconds")
if response.status_code == 200:
if args.streaming:
p = pyaudio.PyAudio()
audio_format = pyaudio.paInt16 # Assuming 16-bit PCM format
stream = p.open(
format=audio_format, channels=args.channels, rate=args.rate, output=True
)
wf = wave.open(f"{args.output}.wav", "wb")
wf.setnchannels(args.channels)
wf.setsampwidth(p.get_sample_size(audio_format))
wf.setframerate(args.rate)
stream_stopped_flag = False
try:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
stream.write(chunk)
wf.writeframesraw(chunk)
else:
if not stream_stopped_flag:
stream.stop_stream()
stream_stopped_flag = True
finally:
stream.close()
p.terminate()
wf.close()
else:
audio_content = response.content
audio_path = f"{args.output}.{args.format}"
with open(audio_path, "wb") as audio_file:
audio_file.write(audio_content)
audio = AudioSegment.from_file(audio_path, format=args.format)
if args.play:
play(audio)
print(f"Audio has been saved to '{audio_path}'.")
else:
print(f"Request failed with status code {response.status_code}")
print(response.json())
================================================
FILE: tools/api_server.py
================================================
import re
from threading import Lock
import pyrootutils
import uvicorn
from kui.asgi import (
Depends,
FactoryClass,
HTTPException,
HttpRoute,
Kui,
OpenAPI,
Routes,
)
from kui.cors import CORSConfig
from kui.openapi.specification import Info
from kui.security import bearer_auth
from loguru import logger
from typing_extensions import Annotated
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from tools.server.api_utils import MsgPackRequest, parse_args
from tools.server.exception_handler import ExceptionHandler
from tools.server.model_manager import ModelManager
from tools.server.views import routes
class API(ExceptionHandler):
def __init__(self):
self.args = parse_args()
def api_auth(endpoint):
async def verify(token: Annotated[str, Depends(bearer_auth)]):
if token != self.args.api_key:
raise HTTPException(401, None, "Invalid token")
return await endpoint()
async def passthrough():
return await endpoint()
if self.args.api_key is not None:
return verify
else:
return passthrough
self.routes = Routes(
routes, # keep existing routes
http_middlewares=[api_auth], # apply api_auth middleware
)
# OpenAPIの設定
self.openapi = OpenAPI(
Info(
{
"title": "Fish Speech API",
"version": "1.5.0",
}
),
).routes
# Initialize the app
self.app = Kui(
routes=self.routes + self.openapi[1:], # Remove the default route
exception_handlers={
HTTPException: self.http_exception_handler,
Exception: self.other_exception_handler,
},
factory_class=FactoryClass(http=MsgPackRequest),
cors_config=CORSConfig(),
)
# Add the state variables
self.app.state.lock = Lock()
self.app.state.device = self.args.device
self.app.state.max_text_length = self.args.max_text_length
# Associate the app with the model manager
self.app.on_startup(self.initialize_app)
async def initialize_app(self, app: Kui):
# Make the ModelManager available to the views
app.state.model_manager = ModelManager(
mode=self.args.mode,
device=self.args.device,
half=self.args.half,
compile=self.args.compile,
llama_checkpoint_path=self.args.llama_checkpoint_path,
decoder_checkpoint_path=self.args.decoder_checkpoint_path,
decoder_config_name=self.args.decoder_config_name,
)
logger.info(f"Startup done, listening server at http://{self.args.listen}")
# Each worker process created by Uvicorn has its own memory space,
# meaning that models and variables are not shared between processes.
# Therefore, any variables (like `llama_queue` or `decoder_model`)
# will not be shared across workers.
# Multi-threading for deep learning can cause issues, such as inconsistent
# outputs if multiple threads access the same buffers simultaneously.
# Instead, it's better to use multiprocessing or independent models per thread.
if __name__ == "__main__":
api = API()
# IPv6 address format is [xxxx:xxxx::xxxx]:port
match = re.search(r"\[([^\]]+)\]:(\d+)$", api.args.listen)
if match:
host, port = match.groups() # IPv6
else:
host, port = api.args.listen.split(":") # IPv4
uvicorn.run(
api.app,
host=host,
port=int(port),
workers=api.args.workers,
log_level="info",
)
================================================
FILE: tools/llama/build_dataset.py
================================================
import itertools
import os
import re
from collections import defaultdict
from functools import partial
from multiprocessing import Pool
from pathlib import Path
import click
import numpy as np
from loguru import logger
from tqdm import tqdm
from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
from fish_speech.utils.file import load_filelist
# To avoid CPU overload
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
def task_generator_folder(root: Path, text_extension: str):
files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
files = sorted(files)
grouped_files = defaultdict(list)
for file in tqdm(files, desc=f"Grouping {root}"):
p = str(file.parent)
speaker = file.parent.name
try:
if isinstance(text_extension, str):
texts = [file.with_suffix(text_extension).read_text(encoding="utf-8")]
else:
texts = [
file.with_suffix(ext).read_text(encoding="utf-8")
for ext in text_extension
]
except Exception as e:
logger.error(f"Failed to read text {file}: {e}")
continue
grouped_files[p].append((speaker, file, texts))
logger.info(
f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
)
for i in grouped_files.values():
subset = [(f, t) for _, f, t in i]
yield i[0][0], subset, "folder"
def task_generator_filelist(filelist):
grouped_files = defaultdict(list)
for filename, speaker, _, text in load_filelist(filelist):
grouped_files[speaker].append((Path(filename), [text]))
logger.info(f"Found {len(grouped_files)} groups in {filelist}")
for speaker, values in grouped_files.items():
yield speaker, values, "filelist"
def run_task(task):
name, subset, source = task
# Parse the files
sentences = []
for file, texts in subset:
np_file = file.with_suffix(".npy")
if np_file.exists() is False:
logger.warning(f"Can't find {np_file}")
continue
new_texts = []
for text in texts:
# Simple cleaning: replace { xxx } and < xxx > with space
text = re.sub(r"\{.*?\}", " ", text)
text = re.sub(r"<.*?>", " ", text)
text = re.sub(r"\s+", " ", text)
new_texts.append(text)
try:
semantics = np.load(np_file)
except Exception as e:
logger.error(f"Failed to parse {file}: {e}")
continue
if isinstance(semantics, np.ndarray):
semantics = semantics.tolist()
sentences.append(
Sentence(
texts=new_texts,
semantics=[Semantics(values=s) for s in semantics],
)
)
# Pack the sentences
return pack_pb_stream(
TextData(
source=source,
name=name,
sentences=sentences,
)
)
@click.command()
@click.option(
"--input",
type=click.Path(path_type=Path),
required=True,
help="A folder containing the dataset or a filelist",
multiple=True,
)
@click.option(
"--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
)
@click.option("--num-workers", type=int, default=16)
@click.option("--text-extension", type=str, default=[".txt"], multiple=True)
@click.option(
"--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
)
def main(input, output, num_workers, text_extension, shard_size):
generator_fns = []
for f in input:
assert f.exists(), f"{f} not found"
if f.is_dir():
generator_fn = task_generator_folder(f, text_extension)
else:
generator_fn = task_generator_filelist(f)
generator_fns.append(generator_fn)
generator_fn = itertools.chain(*generator_fns)
output.mkdir(parents=True, exist_ok=True)
dataset_fp = None
tar_idx = 0
written_size = 0
with Pool(num_workers) as p:
for result in tqdm(p.imap_unordered(run_task, generator_fn)):
if dataset_fp is None:
dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")
dataset_fp.write(result)
written_size += len(result)
if written_size > shard_size * 1024 * 1024:
logger.info(f"Finished writing {tar_idx} shards to {output}")
dataset_fp.close()
dataset_fp = None
written_size = 0
tar_idx += 1
if dataset_fp is not None:
dataset_fp.close()
logger.info(f"Finished writing {tar_idx + 1} shards to {output}")
if __name__ == "__main__":
main()
================================================
FILE: tools/llama/eval_in_context.py
================================================
import pyrootutils
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from transformers import AutoTokenizer
# register eval resolver and root
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from torch.utils.data import DataLoader
from fish_speech.datasets.semantic import AutoAugTextDataset, TextDataCollator
from fish_speech.models.text2semantic.inference import load_model
def smooth(
scalars: list[float], weight: float
) -> list[float]: # Weight between 0 and 1
last = scalars[0] # First value in the plot (first timestep)
smoothed = list()
for point in scalars:
smoothed_val = last * weight + (1 - weight) * point # Calculate smoothed value
smoothed.append(smoothed_val) # Save it
last = smoothed_val # Anchor the last smoothed value
return smoothed
@torch.inference_mode()
def analyze_one_model(loader, config, weight, max_length):
device = "cuda" if torch.cuda.is_available() else "cpu"
model = load_model(
config,
weight,
device,
torch.bfloat16,
max_length,
compile=False,
)[0]
current_step = 0
model.eval()
semantic_loss_sum = torch.zeros(
max_length,
dtype=torch.float32,
device=device,
)
counter = torch.zeros(
max_length,
dtype=torch.long,
device=device,
)
for batch in loader:
batch = {k: v.to(device) for k, v in batch.items()}
labels = batch["labels"]
outputs = model(
inp=batch["inputs"],
key_padding_mask=batch["attention_masks"],
)
token_logits = outputs.token_logits
codebook_logits = outputs.codebook_logits
# Generate labels
base_loss = F.cross_entropy(
token_logits.reshape(-1, token_logits.size(-1)),
labels[:, 0].reshape(-1),
ignore_index=-100,
reduction="none",
)
codebook_labels = labels[:, 1 : 1 + model.config.num_codebooks].mT
semantic_loss = F.cross_entropy(
codebook_logits.reshape(-1, codebook_logits.size(-1)),
codebook_labels.reshape(-1),
ignore_index=-100,
reduction="none",
)
base_loss = base_loss.reshape(labels[:, 0].shape)
semantic_loss = semantic_loss.reshape(codebook_labels.shape)
semantic_loss_frame = semantic_loss.mean(-1)
pad_pos = codebook_labels.sum(-1) == -100 * model.config.num_codebooks
for loss_sample, pad in zip(semantic_loss_frame, pad_pos):
semantic_loss_sum[~pad] += loss_sample[~pad]
counter[~pad] += 1
current_step += 1
if current_step == 10:
break
semantic_loss = semantic_loss.cpu()
counter = counter.cpu()
xs, ys = [], []
for i, (loss, count) in enumerate(zip(semantic_loss_sum, counter)):
if count > 0:
xs.append(i)
ys.append((loss / count).item()) # for better loss visualization
smoothed_ys = smooth(ys, 0.95)
# Unload model
del model
torch.cuda.empty_cache()
return xs, ys, smoothed_ys
def main():
tokenizer = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
max_length = 4096
ds = AutoAugTextDataset(
["data/protos/sft/云天河"],
tokenizer=tokenizer,
use_speaker=False,
interactive_prob=1.0,
max_length=max_length,
)
loader = DataLoader(
ds,
batch_size=8,
collate_fn=TextDataCollator(tokenizer, max_length=max_length),
num_workers=0,
shuffle=False,
)
plt.figure(figsize=(10, 5), dpi=200)
plt.xlabel("Frame")
plt.ylabel("Loss")
plt.yscale("log")
plt.title("Semantic Loss")
plt.grid(which="both", axis="both")
plt.xlim(0, max_length)
tests = [
(
"pertrain-medium",
"dual_ar_2_codebook_medium",
"checkpoints/text2semantic-pretrain-medium-2k-v1.pth",
),
(
"sft-medium",
"dual_ar_2_codebook_medium",
"checkpoints/text2semantic-sft-medium-v1.1-4k.pth",
),
(
"sft-large",
"dual_ar_2_codebook_large",
"checkpoints/text2semantic-sft-large-v1.1-4k.pth",
),
]
for name, config, weight in tests:
xs, _, smoothed_ys = analyze_one_model(loader, config, weight, max_length)
plt.plot(xs, smoothed_ys, label=name)
plt.legend()
plt.savefig("semantic_loss.png")
if __name__ == "__main__":
main()
================================================
FILE: tools/llama/merge_lora.py
================================================
import shutil
from copy import deepcopy
from pathlib import Path
import click
import hydra
import torch
from hydra import compose, initialize
from hydra.utils import instantiate
from loguru import logger
from fish_speech.models.text2semantic.llama import BaseTransformer
from fish_speech.models.text2semantic.lora import get_merged_state_dict
@click.command()
@click.option("--lora-config", type=str, default="r_8_alpha_16")
@click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.4")
@click.option("--lora-weight", type=str, required=True)
@click.option("--output", type=str, required=True)
def merge(lora_config, base_weight, lora_weight, output):
output = Path(output)
logger.info(
f"Merging {base_weight} and {lora_weight} into {output} with {lora_config}"
)
with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"):
cfg = compose(config_name=lora_config)
lora_config = instantiate(cfg)
logger.info(f"Loaded lora model with config {lora_config}")
llama_model = BaseTransformer.from_pretrained(
path=base_weight,
load_weights=True,
lora_config=lora_config,
)
logger.info(f"Loaded llama model")
llama_state_dict = llama_model.state_dict()
llama_state_dict = {k: v for k, v in llama_state_dict.items() if "lora" not in k}
llama_state_dict_copy = deepcopy(llama_state_dict)
lora_state_dict = torch.load(lora_weight, map_location="cpu", weights_only=False)
if "state_dict" in llama_state_dict:
llama_state_dict = llama_state_dict["state_dict"]
if "state_dict" in lora_state_dict:
lora_state_dict = lora_state_dict["state_dict"]
# remove prefix model.
if any(k.startswith("model.") for k in llama_state_dict.keys()):
llama_state_dict = {
k.replace("model.", ""): v
for k, v in llama_state_dict.items()
if k.startswith("model.")
}
if any(k.startswith("model.") for k in lora_state_dict.keys()):
lora_state_dict = {
k.replace("model.", ""): v
for k, v in lora_state_dict.items()
if k.startswith("model.")
}
logger.info(f"Found {len(llama_state_dict)} keys in llama model")
logger.info(f"Found {len(lora_state_dict)} keys in lora model")
merged_state_dict = llama_state_dict | lora_state_dict
llama_model.load_state_dict(merged_state_dict, strict=True)
logger.info(f"Merged model loaded")
# Trigger eval mode to merge lora
llama_model.eval()
llama_model.save_pretrained(output, drop_lora=True)
logger.info(f"Saved merged model to {output}, validating")
new_state_dict = torch.load(output / "model.pth", map_location="cpu")
original_keys = set(llama_state_dict_copy.keys())
tolerance = 1e-5
for key in original_keys:
diff_l1 = (new_state_dict[key] - llama_state_dict_copy[key]).abs().sum().item()
if diff_l1 > tolerance:
logger.info(f"Significant difference found in key: {key}")
break
if diff_l1 <= tolerance:
logger.warning(
"Merged model seems identical to the original model. Further validation might be needed."
)
else:
logger.info("Merged model is different from the original model, check passed")
if __name__ == "__main__":
merge()
================================================
FILE: tools/llama/quantize.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
import datetime
import shutil
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import time
from pathlib import Path
import click
import torch
import torch.nn as nn
import torch.nn.functional as F
from fish_speech.models.text2semantic.inference import load_model
from fish_speech.models.text2semantic.llama import find_multiple
##### Quantization Primitives ######
def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
# assumes symmetric quantization
# assumes axis == 0
# assumes dense memory format
# TODO(future): relax ^ as needed
# default setup for affine quantization of activations
eps = torch.finfo(torch.float32).eps
# get min and max
min_val, max_val = torch.aminmax(x, dim=1)
# calculate scales and zero_points based on min and max
# reference: https://fburl.com/code/srbiybme
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
device = min_val_neg.device
# reference: https://fburl.com/code/4wll53rk
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scales = max_val_pos / (float(quant_max - quant_min) / 2)
# ensure scales is the same dtype as the original tensor
scales = torch.clamp(scales, min=eps).to(x.dtype)
zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
# quantize based on qmin/qmax/scales/zp
# reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
x_div = x / scales.unsqueeze(-1)
x_round = torch.round(x_div)
x_zp = x_round + zero_points.unsqueeze(-1)
quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
return quant, scales, zero_points
def get_group_qparams(w, n_bit=4, groupsize=128):
# needed for GPTQ with padding
if groupsize > w.shape[-1]:
groupsize = w.shape[-1]
assert groupsize > 1
assert w.shape[-1] % groupsize == 0
assert w.dim() == 2
to_quant = w.reshape(-1, groupsize)
assert torch.isnan(to_quant).sum() == 0
max_val = to_quant.amax(dim=1, keepdim=True)
min_val = to_quant.amin(dim=1, keepdim=True)
max_int = 2**n_bit - 1
scales = (max_val - min_val).clamp(min=1e-6) / max_int
zeros = min_val + scales * (2 ** (n_bit - 1))
return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
torch.bfloat16
).reshape(w.shape[0], -1)
def pack_scales_and_zeros(scales, zeros):
assert scales.shape == zeros.shape
assert scales.dtype == torch.bfloat16
assert zeros.dtype == torch.bfloat16
return (
torch.cat(
[
scales.reshape(scales.size(0), scales.size(1), 1),
zeros.reshape(zeros.size(0), zeros.size(1), 1),
],
2,
)
.transpose(0, 1)
.contiguous()
)
def unpack_scales_and_zeros(scales_and_zeros):
assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
assert scales_and_zeros.dtype == torch.float
return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
assert groupsize > 1
# needed for GPTQ single column quantize
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
groupsize = w.shape[-1]
assert w.shape[-1] % groupsize == 0
assert w.dim() == 2
to_quant = w.reshape(-1, groupsize)
assert torch.isnan(to_quant).sum() == 0
scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)
min_val = zeros - scales * (2 ** (n_bit - 1))
max_int = 2**n_bit - 1
min_int = 0
w_int32 = (
to_quant.sub(min_val)
.div(scales)
.round()
.clamp_(min_int, max_int)
.to(torch.int32)
.reshape_as(w)
)
return w_int32
def group_quantize_tensor(w, n_bit=4, groupsize=128):
scales, zeros = get_group_qparams(w, n_bit, groupsize)
w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
scales_and_zeros = pack_scales_and_zeros(scales, zeros)
return w_int32, scales_and_zeros
def group_dequantize_tensor_from_qparams(
w_int32, scales, zeros, n_bit=4, groupsize=128
):
assert groupsize > 1
# needed for GPTQ single column dequantize
if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
groupsize = w_int32.shape[-1]
assert w_int32.shape[-1] % groupsize == 0
assert w_int32.dim() == 2
w_int32_grouped = w_int32.reshape(-1, groupsize)
scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)
w_dq = (
w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32)
)
return w_dq
def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128):
scales, zeros = unpack_scales_and_zeros(scales_and_zeros)
return group_dequantize_tensor_from_qparams(
w_int32, scales, zeros, n_bit, groupsize
)
class QuantHandler:
def __init__(self, mod):
self.mod = mod
def create_quantized_state_dict(self) -> "StateDict":
pass
def convert_for_runtime(self) -> "nn.Module":
pass
##### Weight-only int8 per-channel quantized code ######
def replace_linear_weight_only_int8_per_channel(module):
for name, child in module.named_children():
if isinstance(child, nn.Linear):
setattr(
module,
name,
WeightOnlyInt8Linear(child.in_features, child.out_features),
)
else:
replace_linear_weight_only_int8_per_channel(child)
class WeightOnlyInt8QuantHandler:
def __init__(self, mod):
self.mod = mod
@torch.no_grad()
def create_quantized_state_dict(self):
cur_state_dict = self.mod.state_dict()
for fqn, mod in self.mod.named_modules():
if isinstance(mod, torch.nn.Linear):
int8_weight, scales, _ = dynamically_quantize_per_channel(
mod.weight.float(), -128, 127, torch.int8
)
cur_state_dict[f"{fqn}.weight"] = int8_weight
cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
return cur_state_dict
def convert_for_runtime(self):
replace_linear_weight_only_int8_per_channel(self.mod)
return self.mod
class WeightOnlyInt8Linear(torch.nn.Module):
__constants__ = ["in_features", "out_features"]
in_features: int
out_features: int
weight: torch.Tensor
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer(
"weight", torch.empty((out_features, in_features), dtype=torch.int8)
)
self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
##### weight only int4 per channel groupwise quantized code ######
def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
weight_int32, scales_and_zeros = group_quantize_tensor(
weight_bf16, n_bit=4, groupsize=groupsize
)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
weight_int32, inner_k_tiles
)
return weight_int4pack, scales_and_zeros
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
origin_x_size = x.size()
x = x.reshape(-1, origin_x_size[-1])
c = torch.ops.aten._weight_int4pack_mm(
x, weight_int4pack, groupsize, scales_and_zeros
)
new_shape = origin_x_size[:-1] + (out_features,)
c = c.reshape(new_shape)
return c
def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1):
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
def replace_linear_int4(module, groupsize, inner_k_tiles, padding):
for name, child in module.named_children():
if isinstance(child, nn.Linear):
if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
setattr(
module,
name,
WeightOnlyInt4Linear(
child.in_features,
child.out_features,
bias=False,
groupsize=groupsize,
inner_k_tiles=inner_k_tiles,
padding=False,
),
)
elif padding:
setattr(
module,
name,
WeightOnlyInt4Linear(
child.in_features,
child.out_features,
bias=False,
groupsize=groupsize,
inner_k_tiles=inner_k_tiles,
padding=True,
),
)
else:
replace_linear_int4(child, groupsize, inner_k_tiles, padding)
class WeightOnlyInt4QuantHandler:
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
self.mod = mod
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles
self.padding = padding
assert groupsize in [32, 64, 128, 256]
assert inner_k_tiles in [2, 4, 8]
@torch.no_grad()
def create_quantized_state_dict(self):
cur_state_dict = self.mod.state_dict()
for fqn, mod in self.mod.named_modules():
if isinstance(mod, torch.nn.Linear):
assert not mod.bias
out_features = mod.out_features
in_features = mod.in_features
assert out_features % 8 == 0, "require out_features % 8 == 0"
print(f"linear: {fqn}, in={in_features}, out={out_features}")
weight = mod.weight.data
if not _check_linear_int4_k(
in_features, self.groupsize, self.inner_k_tiles
):
if self.padding:
import torch.nn.functional as F
print(
f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
)
padded_in_features = find_multiple(in_features, 1024)
weight = F.pad(
weight, pad=(0, padded_in_features - in_features)
)
else:
print(
f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
+ "and that groupsize and inner_k_tiles*16 evenly divide into it"
)
continue
(
weight_int4pack,
scales_and_zeros,
) = prepare_int4_weight_and_scales_and_zeros(
weight.to(torch.bfloat16).to("cuda"),
self.groupsize,
self.inner_k_tiles,
)
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu")
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu")
return cur_state_dict
def convert_for_runtime(self):
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
return self.mod
class WeightOnlyInt4Linear(torch.nn.Module):
__constants__ = ["in_features", "out_features"]
in_features: int
out_features: int
weight: torch.Tensor
def __init__(
self,
in_features: int,
out_features: int,
bias=True,
device=None,
dtype=None,
groupsize: int = 128,
inner_k_tiles: int = 8,
padding: bool = True,
) -> None:
super().__init__()
self.padding = padding
if padding:
self.origin_in_features = in_features
in_features = find_multiple(in_features, 1024)
self.in_features = in_features
self.out_features = out_features
assert not bias, "require bias=False"
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles
assert out_features % 8 == 0, "require out_features % 8 == 0"
assert (
in_features % (inner_k_tiles * 16) == 0
), "require in_features % (innerKTiles * 16) == 0"
self.register_buffer(
"weight",
torch.empty(
(
out_features // 8,
in_features // (inner_k_tiles * 16),
32,
inner_k_tiles // 2,
),
dtype=torch.int32,
),
)
self.register_buffer(
"scales_and_zeros",
torch.empty(
(in_features // groupsize, out_features, 2), dtype=torch.bfloat16
),
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.to(torch.bfloat16)
if self.padding:
import torch.nn.functional as F
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
return linear_forward_int4(
input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize
)
def generate_folder_name():
now = datetime.datetime.now()
folder_name = now.strftime("%Y%m%d_%H%M%S")
return folder_name
@click.command()
@click.option(
"--checkpoint-path",
type=click.Path(path_type=Path, exists=True),
default="checkpoints/fish-speech-1.4",
)
@click.option(
"--mode", type=str, default="int8", help="type of quantization to perform"
)
@click.option(
"--groupsize", type=int, default=128, help="Group size for int4 quantization."
)
@click.option("--timestamp", type=str, default="None", help="When to do quantization")
def quantize(checkpoint_path: Path, mode: str, groupsize: int, timestamp: str) -> None:
device = "cpu"
precision = torch.bfloat16
print("Loading model ...")
t0 = time.time()
model, _ = load_model(
checkpoint_path=checkpoint_path,
device=device,
precision=precision,
compile=False,
)
vq_model = "codec.pth"
now = timestamp if timestamp != "None" else generate_folder_name()
if mode == "int8":
print(
"Quantizing model weights for int8 weight-only symmetric per-channel quantization"
)
quant_handler = WeightOnlyInt8QuantHandler(model)
quantized_state_dict = quant_handler.create_quantized_state_dict()
dir_name = checkpoint_path
dst_name = Path(f"checkpoints/fs-1.2-int8-{now}")
shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
if (dst_name / vq_model).exists():
(dst_name / vq_model).unlink()
quantize_path = dst_name / "model.pth"
elif mode == "int4":
print(
"Quantizing model weights for int4 weight-only affine per-channel groupwise quantization"
)
quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
quantized_state_dict = quant_handler.create_quantized_state_dict()
dir_name = checkpoint_path
dst_name = Path(f"checkpoints/fs-1.2-int4-g{groupsize}-{now}")
shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
if (dst_name / vq_model).exists():
(dst_name / vq_model).unlink()
quantize_path = dst_name / "model.pth"
else:
raise ValueError(
f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]"
)
print(f"Writing quantized weights to {quantize_path}")
quantize_path.unlink(missing_ok=True) # remove existing file if one already there
torch.save(quantized_state_dict, quantize_path)
print(f"Quantization complete took {time.time() - t0:.02f} seconds")
if __name__ == "__main__":
quantize()
================================================
FILE: tools/run_webui.py
================================================
import os
from argparse import ArgumentParser
from pathlib import Path
import pyrootutils
import torch
from loguru import logger
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from fish_speech.inference_engine import TTSInferenceEngine
from fish_speech.models.dac.inference import load_model as load_decoder_model
from fish_speech.models.text2semantic.inference import launch_thread_safe_queue
from fish_speech.utils.schema import ServeTTSRequest
from tools.webui import build_app
from tools.webui.inference import get_inference_wrapper
# Make einx happy
os.environ["EINX_FILTER_TRACEBACK"] = "false"
def parse_args():
parser = ArgumentParser()
parser.add_argument(
"--llama-checkpoint-path",
type=Path,
default="checkpoints/s2-pro",
)
parser.add_argument(
"--decoder-checkpoint-path",
type=Path,
default="checkpoints/s2-pro/codec.pth",
)
parser.add_argument("--decoder-config-name", type=str, default="modded_dac_vq")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--half", action="store_true")
parser.add_argument("--compile", action="store_true")
parser.add_argument("--max-gradio-length", type=int, default=0)
parser.add_argument("--theme", type=str, default="light")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
args.precision = torch.half if args.half else torch.bfloat16
# Check if MPS or CUDA is available
if torch.backends.mps.is_available():
args.device = "mps"
logger.info("mps is available, running on mps.")
elif torch.xpu.is_available():
args.device = "xpu"
logger.info("XPU is available, running on XPU.")
elif not torch.cuda.is_available():
logger.info("CUDA is not available, running on CPU.")
args.device = "cpu"
logger.info("Loading Llama model...")
llama_queue = launch_thread_safe_queue(
checkpoint_path=args.llama_checkpoint_path,
device=args.device,
precision=args.precision,
compile=args.compile,
)
logger.info("Loading VQ-GAN model...")
decoder_model = load_decoder_model(
config_name=args.decoder_config_name,
checkpoint_path=args.decoder_checkpoint_path,
device=args.device,
)
logger.info("Decoder model loaded, warming up...")
# Create the inference engine
inference_engine = TTSInferenceEngine(
llama_queue=llama_queue,
decoder_model=decoder_model,
compile=args.compile,
precision=args.precision,
)
# Dry run to check if the model is loaded correctly and avoid the first-time latency
list(
inference_engine.inference(
ServeTTSRequest(
text="Hello world.",
references=[],
reference_id=None,
max_new_tokens=1024,
chunk_length=200,
top_p=0.7,
repetition_penalty=1.5,
temperature=0.7,
format="wav",
)
)
)
logger.info("Warming up done, launching the web UI...")
# Get the inference function with the immutable arguments
inference_fct = get_inference_wrapper(inference_engine)
app = build_app(inference_fct, args.theme)
app.launch()
================================================
FILE: tools/server/api_utils.py
================================================
from argparse import ArgumentParser
from http import HTTPStatus
from typing import Annotated, Any
import ormsgpack
from baize.datastructures import ContentType
from kui.asgi import (
HTTPException,
HttpRequest,
JSONResponse,
request,
)
from loguru import logger
from pydantic import BaseModel
from fish_speech.inference_engine import TTSInferenceEngine
from fish_speech.utils.schema import ServeTTSRequest
from tools.server.inference import inference_wrapper as inference
def parse_args():
parser = ArgumentParser()
parser.add_argument("--mode", type=str, choices=["tts"], default="tts")
parser.add_argument(
"--llama-checkpoint-path",
type=str,
default="checkpoints/s2-pro",
)
parser.add_argument(
"--decoder-checkpoint-path",
type=str,
default="checkpoints/s2-pro/codec.pth",
)
parser.add_argument("--decoder-config-name", type=str, default="modded_dac_vq")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--half", action="store_true")
parser.add_argument("--compile", action="store_true")
parser.add_argument("--max-text-length", type=int, default=0)
parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
parser.add_argument("--workers", type=int, default=1)
parser.add_argument("--api-key", type=str, default=None)
return parser.parse_args()
class MsgPackRequest(HttpRequest):
async def data(
self,
) -> Annotated[
Any,
ContentType("application/msgpack"),
ContentType("application/json"),
ContentType("multipart/form-data"),
]:
if self.content_type == "application/msgpack":
return ormsgpack.unpackb(await self.body)
elif self.content_type == "application/json":
return await self.json
elif self.content_type == "multipart/form-data":
return await self.form
raise HTTPException(
HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
headers={
"Accept": "application/msgpack, application/json, multipart/form-data"
},
)
async def inference_async(req: ServeTTSRequest, engine: TTSInferenceEngine):
for chunk in inference(req, engine):
print("Got chunk")
if isinstance(chunk, bytes):
yield chunk
async def buffer_to_async_generator(buffer):
yield buffer
def get_content_type(audio_format):
if audio_format == "wav":
return "audio/wav"
elif audio_format == "flac":
return "audio/flac"
elif audio_format == "mp3":
return "audio/mpeg"
elif audio_format == "opus":
return "audio/ogg"
else:
return "application/octet-stream"
def wants_json(req):
"""Helper method to determine if the client wants a JSON response
Parameters
----------
req : Request
The request object
Returns
-------
bool
True if the client wants a JSON response, False otherwise
"""
q = req.query_params.get("format", "").strip().lower()
if q in {"json", "application/json", "msgpack", "application/msgpack"}:
return q == "json"
accept = req.headers.get("Accept", "").strip().lower()
return "application/json" in accept and "application/msgpack" not in accept
def format_response(response: BaseModel, status_code=200):
"""
Helper function to format responses consistently based on client preference.
Parameters
----------
response : BaseModel
The response object to format
status_code : int
HTTP status code (default: 200)
Returns
-------
Response
Formatted response in the client's preferred format
"""
try:
if wants_json(request):
return JSONResponse(
response.model_dump(mode="json"), status_code=status_code
)
return (
ormsgpack.packb(
response,
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
),
status_code,
{"Content-Type": "application/msgpack"},
)
except Exception as e:
logger.error(f"Error formatting response: {e}", exc_info=True)
# Fallback to JSON response if formatting fails
return JSONResponse(
{"error": "Response formatting failed", "details": str(e)}, status_code=500
)
================================================
FILE: tools/server/exception_handler.py
================================================
import traceback
from http import HTTPStatus
from kui.asgi import HTTPException, JSONResponse
class ExceptionHandler:
async def http_exception_handler(self, exc: HTTPException):
return JSONResponse(
dict(
statusCode=exc.status_code,
message=exc.content,
error=HTTPStatus(exc.status_code).phrase,
),
exc.status_code,
exc.headers,
)
async def other_exception_handler(self, exc: Exception):
traceback.print_exc()
status = HTTPStatus.INTERNAL_SERVER_ERROR
return JSONResponse(
dict(statusCode=status, message=str(exc), error=status.phrase),
status,
)
================================================
FILE: tools/server/inference.py
================================================
from http import HTTPStatus
import numpy as np
from kui.asgi import HTTPException
from fish_speech.inference_engine import TTSInferenceEngine
from fish_speech.utils.schema import ServeTTSRequest
AMPLITUDE = 32768 # Needs an explaination
def inference_wrapper(req: ServeTTSRequest, engine: TTSInferenceEngine):
"""
Wrapper for the inference function.
Used in the API server.
"""
count = 0
for result in engine.inference(req):
match result.code:
case "header":
if isinstance(result.audio, tuple):
yield result.audio[1]
case "error":
raise HTTPException(
HTTPStatus.INTERNAL_SERVER_ERROR,
content=str(result.error),
)
case "segment":
count += 1
if isinstance(result.audio, tuple):
yield (result.audio[1] * AMPLITUDE).astype(np.int16).tobytes()
case "final":
count += 1
if isinstance(result.audio, tuple):
yield result.audio[1]
return None # Stop the generator
if count == 0:
raise HTTPException(
HTTPStatus.INTERNAL_SERVER_ERROR,
content="No audio generated, please check the input text.",
)
================================================
FILE: tools/server/model_manager.py
================================================
import torch
from loguru import logger
from fish_speech.inference_engine import TTSInferenceEngine
from fish_speech.models.dac.inference import load_model as load_decoder_model
from fish_speech.models.text2semantic.inference import launch_thread_safe_queue
from fish_speech.utils.schema import ServeTTSRequest
from tools.server.inference import inference_wrapper as inference
class ModelManager:
def __init__(
self,
mode: str,
device: str,
half: bool,
compile: bool,
llama_checkpoint_path: str,
decoder_checkpoint_path: str,
decoder_config_name: str,
) -> None:
self.mode = mode
self.device = device
self.half = half
self.compile = compile
self.precision = torch.half if half else torch.bfloat16
# Check if MPS or CUDA is available
if torch.backends.mps.is_available():
self.device = "mps"
logger.info("mps is available, running on mps.")
elif not torch.cuda.is_available():
self.device = "cpu"
logger.info("CUDA is not available, running on CPU.")
# Load the TTS models
self.load_llama_model(
llama_checkpoint_path, self.device, self.precision, self.compile, self.mode
)
self.load_decoder_model(
decoder_config_name, decoder_checkpoint_path, self.device
)
self.tts_inference_engine = TTSInferenceEngine(
llama_queue=self.llama_queue,
decoder_model=self.decoder_model,
precision=self.precision,
compile=self.compile,
)
# Warm up the models
if self.mode == "tts":
self.warm_up(self.tts_inference_engine)
def load_llama_model(
self, checkpoint_path, device, precision, compile, mode
) -> None:
if mode == "tts":
self.llama_queue = launch_thread_safe_queue(
checkpoint_path=checkpoint_path,
device=device,
precision=precision,
compile=compile,
)
else:
raise ValueError(f"Invalid mode: {mode}")
logger.info("LLAMA model loaded.")
def load_decoder_model(self, config_name, checkpoint_path, device) -> None:
self.decoder_model = load_decoder_model(
config_name=config_name,
checkpoint_path=checkpoint_path,
device=device,
)
logger.info("Decoder model loaded.")
def warm_up(self, tts_inference_engine) -> None:
request = ServeTTSRequest(
text="Hello world.",
references=[],
reference_id=None,
max_new_tokens=1024,
chunk_length=200,
top_p=0.7,
repetition_penalty=1.2,
temperature=0.7,
format="wav",
)
list(inference(request, tts_inference_engine))
logger.info("Models warmed up.")
================================================
FILE: tools/server/model_utils.py
================================================
import io
import re
import librosa
import torch
import torchaudio
from cachetools import LRUCache, cached
CACHE_MAXSIZE = 10000
MICRO_BATCH_SIZE = 8
ASR_SAMPLE_RATE = 16000
HUGE_GAP_THRESHOLD = 4000
@torch.no_grad()
@torch.autocast(device_type="cuda", dtype=torch.half)
def batch_encode(model, audios_list: list[bytes]):
# Get sample rate from model
if hasattr(model, "spec_transform"):
sample_rate = model.spec_transform.sample_rate
else:
sample_rate = model.sample_rate
audios: list[torch.Tensor] = [
(
torch.from_numpy(librosa.load(io.BytesIO(audio), sr=sample_rate)[0])[None]
if isinstance(audio, bytes)
else audio
)
for audio in audios_list
]
lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device)
max_length = lengths.max().item()
print(f"Encode max length: {max_length / sample_rate:.2f}s")
padded = torch.stack(
[
torch.nn.functional.pad(audio, (0, int(max_length - audio.shape[-1])))
for audio in audios
]
).to(model.device)
features, feature_lengths = model.encode(padded, audio_lengths=lengths)
features, feature_lengths = features.cpu(), feature_lengths.cpu()
return [feature[..., :length] for feature, length in zip(features, feature_lengths)]
@cached(
cache=LRUCache(maxsize=CACHE_MAXSIZE),
key=lambda model, audios: (model.device, tuple(audios)),
)
def cached_vqgan_batch_encode(model, audios: list[bytes]):
return batch_encode(model, audios)
@torch.no_grad()
@torch.autocast(device_type="cuda", dtype=torch.half)
def batch_vqgan_decode(model, features):
lengths = torch.tensor(
[feature.shape[-1] for feature in features], device=model.device
)
max_length = lengths.max().item()
padded = torch.stack(
[
torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1]))
for feature in features
]
).to(model.device)
# If bs too large, we do micro batch decode
audios, audio_lengths = [], []
for i in range(0, padded.shape[0], MICRO_BATCH_SIZE):
audio, audio_length = model.decode(
padded[i : i + MICRO_BATCH_SIZE],
feature_lengths=lengths[i : i + MICRO_BATCH_SIZE],
)
audios.append(audio)
audio_lengths.append(audio_length)
audios = torch.cat(audios, dim=0)
audio_lengths = torch.cat(audio_lengths, dim=0)
audios, audio_lengths = audios.cpu(), audio_lengths.cpu()
return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)]
================================================
FILE: tools/server/views.py
================================================
import io
import os
import re
import shutil
import tempfile
import time
from http import HTTPStatus
from pathlib import Path
import numpy as np
import ormsgpack
import soundfile as sf
import torch
from kui.asgi import (
Body,
HTTPException,
HttpView,
JSONResponse,
Routes,
StreamResponse,
UploadFile,
request,
)
from loguru import logger
from typing_extensions import Annotated
from fish_speech.utils.schema import (
AddReferenceRequest,
AddReferenceResponse,
DeleteReferenceResponse,
ListReferencesResponse,
ServeTTSRequest,
ServeVQGANDecodeRequest,
ServeVQGANDecodeResponse,
ServeVQGANEncodeRequest,
ServeVQGANEncodeResponse,
UpdateReferenceResponse,
)
from tools.server.api_utils import (
buffer_to_async_generator,
format_response,
get_content_type,
inference_async,
)
from tools.server.inference import inference_wrapper as inference
from tools.server.model_manager import ModelManager
from tools.server.model_utils import (
batch_vqgan_decode,
cached_vqgan_batch_encode,
)
MAX_NUM_SAMPLES = int(os.getenv("NUM_SAMPLES", 1))
_WEBUI_HTML = (
Path(__file__).parent.parent.parent / "awesome_webui" / "dist" / "index.html"
)
routes = Routes()
@routes.http("/ui")
class WebUI(HttpView):
@classmethod
async def get(cls):
from kui.asgi import HTMLResponse
if _WEBUI_HTML.exists():
return HTMLResponse(_WEBUI_HTML.read_text(encoding="utf-8"))
return JSONResponse(
{"error": "WebUI not built. Run: cd awesome_webui && npm run build"},
status_code=404,
)
@routes.http("/v1/health")
class Health(HttpView):
@classmethod
async def get(cls):
return JSONResponse({"status": "ok"})
@classmethod
async def post(cls):
return JSONResponse({"status": "ok"})
@routes.http.post("/v1/vqgan/encode")
async def vqgan_encode(req: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]):
"""
Encode audio using VQGAN model.
"""
try:
# Get the model from the app
model_manager: ModelManager = request.app.state.model_manager
decoder_model = model_manager.decoder_model
# Encode the audio
start_time = time.time()
tokens = cached_vqgan_batch_encode(decoder_model, req.audios)
logger.info(
f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms"
)
# Return the response
return ormsgpack.packb(
ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
)
except Exception as e:
logger.error(f"Error in VQGAN encode: {e}", exc_info=True)
raise HTTPException(
HTTPStatus.INTERNAL_SERVER_ERROR, content="Failed to encode audio"
)
@routes.http.post("/v1/vqgan/decode")
async def vqgan_decode(req: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]):
"""
Decode tokens to audio using VQGAN model.
"""
try:
# Get the model from the app
model_manager: ModelManager = request.app.state.model_manager
decoder_model = model_manager.decoder_model
# Decode the audio
tokens = [torch.tensor(token, dtype=torch.int) for token in req.tokens]
start_time = time.time()
audios = batch_vqgan_decode(decoder_model, tokens)
logger.info(
f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms"
)
audios = [audio.astype(np.float16).tobytes() for audio in audios]
# Return the response
return ormsgpack.packb(
ServeVQGANDecodeResponse(audios=audios),
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
)
except Exception as e:
logger.error(f"Error in VQGAN decode: {e}", exc_info=True)
raise HTTPException(
HTTPStatus.INTERNAL_SERVER_ERROR, content="Failed to decode tokens to audio"
)
@routes.http.post("/v1/tts")
async def tts(req: Annotated[ServeTTSRequest, Body(exclusive=True)]):
"""
Generate speech from text using TTS model.
"""
try:
# Get the model from the app
app_state = request.app.state
model_manager: ModelManager = app_state.model_manager
engine = model_manager.tts_inference_engine
sample_rate = engine.decoder_model.sample_rate
# Check if the text is too long
if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length:
raise HTTPException(
HTTPStatus.BAD_REQUEST,
content=f"Text is too long, max length is {app_state.max_text_length}",
)
# Check if streaming is enabled
if req.streaming and req.format != "wav":
raise HTTPException(
HTTPStatus.BAD_REQUEST,
content="Streaming only supports WAV format",
)
# Perform TTS
if req.streaming:
return StreamResponse(
iterable=inference_async(req, engine),
headers={
"Content-Disposition": f"attachment; filename=audio.{req.format}",
},
content_type=get_content_type(req.format),
)
else:
fake_audios = next(inference(req, engine))
buffer = io.BytesIO()
sf.write(
buffer,
fake_audios,
sample_rate,
format=req.format,
)
return StreamResponse(
iterable=buffer_to_async_generator(buffer.getvalue()),
headers={
"Content-Disposition": f"attachment; filename=audio.{req.format}",
},
content_type=get_content_type(req.format),
)
except HTTPException:
# Re-raise HTTP exceptions as they are already properly formatted
raise
except Exception as e:
logger.error(f"Error in TTS generation: {e}", exc_info=True)
raise HTTPException(
HTTPStatus.INTERNAL_SERVER_ERROR, content="Failed to generate speech"
)
@routes.http.post("/v1/references/add")
async def add_reference(
id: str = Body(...), audio: UploadFile = Body(...), text: str = Body(...)
):
"""
Add a new reference voice with audio file and text.
"""
temp_file_path = None
try:
# Validate input parameters
if not id or not id.strip():
raise ValueError("Reference ID cannot be empty")
if not text or not text.strip():
raise ValueError("Reference text cannot be empty")
# Get the model manager to access the reference loader
app_state = request.app.state
model_manager: ModelManager = app_state.model_manager
engine = model_manager.tts_inference_engine
# Read the uploaded audio file
audio_content = audio.read()
if not audio_content:
raise ValueError("Audio file is empty or could not be read")
# Create a temporary file for the audio data
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
temp_file.write(audio_content)
temp_file_path = temp_file.name
# Add the reference using the engine's reference loader
engine.add_reference(id, temp_file_path, text)
response = AddReferenceResponse(
success=True,
message=f"Reference voice '{id}' added successfully",
reference_id=id,
)
return format_response(response)
except FileExistsError as e:
logger.warning(f"Reference ID '{id}' already exists: {e}")
response = AddReferenceResponse(
success=False,
message=f"Reference ID '{id}' already exists",
reference_id=id,
)
return format_response(response, status_code=409) # Conflict
except ValueError as e:
logger.warning(f"Invalid input for reference '{id}': {e}")
response = AddReferenceResponse(success=False, message=str(e), reference_id=id)
return format_response(response, status_code=400)
except (FileNotFoundError, OSError) as e:
logger.error(f"File system error for reference '{id}': {e}")
response = AddReferenceResponse(
success=False, message="File system error occurred", reference_id=id
)
return format_response(response, status_code=500)
except Exception as e:
logger.error(f"Unexpected error adding reference '{id}': {e}", exc_info=True)
response = AddReferenceResponse(
success=False, message="Internal server error occurred", reference_id=id
)
return format_response(response, status_code=500)
finally:
# Clean up temporary file
if temp_file_path and os.path.exists(temp_file_path):
try:
os.unlink(temp_file_path)
except OSError as e:
logger.warning(
f"Failed to clean up temporary file {temp_file_path}: {e}"
)
@routes.http.get("/v1/references/list")
async def list_references():
"""
Get a list of all available reference voice IDs.
"""
try:
# Get the model manager to access the reference loader
app_state = request.app.state
model_manager: ModelManager = app_state.model_manager
engine = model_manager.tts_inference_engine
# Get the list of reference IDs
reference_ids = engine.list_reference_ids()
response = ListReferencesResponse(
success=True,
reference_ids=reference_ids,
message=f"Found {len(reference_ids)} reference voices",
)
return format_response(response)
except Exception as e:
logger.error(f"Unexpected error listing references: {e}", exc_info=True)
response = ListReferencesResponse(
success=False, reference_ids=[], message="Internal server error occurred"
)
return format_response(response, status_code=500)
@routes.http.delete("/v1/references/delete")
async def delete_reference(reference_id: str = Body(...)):
"""
Delete a reference voice by ID.
"""
try:
# Validate input parameters
if not reference_id or not reference_id.strip():
raise ValueError("Reference ID cannot be empty")
# Get the model manager to access the reference loader
app_state = request.app.state
model_manager: ModelManager = app_state.model_manager
engine = model_manager.tts_inference_engine
# Delete the reference using the engine's reference loader
engine.delete_reference(reference_id)
response = DeleteReferenceResponse(
success=True,
message=f"Reference voice '{reference_id}' deleted successfully",
reference_id=reference_id,
)
return format_response(response)
except FileNotFoundError as e:
logger.warning(f"Reference ID '{reference_id}' not found: {e}")
response = DeleteReferenceResponse(
success=False,
message=f"Reference ID '{reference_id}' not found",
reference_id=reference_id,
)
return format_response(response, status_code=404) # Not Found
except ValueError as e:
logger.warning(f"Invalid input for reference '{reference_id}': {e}")
response = DeleteReferenceResponse(
success=False, message=str(e), reference_id=reference_id
)
return format_response(response, status_code=400)
except OSError as e:
logger.error(f"File system error deleting reference '{reference_id}': {e}")
response = DeleteReferenceResponse(
success=False,
message="File system error occurred",
reference_id=reference_id,
)
return format_response(response, status_code=500)
except Exception as e:
logger.error(
f"Unexpected error deleting reference '{reference_id}': {e}", exc_info=True
)
response = DeleteReferenceResponse(
success=False,
message="Internal server error occurred",
reference_id=reference_id,
)
return format_response(response, status_code=500)
@routes.http.post("/v1/references/update")
async def update_reference(
old_reference_id: str = Body(...), new_reference_id: str = Body(...)
):
"""
Rename a reference voice directory from old_reference_id to new_reference_id.
"""
try:
# Validate input parameters
if not old_reference_id or not old_reference_id.strip():
raise ValueError("Old reference ID cannot be empty")
if not new_reference_id or not new_reference_id.strip():
raise ValueError("New reference ID cannot be empty")
if old_reference_id == new_reference_id:
raise ValueError("New reference ID must be different from old reference ID")
# Validate ID format per ReferenceLoader rules
id_pattern = r"^[a-zA-Z0-9\-_ ]+$"
if not re.match(id_pattern, new_reference_id) or len(new_reference_id) > 255:
raise ValueError(
"New reference ID contains invalid characters or is too long"
)
# Access engine to update caches after renaming
app_state = request.app.state
model_manager: ModelManager = app_state.model_manager
engine = model_manager.tts_inference_engine
refs_base = Path("references")
old_dir = refs_base / old_reference_id
new_dir = refs_base / new_reference_id
# Existence checks
if not old_dir.exists() or not old_dir.is_dir():
raise FileNotFoundError(f"Reference ID '{old_reference_id}' not found")
if new_dir.exists():
# Conflict: destination already exists
response = UpdateReferenceResponse(
success=False,
message=f"Reference ID '{new_reference_id}' already exists",
old_reference_id=old_reference_id,
new_reference_id=new_reference_id,
)
return format_response(response, status_code=409)
# Perform rename
old_dir.rename(new_dir)
# Update in-memory cache key if present
if old_reference_id in engine.ref_by_id:
engine.ref_by_id[new_reference_id] = engine.ref_by_id.pop(old_reference_id)
response = UpdateReferenceResponse(
success=True,
message=(
f"Reference voice renamed from '{old_reference_id}' to '{new_reference_id}' successfully"
),
old_reference_id=old_reference_id,
new_reference_id=new_reference_id,
)
return format_response(response)
except FileNotFoundError as e:
logger.warning(str(e))
response = UpdateReferenceResponse(
success=False,
message=str(e),
old_reference_id=old_reference_id,
new_reference_id=new_reference_id,
)
return format_response(response, status_code=404)
except ValueError as e:
logger.warning(f"Invalid input for update reference: {e}")
response = UpdateReferenceResponse(
success=False,
message=str(e),
old_reference_id=old_reference_id if "old_reference_id" in locals() else "",
new_reference_id=new_reference_id if "new_reference_id" in locals() else "",
)
return format_response(response, status_code=400)
except OSError as e:
logger.error(f"File system error renaming reference: {e}")
response = UpdateReferenceResponse(
success=False,
message="File system error occurred",
old_reference_id=old_reference_id,
new_reference_id=new_reference_id,
)
return format_response(response, status_code=500)
except Exception as e:
logger.error(f"Unexpected error updating reference: {e}", exc_info=True)
response = UpdateReferenceResponse(
success=False,
message="Internal server error occurred",
old_reference_id=old_reference_id if "old_reference_id" in locals() else "",
new_reference_id=new_reference_id if "new_reference_id" in locals() else "",
)
return format_response(response, status_code=500)
================================================
FILE: tools/vqgan/create_train_split.py
================================================
import math
from pathlib import Path
from random import Random
import click
from loguru import logger
from pydub import AudioSegment
from tqdm import tqdm
from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
@click.command()
@click.argument("root", type=click.Path(exists=True, path_type=Path))
@click.option("--val-ratio", type=float, default=None)
@click.option("--val-count", type=int, default=None)
@click.option("--filelist", default=None, type=Path)
@click.option("--min-duration", default=None, type=float)
@click.option("--max-duration", default=None, type=float)
def main(root, val_ratio, val_count, filelist, min_duration, max_duration):
if filelist:
files = [i[0] for i in load_filelist(filelist)]
else:
files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
if min_duration is None and max_duration is None:
filtered_files = list(map(str, [file.relative_to(root) for file in files]))
else:
filtered_files = []
for file in tqdm(files):
try:
audio = AudioSegment.from_file(str(file))
duration = len(audio) / 1000.0
if min_duration is not None and duration < min_duration:
logger.info(
f"Skipping {file} due to duration {duration:.2f} < {min_duration:.2f}"
)
continue
if max_duration is not None and duration > max_duration:
logger.info(
f"Skipping {file} due to duration {duration:.2f} > {max_duration:.2f}"
)
continue
filtered_files.append(str(file.relative_to(root)))
except Exception as e:
logger.info(f"Error processing {file}: {e}")
logger.info(
f"Found {len(files)} files, remaining {len(filtered_files)} files after filtering"
)
Random(42).shuffle(filtered_files)
if val_count is None and val_ratio is None:
logger.info("Validation ratio and count not specified, using min(20%, 100)")
val_size = min(100, math.ceil(len(filtered_files) * 0.2))
elif val_count is not None and val_ratio is not None:
logger.error("Cannot specify both val_count and val_ratio")
return
elif val_count is not None:
if val_count < 1 or val_count > len(filtered_files):
logger.error("val_count must be between 1 and number of files")
return
val_size = val_count
else:
val_size = math.ceil(len(filtered_files) * val_ratio)
logger.info(f"Using {val_size} files for validation")
with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
f.write("\n".join(filtered_files[val_size:]))
with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
f.write("\n".join(filtered_files[:val_size]))
logger.info("Done")
if __name__ == "__main__":
main()
================================================
FILE: tools/vqgan/extract_vq.py
================================================
import os
import subprocess as sp
import sys
import time
from datetime import timedelta
from functools import lru_cache
from pathlib import Path
from random import Random
import click
import numpy as np
import torch
import torchaudio
from hydra import compose, initialize
from hydra.utils import instantiate
from loguru import logger
from omegaconf import OmegaConf
from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
# register eval resolver
OmegaConf.register_new_resolver("eval", eval)
# This file is used to convert the audio files to text files using the Whisper model.
# It's mainly used to generate the training data for the VQ model.
backends = torchaudio.list_audio_backends()
if "ffmpeg" in backends:
backend = "ffmpeg"
else:
backend = "soundfile"
RANK = int(os.environ.get("SLURM_PROCID", 0))
WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
logger_format = (
"{time:YYYY-MM-DD HH:mm:ss.SSS} | "
"{level: <8} | "
"{name} :{function} :{line} | "
"{extra[rank]} - {message} "
)
logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
logger.remove()
logger.add(sys.stderr, format=logger_format)
@lru_cache(maxsize=1)
def get_model(
config_name: str = "modded_dac_vq",
checkpoint_path: str = "checkpoints/openaudio-s1-mini/codec.pth",
device: str | torch.device = "cuda",
):
with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
cfg = compose(config_name=config_name)
model = instantiate(cfg)
state_dict = torch.load(
checkpoint_path,
map_location=device,
)
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
if any("generator" in k for k in state_dict):
state_dict = {
k.replace("generator.", ""): v
for k, v in state_dict.items()
if "generator." in k
}
model.load_state_dict(state_dict, strict=False)
model.eval()
model.to(device)
logger.info(f"Loaded model")
return model
@torch.inference_mode()
def process_batch(files: list[Path], model) -> float:
wavs = []
audio_lengths = []
new_files = []
max_length = total_time = 0
for file in files:
try:
wav, sr = torchaudio.load(
str(file), backend=backend
) # Need to install libsox-dev
except Exception as e:
logger.error(f"Error reading {file}: {e}")
continue
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
wav = torchaudio.functional.resample(wav.cuda(), sr, model.sample_rate)[0]
total_time += len(wav) / model.sample_rate
max_length = max(max_length, len(wav))
wavs.append(wav)
audio_lengths.append(len(wav))
new_files.append(file)
files = new_files
# Pad to max length
for i, wav in enumerate(wavs):
wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
audios = torch.stack(wavs, dim=0)[:, None]
audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long)
# Calculate lengths
indices, feature_lengths = model.encode(audios, audio_lengths)
# Save to disk
outputs = indices.cpu().numpy()
for file, length, feature, audio_length in zip(
files, feature_lengths, outputs, audio_lengths
):
feature = feature[:, :length]
# (T,)
with open(file.with_suffix(".npy"), "wb") as f:
np.save(f, feature)
return total_time
@click.command()
@click.argument("folder")
@click.option("--num-workers", default=1)
@click.option("--config-name", default="modded_dac_vq")
@click.option(
"--checkpoint-path",
default="checkpoints/s2-pro/codec.pth",
)
@click.option("--batch-size", default=64)
@click.option("--filelist", default=None, type=Path)
def main(
folder: str,
num_workers: int,
config_name: str,
checkpoint_path: str,
batch_size: int,
filelist: Path,
):
if num_workers > 1 and WORLD_SIZE != num_workers:
assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
logger.info(f"Spawning {num_workers} workers")
if torch.cuda.is_available():
visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if visible_devices is None:
visible_devices = list(range(torch.cuda.device_count()))
else:
visible_devices = visible_devices.split(",")
else:
# Set to empty string to avoid using GPU
visible_devices = [""]
processes = []
for i in range(num_workers):
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
env["SLURM_PROCID"] = str(i)
env["SLURM_NTASKS"] = str(num_workers)
processes.append(
sp.Popen(
[sys.executable] + sys.argv.copy(),
env=env,
)
)
for p in processes:
p.wait()
logger.info(f"All workers finished")
return
# This is a worker
logger.info(f"Starting worker")
if filelist:
files = [i[0] for i in load_filelist(filelist)]
else:
files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=False)
print(f"Found {len(files)} files")
files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()]
total_files = len(files)
files = files[RANK::WORLD_SIZE]
logger.info(f"Processing {len(files)}/{total_files} files")
# Batch processing
total_time = 0
begin_time = time.time()
processed_files = 0
model = get_model(config_name, checkpoint_path)
for n_batch, idx in enumerate(range(0, len(files), batch_size)):
batch = files[idx : idx + batch_size]
batch_time = process_batch(batch, model)
total_time += batch_time
processed_files += len(batch)
if (n_batch + 1) % 10 == 0:
eta = (
(time.time() - begin_time)
/ processed_files
* (len(files) - processed_files)
)
logger.info(
f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
+ f"ETA: {timedelta(seconds=round(eta))}s"
)
logger.info(
f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
)
if __name__ == "__main__":
main()
================================================
FILE: tools/webui/__init__.py
================================================
from typing import Callable
import gradio as gr
from fish_speech.i18n import i18n
from tools.webui.variables import HEADER_MD, TEXTBOX_PLACEHOLDER
def build_app(inference_fct: Callable, theme: str = "light") -> gr.Blocks:
with gr.Blocks(theme=gr.themes.Base()) as app:
gr.Markdown(HEADER_MD)
# Use light theme by default
app.load(
None,
None,
js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
% theme,
)
# Inference
with gr.Row():
with gr.Column(scale=3):
text = gr.Textbox(
label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
)
with gr.Row():
with gr.Column():
with gr.Tab(label=i18n("Advanced Config")):
with gr.Row():
chunk_length = gr.Slider(
label=i18n("Iterative Prompt Length, 0 means off"),
minimum=100,
maximum=400,
value=300,
step=8,
)
max_new_tokens = gr.Slider(
label=i18n(
"Maximum tokens per batch, 0 means no limit"
),
minimum=0,
maximum=2048,
value=0,
step=8,
)
with gr.Row():
top_p = gr.Slider(
label="Top-P",
minimum=0.7,
maximum=0.95,
value=0.8,
step=0.01,
)
repetition_penalty = gr.Slider(
label=i18n("Repetition Penalty"),
minimum=1,
maximum=1.2,
value=1.1,
step=0.01,
)
with gr.Row():
temperature = gr.Slider(
label="Temperature",
minimum=0.7,
maximum=1.0,
value=0.8,
step=0.01,
)
seed = gr.Number(
label="Seed",
info="0 means randomized inference, otherwise deterministic",
value=0,
)
with gr.Tab(label=i18n("Reference Audio")):
with gr.Row():
gr.Markdown(
i18n(
"5 to 10 seconds of reference audio, useful for specifying speaker."
)
)
with gr.Row():
reference_id = gr.Textbox(
label=i18n("Reference ID"),
placeholder="Leave empty to use uploaded references",
)
with gr.Row():
use_memory_cache = gr.Radio(
label=i18n("Use Memory Cache"),
choices=["on", "off"],
value="on",
)
with gr.Row():
reference_audio = gr.Audio(
label=i18n("Reference Audio"),
type="filepath",
)
with gr.Row():
reference_text = gr.Textbox(
label=i18n("Reference Text"),
lines=1,
placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
value="",
)
with gr.Column(scale=3):
with gr.Row():
error = gr.HTML(
label=i18n("Error Message"),
visible=True,
)
with gr.Row():
audio = gr.Audio(
label=i18n("Generated Audio"),
type="numpy",
interactive=False,
visible=True,
)
with gr.Row():
with gr.Column(scale=3):
generate = gr.Button(
value="\U0001f3a7 " + i18n("Generate"),
variant="primary",
)
# Submit
generate.click(
inference_fct,
[
text,
reference_id,
reference_audio,
reference_text,
max_new_tokens,
chunk_length,
top_p,
repetition_penalty,
temperature,
seed,
use_memory_cache,
],
[audio, error],
concurrency_limit=1,
)
return app
================================================
FILE: tools/webui/inference.py
================================================
import html
from functools import partial
from typing import Any, Callable
from fish_speech.i18n import i18n
from fish_speech.utils.schema import ServeReferenceAudio, ServeTTSRequest
def inference_wrapper(
text,
reference_id,
reference_audio,
reference_text,
max_new_tokens,
chunk_length,
top_p,
repetition_penalty,
temperature,
seed,
use_memory_cache,
engine,
):
"""
Wrapper for the inference function.
Used in the Gradio interface.
"""
if reference_audio:
references = get_reference_audio(reference_audio, reference_text)
else:
references = []
req = ServeTTSRequest(
text=text,
reference_id=reference_id if reference_id else None,
references=references,
max_new_tokens=max_new_tokens,
chunk_length=chunk_length,
top_p=top_p,
repetition_penalty=repetition_penalty,
temperature=temperature,
seed=int(seed) if seed else None,
use_memory_cache=use_memory_cache,
)
for result in engine.inference(req):
match result.code:
case "final":
return result.audio, None
case "error":
return None, build_html_error_message(i18n(result.error))
case _:
pass
return None, i18n("No audio generated")
def get_reference_audio(reference_audio: str, reference_text: str) -> list:
"""
Get the reference audio bytes.
"""
with open(reference_audio, "rb") as audio_file:
audio_bytes = audio_file.read()
return [ServeReferenceAudio(audio=audio_bytes, text=reference_text)]
def build_html_error_message(error: Any) -> str:
error = error if isinstance(error, Exception) else Exception("Unknown error")
return f"""
{html.escape(str(error))}
"""
def get_inference_wrapper(engine) -> Callable:
"""
Get the inference function with the immutable arguments.
"""
return partial(
inference_wrapper,
engine=engine,
)
================================================
FILE: tools/webui/variables.py
================================================
from fish_speech.i18n import i18n
HEADER_MD = f"""# Fish Speech
{i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")}
{i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5).")}
{i18n("Related code and weights are released under FISH AUDIO RESEARCH LICENSE.")}
{i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")}
"""
TEXTBOX_PLACEHOLDER = i18n("Put your text here.")