🚀 **Try out our interactive [UI playground](https://huggingface.co/spaces/Splend1dchan/BreezyVoice-Playground) now!** 🚀
🚀 **[立即體驗 BreezyVoice 語音合成](https://huggingface.co/spaces/Splend1dchan/BreezyVoice-Playground) !** 🚀
Or visit one of these resources:
- [Playground (CLI Inference)](https://www.kaggle.com/code/a24998667/breezyvoice-playground)
- [Model](https://huggingface.co/MediaTek-Research/BreezyVoice/tree/main)
- [Paper](https://arxiv.org/abs/2501.17790)
Repo Main Contributors: Chia-Chun Lin, Chan-Jan Hsu
## Features
🔥 BreezyVoice outperforms competing commercial services in terms of naturalness.
🔥 BreezyVoice is highly competitive at code-switching scenarios.
| Code-Switching Term Category | **BreezyVoice** | Z | Y | U | M |
|-------------|--------------|---|---|---|---|
| **General Words** | **8** | 5 | **8** | **8** | 7 |
| **Entities**| **9** | 6 | 4 | 7 | 4 |
| **Abbreviations** | **9** | 8 | 6 | 6 | 7 |
| **Toponyms**| 3 | 3 | **7** | 3 | 4 |
| **Full Sentences**| 7 | 7 | **8** | 5 | 3 |
🔥 BreezyVoice supports automatic 注音 annotation, as well as manual 注音 correction (See Inference).
## Install
**Clone and install**
- Clone the repo
``` sh
git clone https://github.com/mtkresearch/BreezyVoice.git
# If you failed to clone submodule due to network failures, please run following command until success
cd BreezyVoice
```
- Install Requirements (requires Python3.10)
```
pip uninstall onnxruntime # use onnxruntime-gpu instead of onnxruntime
pip install -r requirements.txt
```
(The model is runnable on CPU, please change onnxruntime-gpu to onnxruntime in `requirements.txt` if you do not have GPU in your environment)
You might need to install cudnn depending on cuda version
```
sudo apt-get -y install cudnn9-cuda-11
```
## Inference
UTF8 encoding is required:
``` sh
export PYTHONUTF8=1
```
---
**Run single_inference.py with the following arguments:**
- `--content_to_synthesize`:
- **Description**: Specifies the content that will be synthesized into speech. Phonetic symbols can optionally be included but should be used sparingly, as shown in the examples below:
- Simple text: `"今天天氣真好"`
- Text with phonetic symbols: `"今天天氣真好[:ㄏㄠ3]"`
- `--speaker_prompt_audio_path`:
- **Description**: Specifies the path to the prompt speech audio file for setting the style of the speaker. Use your custom audio file or our example file:
- Example audio: `./data/tc_speaker.wav`
- `--speaker_prompt_text_transcription` (optional):
- **Description**: Specifies the transcription of the speaker prompt audio. Providing this input is highly recommended for better accuracy. If not provided, the system will automatically transcribe the audio using Whisper.
- Example text for the audio file: `"在密碼學中,加密是將明文資訊改變為難以讀取的密文內容,使之不可讀的方法。只有擁有解密方法的對象,經由解密過程才能將密文還原為正常可讀的內容。"`
- `--output_path` (optional):
- **Description**: Specifies the name and path for the output `.wav` file. If not provided, the default path is used.
- **Default Value**: `results/output.wav`
- Example: `[your_file_name].wav`
- `--model_path` (optional):
- **Description**: Specifies the pre-trained model used for speech synthesis.
- **Default Value**: `MediaTek-Research/BreezyVoice`
**Example Usage:**
``` bash
bash run_single_inference.sh
```
``` python
# python single_inference.py --text_to_speech [text to be converted into audio] --text_prompt [the prompt of that audio file] --audio_path [reference audio file]
python single_inference.py --content_to_synthesize "今天天氣真好" --speaker_prompt_text_transcription "在密碼學中,加密是將明文資訊改變為難以讀取的密文內容,使之不可讀的方法。只有擁有解密方法的對象,經由解密過程才能將密文還原為正常可讀的內容。" --speaker_prompt_audio_path "./data/example.wav"
```
``` python
# python single_inference.py --text_to_speech [text to be converted into audio] --audio_path [reference audio file]
python single_inference.py --content_to_synthesize "今天天氣真好[:ㄏㄠ3]" --speaker_prompt_audio_path "./data/example.wav"
```
---
**Run `batch_inference.py` with the following arguments:**
- `--csv_file`:
- **Description**: Path to the CSV file that contains the input data for batch processing.
- **Example**: `./data/batch_files.csv`
- `--speaker_prompt_audio_folder`:
- **Description**: Path to the folder containing the speaker prompt audio files. The files in this folder are used to set the style of the speaker for each synthesis task.
- **Example**: `./data`
- `--output_audio_folder`:
- **Description**: Path to the folder where the output audio files will be saved. Each processed row in the CSV will result in a synthesized audio file stored in this folder.
- **Example**: `./results`
**CSV File Structure:**
The CSV file should contain the following columns:
- **`speaker_prompt_audio_filename`**:
- **Description**: The filename (without extension) of the speaker prompt audio file that will be used to guide the style of the generated speech.
- **Example**: `example`
- **`speaker_prompt_text_transcription`**:
- **Description**: The transcription of the speaker prompt audio. This field is optional but highly recommended to improve transcription accuracy. If not provided, the system will attempt to transcribe the audio using Whisper.
- **Example**: `"在密碼學中,加密是將明文資訊改變為難以讀取的密文內容,使之不可讀的方法。"`
- **`content_to_synthesize`**:
- **Description**: The content that will be synthesized into speech. You can include phonetic symbols if needed, though they should be used sparingly.
- **Example**: `"今天天氣真好"`
- **`output_audio_filename`**:
- **Description**: The filename (without extension) for the generated output audio. The audio will be saved as a `.wav` file in the output folder.
- **Example**: `output`
**Example Usage:**
``` bash
bash run_batch_inference.sh
```
```bash
python batch_inference.py \
--csv_file ./data/batch_files.csv \
--speaker_prompt_audio_folder ./data \
--output_audio_folder ./results
```
### Docker and OpenAI Compatible API
``` bash
$ docker compose up -d --build
# after the container is up
$ pip install openai
$ python openai_api_inference.py
```
---
If you like our work, please cite:
```
@article{hsu2025breezyvoice,
title={BreezyVoice: Adapting TTS for Taiwanese Mandarin with Enhanced Polyphone Disambiguation--Challenges and Insights},
author={Hsu, Chan-Jan and Lin, Yi-Cheng and Lin, Chia-Chun and Chen, Wei-Chih and Chung, Ho Lam and Li, Chen-An and Chen, Yi-Chang and Yu, Chien-Yu and Lee, Ming-Ji and Chen, Chien-Cheng and others},
journal={arXiv preprint arXiv:2501.17790},
year={2025}
}
@article{hsu2025breeze,
title={The Breeze 2 Herd of Models: Traditional Chinese LLMs Based on Llama with Vision-Aware and Function-Calling Capabilities},
author={Hsu, Chan-Jan and Liu, Chia-Sheng and Chen, Meng-Hsi and Chen, Muxi and Hsu, Po-Chun and Chen, Yi-Chang and Shiu, Da-Shan},
journal={arXiv preprint arXiv:2501.13921},
year={2025}
}
@article{du2024cosyvoice,
title={Cosyvoice: A scalable multilingual zero-shot text-to-speech synthesizer based on supervised semantic tokens},
author={Du, Zhihao and Chen, Qian and Zhang, Shiliang and Hu, Kai and Lu, Heng and Yang, Yexin and Hu, Hangrui and Zheng, Siqi and Gu, Yue and Ma, Ziyang and others},
journal={arXiv preprint arXiv:2407.05407},
year={2024}
}
```
================================================
FILE: api.py
================================================
# OpenAI API Spec. Reference: https://platform.openai.com/docs/api-reference/audio/createSpeech
from contextlib import asynccontextmanager
from io import BytesIO
import torchaudio
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from g2pw import G2PWConverter
from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings
from cosyvoice.utils.file_utils import load_wav
from single_inference import CustomCosyVoice, get_bopomofo_rare
class Settings(BaseSettings):
api_key: str = Field(
default="", description="Specifies the API key used to authenticate the user."
)
model_path: str = Field(
default="MediaTek-Research/BreezyVoice",
description="Specifies the model used for speech synthesis.",
)
speaker_prompt_audio_path: str = Field(
default="./data/example.wav",
description="Specifies the path to the prompt speech audio file of the speaker.",
)
speaker_prompt_text_transcription: str = Field(
default="在密碼學中,加密是將明文資訊改變為難以讀取的密文內容,使之不可讀的方法。只有擁有解密方法的對象,經由解密過程,才能將密文還原為正常可讀的內容。",
description="Specifies the transcription of the speaker prompt audio.",
)
class SpeechRequest(BaseModel):
model: str = ""
input: str = Field(
description="The content that will be synthesized into speech. You can include phonetic symbols if needed, though they should be used sparingly.",
examples=["今天天氣真好"],
)
response_format: str = ""
speed: float = 1.0
@asynccontextmanager
async def lifespan(app: FastAPI):
app.state.settings = Settings()
app.state.cosyvoice = CustomCosyVoice(app.state.settings.model_path)
app.state.bopomofo_converter = G2PWConverter()
app.state.prompt_speech_16k = load_wav(
app.state.settings.speaker_prompt_audio_path, 16000
)
yield
del app.state.cosyvoice
del app.state.bopomofo_converter
app = FastAPI(lifespan=lifespan, root_path="/v1")
@app.get("/models")
async def get_models(request: Request):
return {
"object": "list",
"data": [
{
"id": request.app.state.settings.model_path,
"object": "model",
"created": 0,
"owned_by": "local",
}
],
}
@app.post("/audio/speech")
async def speach_endpoint(request: Request, payload: SpeechRequest):
# normalization
speaker_prompt_text_transcription = (
request.app.state.cosyvoice.frontend.text_normalize_new(
request.app.state.settings.speaker_prompt_text_transcription, split=False
)
)
content_to_synthesize = request.app.state.cosyvoice.frontend.text_normalize_new(
payload.input, split=False
)
speaker_prompt_text_transcription_bopomo = get_bopomofo_rare(
speaker_prompt_text_transcription, request.app.state.bopomofo_converter
)
content_to_synthesize_bopomo = get_bopomofo_rare(
content_to_synthesize, request.app.state.bopomofo_converter
)
output = request.app.state.cosyvoice.inference_zero_shot_no_normalize(
content_to_synthesize_bopomo,
speaker_prompt_text_transcription_bopomo,
request.app.state.prompt_speech_16k,
)
audio_buffer = BytesIO()
torchaudio.save(audio_buffer, output["tts_speech"], 22050, format="wav")
audio_buffer.seek(0)
return StreamingResponse(
audio_buffer,
media_type="audio/wav",
headers={"Content-Disposition": "attachment; filename=output.wav"},
)
if __name__ == "__main__":
import uvicorn
uvicorn.run("api:app", host="0.0.0.0", port=8080)
================================================
FILE: batch_inference.py
================================================
import os
import time
import subprocess
import argparse
import pandas as pd
from datasets import Dataset
from single_inference import single_inference, CustomCosyVoice
from g2pw import G2PWConverter
def process_batch(csv_file, speaker_prompt_audio_folder, output_audio_folder, model):
# Load CSV with pandas
data = pd.read_csv(csv_file)
# Transform pandas DataFrame to HuggingFace Dataset
dataset = Dataset.from_pandas(data)
dataset = dataset.shuffle(seed = int(time.time()*1000))
cosyvoice, bopomofo_converter = model
def gen_audio(row):
speaker_prompt_audio_path = os.path.join(speaker_prompt_audio_folder, f"{row['speaker_prompt_audio_filename']}.wav")
speaker_prompt_text_transcription = row['speaker_prompt_text_transcription']
content_to_synthesize = row['content_to_synthesize']
output_audio_path = os.path.join(output_audio_folder, f"{row['output_audio_filename']}.wav")
if not os.path.exists(speaker_prompt_audio_path):
print(f"File {speaker_prompt_audio_path} does not exist")
return row #{"status": "failed", "reason": "file not found"}
if not os.path.exists(output_audio_path):
single_inference(speaker_prompt_audio_path, content_to_synthesize, output_audio_path, cosyvoice, bopomofo_converter, speaker_prompt_text_transcription)
else:
pass
# command = [
# "python", "single_inference.py",
# "--speaker_prompt_audio_path", speaker_prompt_audio_path,
# "--speaker_prompt_text_transcription", speaker_prompt_text_transcription,
# "--content_to_synthesize", content_to_synthesize,
# "--output_path", output_audio_path
# ]
# try:
# print(f"Processing: {speaker_prompt_audio_path}")
# subprocess.run(command, check=True)
# print(f"Generated: {output_audio_path}")
# return row #{"status": "success", "output": gen_voice_file_name}
# except subprocess.CalledProcessError as e:
# print(f"Failed to generate {speaker_prompt_audio_path}, error: {e}")
# return row #{"status": "failed", "reason": str(e)}
dataset = dataset.map(gen_audio, num_proc = 1)
def main():
parser = argparse.ArgumentParser(description="Batch process audio generation.")
parser.add_argument("--csv_file", required=True, help="Path to the CSV file containing input data.")
parser.add_argument("--speaker_prompt_audio_folder", required=True, help="Path to the folder containing speaker prompt audio files.")
parser.add_argument("--output_audio_folder", required=True, help="Path to the folder where results will be stored.")
parser.add_argument("--model_path", type=str, required=False, default = "MediaTek-Research/BreezyVoice-300M",help="Specifies the model used for speech synthesis.")
args = parser.parse_args()
cosyvoice = CustomCosyVoice(args.model_path)
bopomofo_converter = G2PWConverter()
os.makedirs(args.output_audio_folder, exist_ok=True)
process_batch(
csv_file=args.csv_file,
speaker_prompt_audio_folder=args.speaker_prompt_audio_folder,
output_audio_folder=args.output_audio_folder,
model = (cosyvoice, bopomofo_converter),
)
if __name__ == "__main__":
main()
================================================
FILE: compose.yaml
================================================
services:
app:
image: breezyvoice:latest
build: .
ports:
- "8080:8080"
volumes:
- hf-cache:/root/.cache/huggingface/
command: api.py
init: true
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
volumes:
hf-cache:
================================================
FILE: cosyvoice/__init__.py
================================================
================================================
FILE: cosyvoice/bin/inference.py
================================================
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
import os
import torch
from torch.utils.data import DataLoader
import torchaudio
from hyperpyyaml import load_hyperpyyaml
from tqdm import tqdm
from cosyvoice.cli.model import CosyVoiceModel
from cosyvoice.dataset.dataset import Dataset
def get_args():
parser = argparse.ArgumentParser(description='inference with your model')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--prompt_data', required=True, help='prompt data file')
parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
parser.add_argument('--tts_text', required=True, help='tts input file')
parser.add_argument('--llm_model', required=True, help='llm model file')
parser.add_argument('--flow_model', required=True, help='flow model file')
parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
parser.add_argument('--gpu',
type=int,
default=-1,
help='gpu id for this rank, -1 for cpu')
parser.add_argument('--mode',
default='sft',
choices=['sft', 'zero_shot'],
help='inference mode')
parser.add_argument('--result_dir', required=True, help='asr result file')
args = parser.parse_args()
print(args)
return args
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
# Init cosyvoice models from configs
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f)
model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
model.load(args.llm_model, args.flow_model, args.hifigan_model)
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False, tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
del configs
os.makedirs(args.result_dir, exist_ok=True)
fn = os.path.join(args.result_dir, 'wav.scp')
f = open(fn, 'w')
with torch.no_grad():
for batch_idx, batch in tqdm(enumerate(test_data_loader)):
utts = batch["utts"]
assert len(utts) == 1, "inference mode only support batchsize 1"
text = batch["text"]
text_token = batch["text_token"].to(device)
text_token_len = batch["text_token_len"].to(device)
tts_text = batch["tts_text"]
tts_index = batch["tts_index"]
tts_text_token = batch["tts_text_token"].to(device)
tts_text_token_len = batch["tts_text_token_len"].to(device)
speech_token = batch["speech_token"].to(device)
speech_token_len = batch["speech_token_len"].to(device)
speech_feat = batch["speech_feat"].to(device)
speech_feat_len = batch["speech_feat_len"].to(device)
utt_embedding = batch["utt_embedding"].to(device)
spk_embedding = batch["spk_embedding"].to(device)
if args.mode == 'sft':
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
else:
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
'prompt_text': text_token, 'prompt_text_len': text_token_len,
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
model_output = model.inference(**model_input)
tts_key = '{}_{}'.format(utts[0], tts_index[0])
tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
torchaudio.save(tts_fn, model_output['tts_speech'], sample_rate=22050)
f.write('{} {}\n'.format(tts_key, tts_fn))
f.flush()
f.close()
logging.info('Result wav.scp saved in {}'.format(fn))
if __name__ == '__main__':
main()
================================================
FILE: cosyvoice/bin/train.py
================================================
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import datetime
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
from copy import deepcopy
import torch
import torch.distributed as dist
import deepspeed
from hyperpyyaml import load_hyperpyyaml
from torch.distributed.elastic.multiprocessing.errors import record
from cosyvoice.utils.executor import Executor
from cosyvoice.utils.train_utils import (
init_distributed,
init_dataset_and_dataloader,
init_optimizer_and_scheduler,
init_summarywriter, save_model,
wrap_cuda_model, check_modify_and_save_config)
def get_args():
parser = argparse.ArgumentParser(description='training your network')
parser.add_argument('--train_engine',
default='torch_ddp',
choices=['torch_ddp', 'deepspeed'],
help='Engine for paralleled training')
parser.add_argument('--model', required=True, help='model which will be trained')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--train_data', required=True, help='train data file')
parser.add_argument('--cv_data', required=True, help='cv data file')
parser.add_argument('--checkpoint', help='checkpoint model')
parser.add_argument('--model_dir', required=True, help='save model dir')
parser.add_argument('--tensorboard_dir',
default='tensorboard',
help='tensorboard log dir')
parser.add_argument('--ddp.dist_backend',
dest='dist_backend',
default='nccl',
choices=['nccl', 'gloo'],
help='distributed backend')
parser.add_argument('--num_workers',
default=0,
type=int,
help='num of subprocess workers for reading')
parser.add_argument('--prefetch',
default=100,
type=int,
help='prefetch number')
parser.add_argument('--pin_memory',
action='store_true',
default=False,
help='Use pinned memory buffers used for reading')
parser.add_argument('--deepspeed.save_states',
dest='save_states',
default='model_only',
choices=['model_only', 'model+optimizer'],
help='save model/optimizer states')
parser.add_argument('--timeout',
default=30,
type=int,
help='timeout (in seconds) of cosyvoice_join.')
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
return args
@record
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
override_dict = {k: None for k in ['llm', 'flow', 'hift'] if k != args.model}
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f, overrides=override_dict)
configs['train_conf'].update(vars(args))
# Init env for ddp
init_distributed(args)
# Get dataset & dataloader
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
init_dataset_and_dataloader(args, configs)
# Do some sanity checks and save config to arsg.model_dir
configs = check_modify_and_save_config(args, configs)
# Tensorboard summary
writer = init_summarywriter(args)
# load checkpoint
model = configs[args.model]
if args.checkpoint is not None:
model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'))
# Dispatch model from cpu to gpu
model = wrap_cuda_model(args, model)
# Get optimizer & scheduler
model, optimizer, scheduler = init_optimizer_and_scheduler(args, configs, model)
# Save init checkpoints
info_dict = deepcopy(configs['train_conf'])
save_model(model, 'init', info_dict)
# Get executor
executor = Executor()
# Start training loop
for epoch in range(info_dict['max_epoch']):
executor.epoch = epoch
train_dataset.set_epoch(epoch)
dist.barrier()
group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)
dist.destroy_process_group(group_join)
if __name__ == '__main__':
main()
================================================
FILE: cosyvoice/cli/__init__.py
================================================
================================================
FILE: cosyvoice/cli/cosyvoice.py
================================================
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from hyperpyyaml import load_hyperpyyaml
from huggingface_hub import snapshot_download
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
from cosyvoice.cli.model import CosyVoiceModel
class CosyVoice:
def __init__(self, model_dir):
instruct = True if '-Instruct' in model_dir else False
self.model_dir = model_dir
if not os.path.exists(model_dir):
model_dir = snapshot_download(model_dir)
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
configs = load_hyperpyyaml(f)
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
configs['feat_extractor'],
'{}/campplus.onnx'.format(model_dir),
'{}/speech_tokenizer_v1.onnx'.format(model_dir),
'{}/spk2info.pt'.format(model_dir),
instruct,
configs['allowed_special'])
self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
self.model.load('{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir),
'{}/hift.pt'.format(model_dir))
del configs
def list_avaliable_spks(self):
spks = list(self.frontend.spk2info.keys())
return spks
def inference_sft(self, tts_text, spk_id):
tts_speeches = []
for i in self.frontend.text_normalize(tts_text, split=True):
model_input = self.frontend.frontend_sft(i, spk_id)
model_output = self.model.inference(**model_input)
tts_speeches.append(model_output['tts_speech'])
return {'tts_speech': torch.concat(tts_speeches, dim=1)}
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
prompt_text = self.frontend.text_normalize(prompt_text, split=False)
tts_speeches = []
for i in self.frontend.text_normalize(tts_text, split=True):
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
model_output = self.model.inference(**model_input)
tts_speeches.append(model_output['tts_speech'])
return {'tts_speech': torch.concat(tts_speeches, dim=1)}
def inference_cross_lingual(self, tts_text, prompt_speech_16k):
if self.frontend.instruct is True:
raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
tts_speeches = []
for i in self.frontend.text_normalize(tts_text, split=True):
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
model_output = self.model.inference(**model_input)
tts_speeches.append(model_output['tts_speech'])
return {'tts_speech': torch.concat(tts_speeches, dim=1)}
def inference_instruct(self, tts_text, spk_id, instruct_text):
if self.frontend.instruct is False:
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
instruct_text = self.frontend.text_normalize(instruct_text, split=False)
tts_speeches = []
for i in self.frontend.text_normalize(tts_text, split=True):
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
model_output = self.model.inference(**model_input)
tts_speeches.append(model_output['tts_speech'])
return {'tts_speech': torch.concat(tts_speeches, dim=1)}
================================================
FILE: cosyvoice/cli/frontend.py
================================================
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import onnxruntime
import torch
import numpy as np
import whisper
from typing import Callable
import torchaudio.compliance.kaldi as kaldi
import torchaudio
import os
import re
import inflect
try:
import ttsfrd
use_ttsfrd = True
except ImportError:
print("failed to import ttsfrd, use WeTextProcessing instead")
from tn.chinese.normalizer import Normalizer as ZhNormalizer
from tn.english.normalizer import Normalizer as EnNormalizer
use_ttsfrd = False
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph
class CosyVoiceFrontEnd:
def __init__(self,
get_tokenizer: Callable,
feat_extractor: Callable,
model_dir: str,
campplus_model: str,
speech_tokenizer_model: str,
spk2info: str = '',
instruct: bool = False,
allowed_special: str = 'all'):
self.tokenizer = get_tokenizer()
self.feat_extractor = feat_extractor
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option, providers=["CUDAExecutionProvider"if torch.cuda.is_available() else "CPUExecutionProvider"])
if os.path.exists(spk2info):
self.spk2info = torch.load(spk2info, map_location=self.device)
self.instruct = instruct
self.allowed_special = allowed_special
self.inflect_parser = inflect.engine()
self.use_ttsfrd = use_ttsfrd
if self.use_ttsfrd:
self.frd = ttsfrd.TtsFrontendEngine()
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
assert self.frd.initialize('{}/CosyVoice-ttsfrd/resource'.format(model_dir)) is True, 'failed to initialize ttsfrd resource'
self.frd.set_lang_type('pinyin')
self.frd.enable_pinyin_mix(True)
self.frd.set_breakmodel_index(1)
else:
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False)
self.en_tn_model = EnNormalizer()
def _extract_text_token(self, text):
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
return text_token, text_token_len
def _extract_speech_token(self, speech):
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
speech_token = self.speech_tokenizer_session.run(None, {self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
return speech_token, speech_token_len
def _extract_spk_embedding(self, speech):
feat = kaldi.fbank(speech,
num_mel_bins=80,
dither=0,
sample_frequency=16000)
feat = feat - feat.mean(dim=0, keepdim=True)
embedding = self.campplus_session.run(None, {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
embedding = torch.tensor([embedding]).to(self.device)
return embedding
def _extract_speech_feat(self, speech):
speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
speech_feat = speech_feat.unsqueeze(dim=0)
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
return speech_feat, speech_feat_len
def text_normalize(self, text, split=True):
text = text.strip()
if contains_chinese(text):
if self.use_ttsfrd:
text = self.frd.get_frd_extra_info(text, 'input')
else:
text = self.zh_tn_model.normalize(text)
text = text.replace("\n", "")
text = replace_blank(text)
text = replace_corner_mark(text)
text = text.replace(".", "、")
text = text.replace(" - ", ",")
text = remove_bracket(text)
text = re.sub(r'[,,]+$', '。', text)
texts = [i for i in split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
token_min_n=60, merge_len=20,
comma_split=False)]
else:
if self.use_ttsfrd:
text = self.frd.get_frd_extra_info(text, 'input')
else:
text = self.en_tn_model.normalize(text)
text = spell_out_number(text, self.inflect_parser)
texts = [i for i in split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
token_min_n=60, merge_len=20,
comma_split=False)]
if split is False:
return text
return texts
def frontend_sft(self, tts_text, spk_id):
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
embedding = self.spk2info[spk_id]['embedding']
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
return model_input
def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
embedding = self._extract_spk_embedding(prompt_speech_16k)
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
'llm_embedding': embedding, 'flow_embedding': embedding}
return model_input
def frontend_cross_lingual(self, tts_text, prompt_speech_16k):
model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k)
# in cross lingual mode, we remove prompt in llm
del model_input['prompt_text']
del model_input['prompt_text_len']
del model_input['llm_prompt_speech_token']
del model_input['llm_prompt_speech_token_len']
return model_input
def frontend_instruct(self, tts_text, spk_id, instruct_text):
model_input = self.frontend_sft(tts_text, spk_id)
# in instruct mode, we remove spk_embedding in llm due to information leakage
del model_input['llm_embedding']
instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '